Unverified Commit 799789af authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Bump Flashinfer to 0.2.5 (#5870)


Co-authored-by: default avatarYuhao Chen <yxckeis8@gmail.com>
parent cc4a80ca
...@@ -96,8 +96,6 @@ jobs: ...@@ -96,8 +96,6 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
......
...@@ -164,4 +164,4 @@ sky status --endpoint 30000 sglang ...@@ -164,4 +164,4 @@ sky status --endpoint 30000 sglang
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub. - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub.
- If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. - If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
- The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime. - The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime.
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.3" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`. - To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.5" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
...@@ -37,7 +37,7 @@ runtime_common = [ ...@@ -37,7 +37,7 @@ runtime_common = [
"python-multipart", "python-multipart",
"pyzmq>=25.1.2", "pyzmq>=25.1.2",
"soundfile==0.13.1", "soundfile==0.13.1",
"torchao>=0.7.0", "torchao>=0.9.0",
"transformers==4.51.1", "transformers==4.51.1",
"uvicorn", "uvicorn",
"uvloop", "uvloop",
...@@ -47,7 +47,7 @@ runtime_common = [ ...@@ -47,7 +47,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.1.0", "sgl-kernel==0.1.0",
"flashinfer_python==0.2.3", "flashinfer_python==0.2.5",
"torch==2.6.0", "torch==2.6.0",
"torchvision==0.21.0", "torchvision==0.21.0",
"cuda-python", "cuda-python",
......
...@@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer": if server_args.attention_backend == "flashinfer":
assert_pkg_version( assert_pkg_version(
"flashinfer_python", "flashinfer_python",
"0.2.3", "0.2.5",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
...@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union ...@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
...@@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal self.is_multimodal = model_runner.model_config.is_multimodal
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
assert not ( assert not (
model_runner.sliding_window_size is not None model_runner.sliding_window_size is not None
...@@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend):
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
] ]
# Ensure tensors are properly allocated
for i in range(self.num_wrappers):
# Force allocation by performing a small operation
if len(self.cuda_graph_kv_indices[i]) > 0:
self.cuda_graph_kv_indices[i][0] = 0
if not self.skip_prefill: if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros( self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len), (max_bs * self.max_context_len),
...@@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
...@@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
...@@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
k_scale=k_scale, k_scale=layer.k_scale,
v_scale=v_scale, v_scale=layer.v_scale,
) )
else: else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
...@@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend):
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
decode_wrapper = self.forward_metadata.decode_wrappers[ decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
...@@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, k_scale, v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
# Call the wrapped function
o = decode_wrapper.forward( o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
k_scale=k_scale, k_scale=layer.k_scale,
v_scale=v_scale, v_scale=layer.v_scale,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -1146,8 +1152,9 @@ def fast_decode_plan( ...@@ -1146,8 +1152,9 @@ def fast_decode_plan(
pos_encoding_mode: str = "NONE", pos_encoding_mode: str = "NONE",
window_left: int = -1, window_left: int = -1,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
data_type: Union[str, torch.dtype] = "float16",
q_data_type: Optional[Union[str, torch.dtype]] = None, q_data_type: Optional[Union[str, torch.dtype]] = None,
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None, sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None, rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
...@@ -1163,6 +1170,18 @@ def fast_decode_plan( ...@@ -1163,6 +1170,18 @@ def fast_decode_plan(
if logits_soft_cap is None: if logits_soft_cap is None:
logits_soft_cap = 0.0 logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
if self.use_tensor_cores: if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
...@@ -1178,36 +1197,33 @@ def fast_decode_plan( ...@@ -1178,36 +1197,33 @@ def fast_decode_plan(
raise ValueError( raise ValueError(
"The size of indices should be less than or equal to the allocated buffer" "The size of indices should be less than or equal to the allocated buffer"
) )
# Skip these copies because we directly write to them during prepartion
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else: else:
self._paged_kv_indptr_buf = indptr self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len self._paged_kv_last_page_len_buf = last_page_len
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking) if self.use_tensor_cores:
self._qo_indptr_buf = qo_indptr_host.to(
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info self.device, non_blocking=non_blocking
if not q_data_type: )
q_data_type = data_type
if not hasattr(self, "empty_q_data"): # Create empty tensors for dtype info if needed
self.empty_q_data = torch.empty( empty_q_data = torch.empty(
0, 0,
dtype=( dtype=(
getattr(torch, q_data_type) getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
if isinstance(q_data_type, str)
else q_data_type
), ),
device=self.device,
) )
self.empty_kv_cache = torch.empty(
empty_kv_cache = torch.empty(
0, 0,
dtype=( dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type getattr(torch, kv_data_type)
if isinstance(kv_data_type, str)
else kv_data_type
), ),
device=self.device,
) )
self.last_page_len = torch.ones(32768, dtype=torch.int32)
indptr_host = ( indptr_host = (
global_override_indptr_cpu global_override_indptr_cpu
...@@ -1215,11 +1231,16 @@ def fast_decode_plan( ...@@ -1215,11 +1231,16 @@ def fast_decode_plan(
else indptr.cpu() else indptr.cpu()
) )
with torch.cuda.device(self.device):
if self.use_tensor_cores: if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens( # ALSO convert last_page_len to CPU
indptr_host, self.last_page_len[:batch_size], page_size last_page_len_host = last_page_len.cpu()
)
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan( self._plan_info = self._cached_module.plan(
self._float_workspace_buffer, self._float_workspace_buffer,
self._int_workspace_buffer, self._int_workspace_buffer,
...@@ -1236,9 +1257,12 @@ def fast_decode_plan( ...@@ -1236,9 +1257,12 @@ def fast_decode_plan(
head_dim, head_dim,
head_dim, head_dim,
False, # causal False, # causal
torch.cuda.current_stream().cuda_stream,
) )
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")
else: else:
try:
# Make sure we pass exactly 15 arguments for standard version
self._plan_info = self._cached_module.plan( self._plan_info = self._cached_module.plan(
self._float_workspace_buffer, self._float_workspace_buffer,
self._int_workspace_buffer, self._int_workspace_buffer,
...@@ -1253,10 +1277,11 @@ def fast_decode_plan( ...@@ -1253,10 +1277,11 @@ def fast_decode_plan(
logits_soft_cap, logits_soft_cap,
head_dim, head_dim,
head_dim, head_dim,
self.empty_q_data, empty_q_data,
self.empty_kv_cache, empty_kv_cache,
torch.cuda.current_stream().cuda_stream,
) )
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}")
self._pos_encoding_mode = pos_encoding_mode self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left self._window_left = window_left
......
...@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding. ...@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html More details can be found in https://docs.flashinfer.ai/api/mla.html
""" """
import os
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
...@@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union ...@@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
import triton import triton
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
...@@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
k, k,
v, v,
) )
# Reshape inputs
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
# Direct call to run without the wrapper
o = decode_wrapper.run( o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim], reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :], reshaped_q[:, :, layer.v_head_dim :],
reshaped_k[:, :, : layer.v_head_dim], k_buffer[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :], k_buffer[:, :, layer.v_head_dim :],
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -825,8 +834,9 @@ def fast_mla_decode_plan( ...@@ -825,8 +834,9 @@ def fast_mla_decode_plan(
self._sm_scale = sm_scale self._sm_scale = sm_scale
with self.device as device: with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream try:
self._cached_module.plan( # Standard version with just the required arguments (no use_profiler)
self._cached_module.plan.default(
self._float_workspace_buffer, self._float_workspace_buffer,
self._int_workspace_buffer, self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer, self._pin_memory_int_workspace_buffer,
...@@ -836,5 +846,6 @@ def fast_mla_decode_plan( ...@@ -836,5 +846,6 @@ def fast_mla_decode_plan(
num_heads, num_heads,
head_dim_ckv, head_dim_ckv,
causal, causal,
stream,
) )
except Exception as e:
raise RuntimeError(f"Error in alternate MLA plan: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment