Unverified Commit 36a4cad7 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Support overlap-spec-v2 with trtllm_mla attention backend (#11821)

parent 65d376b4
...@@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size ...@@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available from sglang.srt.utils import is_cuda, is_flashinfer_available
from sglang.srt.utils.common import cached_triton_kernel
if is_flashinfer_available(): if is_flashinfer_available():
import flashinfer import flashinfer
...@@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB ...@@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
TRTLLM_BLOCK_CONSTRAINT = 128 TRTLLM_BLOCK_CONSTRAINT = 128
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
@triton.jit @triton.jit
def pad_draft_extend_query_kernel( def pad_draft_extend_query_kernel(
q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim] q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
...@@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel( ...@@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel(
) )
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
@triton.jit @triton.jit
def unpad_draft_extend_output_kernel( def unpad_draft_extend_output_kernel(
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim) raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
...@@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if ( if (
not forward_mode.is_decode_or_idle() not forward_mode.is_decode_or_idle()
and not forward_mode.is_target_verify() and not forward_mode.is_target_verify()
and not forward_mode.is_draft_extend() and not forward_mode.is_draft_extend(include_v2=True)
): ):
return super().init_forward_metadata_capture_cuda_graph( return super().init_forward_metadata_capture_cuda_graph(
bs, bs,
...@@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices, block_kv_indices,
max_seq_len_val, max_seq_len_val,
) )
if forward_mode.is_draft_extend(): if forward_mode.is_draft_extend(include_v2=True):
num_tokens_per_bs = num_tokens // bs num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs + 1 metadata.max_seq_len_q = num_tokens_per_bs + 1
metadata.sum_seq_lens_q = num_tokens_per_bs * bs metadata.sum_seq_lens_q = num_tokens_per_bs * bs
...@@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if ( if (
not forward_mode.is_decode_or_idle() not forward_mode.is_decode_or_idle()
and not forward_mode.is_target_verify() and not forward_mode.is_target_verify()
and not forward_mode.is_draft_extend() and not forward_mode.is_draft_extend(include_v2=True)
): ):
return super().init_forward_metadata_replay_cuda_graph( return super().init_forward_metadata_replay_cuda_graph(
bs, bs,
...@@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
if forward_mode.is_draft_extend(): if forward_mode.is_draft_extend(include_v2=True):
accept_length = spec_info.accept_length[:bs] accept_length = spec_info.accept_length[:bs]
if spec_info.accept_length_cpu: if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs]) metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
...@@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
): ):
if self.disable_chunked_prefix_cache: if self.disable_chunked_prefix_cache:
super().init_forward_metadata(forward_batch) super().init_forward_metadata(forward_batch)
...@@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
elif ( elif (
forward_batch.forward_mode.is_decode_or_idle() forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend() or forward_batch.forward_mode.is_draft_extend(include_v2=True)
): ):
bs = forward_batch.batch_size bs = forward_batch.batch_size
...@@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_decode_metadata = TRTLLMMLADecodeMetadata( self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
block_kv_indices, max_seq_len_val block_kv_indices, max_seq_len_val
) )
if forward_batch.forward_mode.is_draft_extend(): if forward_batch.forward_mode.is_draft_extend(include_v2=True):
max_seq = forward_batch.seq_lens_cpu.max().item() max_seq = forward_batch.seq_lens_cpu.max().item()
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu) sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
...@@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if ( if (
forward_batch.forward_mode.is_target_verify() forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend() or forward_batch.forward_mode.is_draft_extend(include_v2=True)
): ):
metadata = ( metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None) getattr(forward_batch, "decode_trtllm_mla_metadata", None)
...@@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Reshape output directly without slicing # Reshape output directly without slicing
if forward_batch.forward_mode.is_draft_extend(): if forward_batch.forward_mode.is_draft_extend(include_v2=True):
raw_out = self.unpad_draft_extend_output( raw_out = self.unpad_draft_extend_output(
raw_out, raw_out,
metadata.cu_seqlens_q, metadata.cu_seqlens_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment