Unverified Commit 799789af authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Bump Flashinfer to 0.2.5 (#5870)


Co-authored-by: default avatarYuhao Chen <yxckeis8@gmail.com>
parent cc4a80ca
......@@ -96,8 +96,6 @@ jobs:
uses: actions/checkout@v4
- name: Install dependencies
env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: |
bash scripts/ci_install_dependency.sh
......
......@@ -164,4 +164,4 @@ sky status --endpoint 30000 sglang
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub.
- If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
- The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime.
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.3" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.5" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
......@@ -37,7 +37,7 @@ runtime_common = [
"python-multipart",
"pyzmq>=25.1.2",
"soundfile==0.13.1",
"torchao>=0.7.0",
"torchao>=0.9.0",
"transformers==4.51.1",
"uvicorn",
"uvloop",
......@@ -47,7 +47,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.1.0",
"flashinfer_python==0.2.3",
"flashinfer_python==0.2.5",
"torch==2.6.0",
"torchvision==0.21.0",
"cuda-python",
......
......@@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer":
assert_pkg_version(
"flashinfer_python",
"0.2.3",
"0.2.5",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
......
......@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
......@@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
assert not (
model_runner.sliding_window_size is not None
......@@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend):
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]
# Ensure tensors are properly allocated
for i in range(self.num_wrappers):
# Force allocation by performing a small operation
if len(self.cuda_graph_kv_indices[i]) > 0:
self.cuda_graph_kv_indices[i][0] = 0
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
......@@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
......@@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
......@@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
......@@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......@@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
......@@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
# Call the wrapped function
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......@@ -1146,8 +1152,9 @@ def fast_decode_plan(
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
data_type: Union[str, torch.dtype] = "float16",
q_data_type: Optional[Union[str, torch.dtype]] = None,
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
......@@ -1163,6 +1170,18 @@ def fast_decode_plan(
if logits_soft_cap is None:
logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
......@@ -1178,36 +1197,33 @@ def fast_decode_plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# Skip these copies because we directly write to them during prepartion
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
q_data_type = data_type
if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type)
if isinstance(q_data_type, str)
else q_data_type
),
)
self.empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
if self.use_tensor_cores:
self._qo_indptr_buf = qo_indptr_host.to(
self.device, non_blocking=non_blocking
)
# Create empty tensors for dtype info if needed
empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
),
device=self.device,
)
empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, kv_data_type)
if isinstance(kv_data_type, str)
else kv_data_type
),
device=self.device,
)
indptr_host = (
global_override_indptr_cpu
......@@ -1215,48 +1231,57 @@ def fast_decode_plan(
else indptr.cpu()
)
if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(
indptr_host, self.last_page_len[:batch_size], page_size
)
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
torch.cuda.current_stream().cuda_stream,
)
else:
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
self.empty_q_data,
self.empty_kv_cache,
torch.cuda.current_stream().cuda_stream,
)
with torch.cuda.device(self.device):
if self.use_tensor_cores:
# ALSO convert last_page_len to CPU
last_page_len_host = last_page_len.cpu()
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")
else:
try:
# Make sure we pass exactly 15 arguments for standard version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
......
......@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
import os
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union
......@@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import triton
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
......@@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
k,
v,
)
# Reshape inputs
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
# Direct call to run without the wrapper
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
reshaped_k[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :],
k_buffer[:, :, : layer.v_head_dim],
k_buffer[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......@@ -825,16 +834,18 @@ def fast_mla_decode_plan(
self._sm_scale = sm_scale
with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
stream,
)
try:
# Standard version with just the required arguments (no use_profiler)
self._cached_module.plan.default(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
)
except Exception as e:
raise RuntimeError(f"Error in alternate MLA plan: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment