Unverified Commit 70f894b8 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: support flashinfer mla attention for deepseek v3 (#3550)

parent 368de366
...@@ -72,7 +72,7 @@ jobs: ...@@ -72,7 +72,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
...@@ -98,7 +98,7 @@ jobs: ...@@ -98,7 +98,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
...@@ -123,7 +123,7 @@ jobs: ...@@ -123,7 +123,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
...@@ -163,7 +163,7 @@ jobs: ...@@ -163,7 +163,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
...@@ -209,7 +209,7 @@ jobs: ...@@ -209,7 +209,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
...@@ -243,7 +243,7 @@ jobs: ...@@ -243,7 +243,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
...@@ -283,7 +283,7 @@ jobs: ...@@ -283,7 +283,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
git clone https://github.com/merrymercy/human-eval.git git clone https://github.com/merrymercy/human-eval.git
...@@ -308,7 +308,7 @@ jobs: ...@@ -308,7 +308,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
git clone https://github.com/merrymercy/human-eval.git git clone https://github.com/merrymercy/human-eval.git
......
...@@ -21,12 +21,13 @@ runtime_common = [ ...@@ -21,12 +21,13 @@ runtime_common = [
"hf_transfer", "huggingface_hub", "interegular", "modelscope", "hf_transfer", "huggingface_hub", "interegular", "modelscope",
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "orjson", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2",
"torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10" "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10", "ninja"
] ]
srt = [ srt = [
"sglang[runtime_common]", "cuda-python", "sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2", "sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2",
"flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11" "flashinfer_python>=0.2.1.post1",
"outlines>=0.0.44,<=0.1.11",
] ]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
......
...@@ -38,5 +38,7 @@ class GlobalConfig: ...@@ -38,5 +38,7 @@ class GlobalConfig:
self.enable_precache_with_tracing = True self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True self.enable_parallel_encoding = True
self.enable_flashinfer_mla = False
global_config = GlobalConfig() global_config = GlobalConfig()
...@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer": if server_args.attention_backend == "flashinfer":
assert_pkg_version( assert_pkg_version(
"flashinfer_python", "flashinfer_python",
"0.2.0.post2", "0.2.1.post1",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
...@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize. ...@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
""" """
import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
...@@ -20,6 +21,7 @@ import triton.language as tl ...@@ -20,6 +21,7 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
...@@ -35,7 +37,7 @@ if is_flashinfer_available(): ...@@ -35,7 +37,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode from flashinfer.mla import BatchMLAPagedAttentionWrapper
class WrapperDispatch(Enum): class WrapperDispatch(Enum):
...@@ -45,7 +47,9 @@ class WrapperDispatch(Enum): ...@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
@dataclass @dataclass
class DecodeMetadata: class DecodeMetadata:
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
@dataclass @dataclass
...@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
global_config.flashinfer_workspace_size = 512 * 1024 * 1024 global_config.flashinfer_workspace_size = 512 * 1024 * 1024
self.enable_flashinfer_mla = False
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
if global_server_args_dict["enable_flashinfer_mla"]:
self.enable_flashinfer_mla = True
global_config.enable_flashinfer_mla = True
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
...@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
) )
for _ in range(self.num_wrappers) for _ in range(self.num_wrappers)
] ]
if self.enable_flashinfer_mla:
self.qo_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
else: else:
assert self.num_wrappers == 1 assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf] self.kv_indptr = [kv_indptr_buf]
...@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_verify.append( self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
) )
self.decode_wrappers.append( if self.enable_flashinfer_mla:
BatchDecodeWithPagedKVCacheWrapper( self.decode_wrappers.append(
self.workspace_buffer, BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
"NHD", )
use_tensor_cores=self.decode_use_tensor_cores, else:
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
) )
)
# Create indices updater # Create indices updater
if not skip_prefill: if not skip_prefill:
...@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
for i in range(self.num_wrappers): for i in range(self.num_wrappers):
decode_wrappers.append( if self.enable_flashinfer_mla:
BatchDecodeWithPagedKVCacheWrapper( decode_wrappers.append(
self.workspace_buffer, BatchMLAPagedAttentionWrapper(
"NHD", self.workspace_buffer,
use_cuda_graph=True, use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores, qo_indptr=self.qo_indptr[i][: num_tokens + 1],
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], kv_indptr=self.kv_indptr[i][: num_tokens + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], kv_indices=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[ kv_len_arr=self.kv_last_page_len[:num_tokens],
:num_tokens backend="fa2",
], )
)
else:
decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[
:num_tokens
],
)
) )
)
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices, req_pool_indices,
...@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ if global_config.enable_flashinfer_mla:
self._get_wrapper_idx(layer) cache_loc = (
] forward_batch.out_cache_loc
cache_loc = ( if not layer.is_cross_attention
forward_batch.out_cache_loc else forward_batch.encoder_out_cache_loc
if not layer.is_cross_attention )
else forward_batch.encoder_out_cache_loc
)
logits_soft_cap = layer.logit_cap logits_soft_cap = layer.logit_cap
if not self.forward_metadata.use_ragged: o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim), q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim), v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True, causal=True,
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
) )
if self.forward_metadata.extend_no_prefix: o = o1
o = o1
else: if save_kv_cache:
o2, s2 = prefill_wrapper_paged.forward_return_lse( forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
logits_soft_cap = layer.logit_cap
if not self.forward_metadata.use_ragged:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False, causal=not layer.is_cross_attention,
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
) )
o, _ = merge_state(o1, s1, o2, s2) if self.forward_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
if save_kv_cache: o, _ = merge_state(o1, s1, o2, s2)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale if save_kv_cache:
) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode( def forward_decode(
self, self,
...@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
if k is not None: if self.enable_flashinfer_mla:
assert v is not None if k is not None:
if save_kv_cache: assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, cache_loc, k, v, layer.k_scale, layer.v_scale forward_batch.token_to_kv_pool.set_kv_buffer(
) layer,
cache_loc,
k,
v,
)
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
reshaped_k[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :],
)
o = decode_wrapper.forward( return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), else:
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), if k is not None:
sm_scale=layer.scaling, assert v is not None
logits_soft_cap=layer.logit_cap, if save_kv_cache:
k_scale=layer.k_scale, forward_batch.token_to_kv_pool.set_kv_buffer(
v_scale=layer.v_scale, layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: RadixAttention): def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1: if self.num_wrappers == 1:
...@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
...@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
...@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper: BatchDecodeWithPagedKVCacheWrapper, wrapper: Union[
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
...@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode: ...@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
wrapper.begin_forward( if global_config.enable_flashinfer_mla:
kv_indptr, sm_scale = 1.0 / math.sqrt(192)
kv_indices, q_indptr = torch.arange(0, bs + 1).to(0).int()
self.kv_last_page_len[:bs], kv_lens = paged_kernel_lens.to(torch.int32)
self.num_qo_heads, wrapper.plan(
self.num_kv_heads, q_indptr,
self.head_dim, kv_indptr,
1, kv_indices,
data_type=self.data_type, kv_lens,
q_data_type=self.q_data_type, self.num_qo_heads,
non_blocking=True, 512,
) 64,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
)
class FlashInferIndicesUpdaterPrefill: class FlashInferIndicesUpdaterPrefill:
...@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
# extend part # extend part
if use_ragged: if use_ragged:
wrapper_ragged.begin_forward( if global_config.enable_flashinfer_mla:
qo_indptr, wrapper_ragged.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim_qk=192,
head_dim_vo=128,
q_data_type=self.q_data_type,
)
else:
wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
q_data_type=self.q_data_type,
)
if not global_config.enable_flashinfer_mla:
# cached part
wrapper_paged.begin_forward(
qo_indptr, qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads, self.num_qo_heads,
self.num_kv_heads, self.num_kv_heads,
self.head_dim, self.head_dim,
1,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
custom_mask=custom_mask,
non_blocking=True,
) )
# cached part
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
non_blocking=True,
)
class FlashInferMultiStepDraftBackend: class FlashInferMultiStepDraftBackend:
""" """
...@@ -1163,6 +1287,7 @@ def fast_decode_plan( ...@@ -1163,6 +1287,7 @@ def fast_decode_plan(
window_left, window_left,
logits_soft_cap, logits_soft_cap,
head_dim, head_dim,
head_dim,
empty_q_data, empty_q_data,
empty_kv_cache, empty_kv_cache,
stream.cuda_stream, stream.cuda_stream,
......
...@@ -65,6 +65,7 @@ global_server_args_dict = { ...@@ -65,6 +65,7 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device, "device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -67,6 +67,7 @@ from sglang.srt.utils import ( ...@@ -67,6 +67,7 @@ from sglang.srt.utils import (
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
set_cuda_arch,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -110,8 +111,14 @@ class ModelRunner: ...@@ -110,8 +111,14 @@ class ModelRunner:
): ):
# TODO: add MLA optimization on CPU # TODO: add MLA optimization on CPU
if self.server_args.device != "cpu": if self.server_args.device != "cpu":
logger.info("MLA optimization is turned on. Use triton backend.") if server_args.enable_flashinfer_mla:
self.server_args.attention_backend = "triton" logger.info(
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
)
self.server_args.attention_backend = "flashinfer"
else:
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.server_args.enable_double_sparsity: if self.server_args.enable_double_sparsity:
logger.info( logger.info(
...@@ -169,6 +176,7 @@ class ModelRunner: ...@@ -169,6 +176,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device, "device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
} }
) )
...@@ -292,6 +300,8 @@ class ModelRunner: ...@@ -292,6 +300,8 @@ class ModelRunner:
if torch.cuda.get_device_capability()[1] < 5: if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.") raise RuntimeError("SGLang only supports sm75 and above.")
set_cuda_arch()
# Prepare the model config # Prepare the model config
self.load_config = LoadConfig( self.load_config = LoadConfig(
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
......
...@@ -510,14 +510,20 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -510,14 +510,20 @@ class DeepseekV2AttentionMLA(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
# Use normal computation for prefill and use weight absorption for extend/decode if global_server_args_dict["enable_flashinfer_mla"]:
if ( if forward_batch.forward_mode.is_extend():
forward_batch.forward_mode.is_extend() return self.forward_normal(positions, hidden_states, forward_batch)
and forward_batch.extend_prefix_lens.sum() == 0 else:
): return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_normal(positions, hidden_states, forward_batch)
else: else:
return self.forward_absorb(positions, hidden_states, forward_batch) # Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
def forward_normal( def forward_normal(
self, self,
......
...@@ -168,6 +168,8 @@ class ServerArgs: ...@@ -168,6 +168,8 @@ class ServerArgs:
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -693,6 +695,11 @@ class ServerArgs: ...@@ -693,6 +695,11 @@ class ServerArgs:
default=ServerArgs.grammar_backend, default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.", help="Choose the backend for grammar-guided decoding.",
) )
parser.add_argument(
"--enable-flashinfer-mla",
action="store_true",
help="Enable FlashInfer MLA optimization",
)
# Speculative decoding # Speculative decoding
parser.add_argument( parser.add_argument(
......
...@@ -1444,3 +1444,10 @@ def launch_dummy_health_check_server(host, port): ...@@ -1444,3 +1444,10 @@ def launch_dummy_health_check_server(host, port):
timeout_keep_alive=5, timeout_keep_alive=5,
loop="uvloop", loop="uvloop",
) )
def set_cuda_arch():
if is_flashinfer_available():
capability = torch.cuda.get_device_capability()
arch = f"{capability[0]}.{capability[1]}"
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
...@@ -4,17 +4,19 @@ set -euxo pipefail ...@@ -4,17 +4,19 @@ set -euxo pipefail
# Install the dependency in CI. # Install the dependency in CI.
# Use repo from environment variable, passed from GitHub Actions # Use repo from environment variable, passed from GitHub Actions
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer}" FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
bash "${SCRIPT_DIR}/killall_sglang.sh" bash "${SCRIPT_DIR}/killall_sglang.sh"
pip install --upgrade pip pip install --upgrade pip
pip uninstall flashinfer -y pip uninstall flashinfer -y
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
rm -rf /root/.cache/flashinfer
# Force reinstall flashinfer and torch_memory_saver # Force reinstall flashinfer and torch_memory_saver
pip install flashinfer_python==0.2.0.post2 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps pip install flashinfer_python==0.2.1.post1 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
pip install torch_memory_saver --force-reinstall pip install torch_memory_saver --force-reinstall
pip install transformers==4.45.2 sentence_transformers accelerate peft pip install transformers==4.45.2 sentence_transformers accelerate peft
......
...@@ -28,6 +28,7 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -28,6 +28,7 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_eagle_topk": 8, "speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64, "speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7, "mem_fraction_static": 0.7,
"cuda_graph_max_bs": 32,
} }
def setUp(self): def setUp(self):
...@@ -124,6 +125,8 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -124,6 +125,8 @@ class TestEAGLEServer(unittest.TestCase):
"64", "64",
"--mem-fraction-static", "--mem-fraction-static",
"0.7", "0.7",
"--cuda-graph-max-bs",
"32",
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment