"vscode:/vscode.git/clone" did not exist on "f554bc4873b82e5b66605311a518673137b9071c"
Unverified Commit 36acd2ff authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Fix chunked prefix cache for nvfp4 (#10180)


Co-authored-by: default avatarElfie Guo <elfieg@nvidia.com>
parent fe6cdf89
...@@ -20,6 +20,7 @@ from sglang.srt.layers.attention.utils import ( ...@@ -20,6 +20,7 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton, create_flashmla_kv_indices_triton,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
...@@ -72,7 +73,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -72,7 +73,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None, q_indptr_decode_buf: Optional[torch.Tensor] = None,
): ):
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf) super().__init__(
model_runner,
skip_prefill,
kv_indptr_buf,
q_indptr_decode_buf,
)
config = model_runner.model_config config = model_runner.model_config
...@@ -112,6 +118,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -112,6 +118,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
def _calc_padded_blocks(self, max_seq_len: int) -> int: def _calc_padded_blocks(self, max_seq_len: int) -> int:
""" """
Calculate padded block count that satisfies both TRT-LLM and Triton constraints. Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
...@@ -301,6 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -301,6 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
): ):
if self.disable_chunked_prefix_cache:
super().init_forward_metadata(forward_batch)
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
cum_seq_lens_q = torch.cat( cum_seq_lens_q = torch.cat(
( (
...@@ -540,6 +553,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -540,6 +553,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
return super().forward_extend( return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
) )
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
if forward_batch.attn_attend_prefix_cache is None:
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
if not forward_batch.attn_attend_prefix_cache: if not forward_batch.attn_attend_prefix_cache:
q = q.view(-1, layer.tp_q_head_num, layer.head_dim) q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
......
...@@ -560,18 +560,19 @@ class ModelRunner: ...@@ -560,18 +560,19 @@ class ModelRunner:
if not self.use_mla_backend: if not self.use_mla_backend:
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention. # TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
# For more details, see: https://github.com/sgl-project/sglang/issues/8616 # For more details, see: https://github.com/sgl-project/sglang/issues/8616
elif ( elif (
self.dp_size > 1 self.dp_size > 1
and is_sm100_supported() and is_sm100_supported()
and server_args.attention_backend != "triton" and server_args.attention_backend != "triton"
and server_args.attention_backend == "trtllm_mla"
): ):
logger.info( logger.info(
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton." "Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
) )
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache: if not server_args.disable_chunked_prefix_cache:
logger.info("Chunked prefix cache is turned on.") logger.info("Chunked prefix cache is turned on.")
......
...@@ -1087,6 +1087,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1087,6 +1087,8 @@ class DeepseekV2AttentionMLA(nn.Module):
disable_ragged = ( disable_ragged = (
attention_backend == "flashinfer" or attention_backend == "flashmla" attention_backend == "flashinfer" or attention_backend == "flashmla"
) and self.flashinfer_mla_disable_ragged ) and self.flashinfer_mla_disable_ragged
original_mode = getattr(forward_batch, "_original_forward_mode", None)
if ( if (
not disable_ragged not disable_ragged
and forward_batch.forward_mode.is_extend() and forward_batch.forward_mode.is_extend()
...@@ -1099,15 +1101,40 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1099,15 +1101,40 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
or sum_extend_prefix_lens == 0 or sum_extend_prefix_lens == 0
) )
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
# dp case. Redirect to mla kernel as a workaround.
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
and not (
original_mode is not None
and original_mode.is_decode()
and is_sm100_supported()
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
)
): ):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return _dispatch_mla_subtype() return _dispatch_mla_subtype()
elif attention_backend == "trtllm_mla": elif attention_backend == "trtllm_mla":
original_mode = getattr(forward_batch, "_original_forward_mode", None)
if (
original_mode is not None
and original_mode.is_decode()
and is_sm100_supported()
):
return _dispatch_mla_subtype()
sum_extend_prefix_lens = (
sum(forward_batch.extend_prefix_lens_cpu)
if forward_batch.extend_prefix_lens_cpu is not None
else 0
)
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
and (
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
)
): ):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment