Unverified Commit 36a4cad7 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Support overlap-spec-v2 with trtllm_mla attention backend (#11821)

parent 65d376b4
......@@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available
from sglang.srt.utils.common import cached_triton_kernel
if is_flashinfer_available():
import flashinfer
......@@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
TRTLLM_BLOCK_CONSTRAINT = 128
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
@triton.jit
def pad_draft_extend_query_kernel(
q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
......@@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel(
)
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
@triton.jit
def unpad_draft_extend_output_kernel(
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
......@@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if (
not forward_mode.is_decode_or_idle()
and not forward_mode.is_target_verify()
and not forward_mode.is_draft_extend()
and not forward_mode.is_draft_extend(include_v2=True)
):
return super().init_forward_metadata_capture_cuda_graph(
bs,
......@@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices,
max_seq_len_val,
)
if forward_mode.is_draft_extend():
if forward_mode.is_draft_extend(include_v2=True):
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs + 1
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
......@@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if (
not forward_mode.is_decode_or_idle()
and not forward_mode.is_target_verify()
and not forward_mode.is_draft_extend()
and not forward_mode.is_draft_extend(include_v2=True)
):
return super().init_forward_metadata_replay_cuda_graph(
bs,
......@@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata = self.decode_cuda_graph_metadata[bs]
if forward_mode.is_draft_extend():
if forward_mode.is_draft_extend(include_v2=True):
accept_length = spec_info.accept_length[:bs]
if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
......@@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
if self.disable_chunked_prefix_cache:
super().init_forward_metadata(forward_batch)
......@@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
elif (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
bs = forward_batch.batch_size
......@@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
block_kv_indices, max_seq_len_val
)
if forward_batch.forward_mode.is_draft_extend():
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
max_seq = forward_batch.seq_lens_cpu.max().item()
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
......@@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
......@@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Reshape output directly without slicing
if forward_batch.forward_mode.is_draft_extend():
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
raw_out = self.unpad_draft_extend_output(
raw_out,
metadata.cu_seqlens_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment