Unverified Commit efa47334 authored by Paiiii's avatar Paiiii Committed by GitHub
Browse files

[Spec Decoding] Support MTP for dsv3.2 (#11652)


Co-authored-by: default avatarPaiiiiiiiiiiiiii <zengpai@baidu.com>
parent d658f049
......@@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
in [
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLMNextN",
]
and getattr(config, "index_topk", None) is not None
)
......
......@@ -266,6 +266,9 @@ class Indexer(CustomOp):
)
blocksize = page_size
if forward_batch.forward_mode.is_target_verify():
seqlens_32 = metadata.get_seqlens_expanded()
else:
seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
......@@ -317,8 +320,9 @@ class Indexer(CustomOp):
k_fp8_list = []
k_scale_list = []
ks_list = []
ke_list = []
offset = 0
seq_lens_expanded = metadata.get_seqlens_expanded()
block_tables = metadata.get_page_table_64()
assert (
......@@ -341,30 +345,34 @@ class Indexer(CustomOp):
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
ke_list.append(ke)
offset += extend_seq_len
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded
ke = torch.cat(ke_list, dim=0)
logits = deep_gemm.fp8_mqa_logits(
q_fp8,
q_fp8[:offset],
kv_fp8,
weights,
weights[:offset],
ks,
ke,
clean_logits=False,
)
token_nums, _, _ = q_fp8.shape
assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk)
raw_topk_result = metadata.topk_transform(logits, self.index_topk)
topk_result = torch.full(
(token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
)
topk_result[:offset] = raw_topk_result
return topk_result
def forward_indexer(
......@@ -500,6 +508,8 @@ class Indexer(CustomOp):
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
......@@ -521,7 +531,10 @@ class Indexer(CustomOp):
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
)
if forward_batch.forward_mode.is_decode_or_idle():
if (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
):
topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata
)
......
......@@ -48,6 +48,7 @@ class DraftBackendFactory:
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend,
}
return self._create_backend(
......@@ -70,6 +71,7 @@ class DraftBackendFactory:
"flashmla": self._create_flashmla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend,
}
backend_name = (
"decode_attention_backend"
......@@ -82,6 +84,20 @@ class DraftBackendFactory:
"EAGLE is not supported in attention backend {backend_type}",
)
def _create_nsa_decode_backend(self):
from sglang.srt.layers.attention.nsa_backend import (
NativeSparseAttnMultiStepBackend,
)
return NativeSparseAttnMultiStepBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_nsa_prefill_backend(self):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_flashinfer_decode_backend(self):
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
......
......@@ -81,6 +81,7 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs
if self.enable_torch_compile:
set_torch_compile_config()
......@@ -92,6 +93,7 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros(
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
)
......@@ -165,6 +167,9 @@ class EAGLEDraftCudaGraphRunner:
# Graph inputs
req_pool_indices = self.req_pool_indices[:num_seqs]
seq_lens = self.seq_lens[:num_seqs]
seq_lens_cpu = self.seq_lens_cpu[:num_seqs]
extend_seq_lens = self.extend_seq_lens[:num_seqs]
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs]
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
positions = self.positions[:num_tokens]
mrope_positions = self.mrope_positions[:, :num_tokens]
......@@ -227,6 +232,9 @@ class EAGLEDraftCudaGraphRunner:
input_ids=None,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
......
......@@ -78,6 +78,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs
if self.enable_torch_compile:
set_torch_compile_config()
......@@ -196,7 +197,9 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
seq_lens_cpu = self.seq_lens_cpu[:bs]
extend_seq_lens = self.extend_seq_lens[:bs]
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs]
accept_length = self.accept_length[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
......@@ -254,6 +257,7 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
next_token_logits_buffer=next_token_logits_buffer,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
......@@ -271,6 +275,7 @@ class EAGLEDraftExtendCudaGraphRunner:
capture_hidden_mode=CaptureHiddenMode.LAST,
attn_backend=self.eagle_worker.draft_extend_attn_backend,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
padded_static_len=self.padded_static_len,
)
......@@ -373,6 +378,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if forward_batch.extend_seq_lens_cpu is not None:
self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu
if bs != raw_bs:
forward_batch.spec_info.positions = self.positions[:num_tokens]
forward_batch.spec_info.accept_length = self.accept_length[:bs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment