Unverified Commit efa47334 authored by Paiiii's avatar Paiiii Committed by GitHub
Browse files

[Spec Decoding] Support MTP for dsv3.2 (#11652)


Co-authored-by: default avatarPaiiiiiiiiiiiiii <zengpai@baidu.com>
parent d658f049
...@@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool: ...@@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return ( return (
config.architectures is not None config.architectures is not None
and config.architectures[0] and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"] in [
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLMNextN",
]
and getattr(config, "index_topk", None) is not None and getattr(config, "index_topk", None) is not None
) )
......
...@@ -266,6 +266,9 @@ class Indexer(CustomOp): ...@@ -266,6 +266,9 @@ class Indexer(CustomOp):
) )
blocksize = page_size blocksize = page_size
if forward_batch.forward_mode.is_target_verify():
seqlens_32 = metadata.get_seqlens_expanded()
else:
seqlens_32 = metadata.get_seqlens_int32() seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number # NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
...@@ -317,8 +320,9 @@ class Indexer(CustomOp): ...@@ -317,8 +320,9 @@ class Indexer(CustomOp):
k_fp8_list = [] k_fp8_list = []
k_scale_list = [] k_scale_list = []
ks_list = [] ks_list = []
ke_list = []
offset = 0 offset = 0
seq_lens_expanded = metadata.get_seqlens_expanded()
block_tables = metadata.get_page_table_64() block_tables = metadata.get_page_table_64()
assert ( assert (
...@@ -341,30 +345,34 @@ class Indexer(CustomOp): ...@@ -341,30 +345,34 @@ class Indexer(CustomOp):
) )
extend_seq_len = forward_batch.extend_seq_lens_cpu[i] extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda") ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
k_fp8_list.append(k_fp8) k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale) k_scale_list.append(k_scale)
ks_list.append(ks) ks_list.append(ks)
ke_list.append(ke)
offset += extend_seq_len offset += extend_seq_len
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale) kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0) ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded() ke = torch.cat(ke_list, dim=0)
ke = ks + seq_lens_expanded
logits = deep_gemm.fp8_mqa_logits( logits = deep_gemm.fp8_mqa_logits(
q_fp8, q_fp8[:offset],
kv_fp8, kv_fp8,
weights, weights[:offset],
ks, ks,
ke, ke,
clean_logits=False, clean_logits=False,
) )
token_nums, _, _ = q_fp8.shape
assert logits.shape[0] == len(seq_lens_expanded) assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk) raw_topk_result = metadata.topk_transform(logits, self.index_topk)
topk_result = torch.full(
(token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
)
topk_result[:offset] = raw_topk_result
return topk_result return topk_result
def forward_indexer( def forward_indexer(
...@@ -500,6 +508,8 @@ class Indexer(CustomOp): ...@@ -500,6 +508,8 @@ class Indexer(CustomOp):
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer( forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id, layer_id=layer_id,
loc=forward_batch.out_cache_loc, loc=forward_batch.out_cache_loc,
...@@ -521,7 +531,10 @@ class Indexer(CustomOp): ...@@ -521,7 +531,10 @@ class Indexer(CustomOp):
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda" (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
) )
if forward_batch.forward_mode.is_decode_or_idle(): if (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
):
topk_result = self._get_topk_paged( topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata forward_batch, layer_id, q_fp8, weights, metadata
) )
......
...@@ -48,6 +48,7 @@ class DraftBackendFactory: ...@@ -48,6 +48,7 @@ class DraftBackendFactory:
"flashmla": self._create_flashmla_decode_backend, "flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend,
} }
return self._create_backend( return self._create_backend(
...@@ -70,6 +71,7 @@ class DraftBackendFactory: ...@@ -70,6 +71,7 @@ class DraftBackendFactory:
"flashmla": self._create_flashmla_prefill_backend, "flashmla": self._create_flashmla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend,
} }
backend_name = ( backend_name = (
"decode_attention_backend" "decode_attention_backend"
...@@ -82,6 +84,20 @@ class DraftBackendFactory: ...@@ -82,6 +84,20 @@ class DraftBackendFactory:
"EAGLE is not supported in attention backend {backend_type}", "EAGLE is not supported in attention backend {backend_type}",
) )
def _create_nsa_decode_backend(self):
from sglang.srt.layers.attention.nsa_backend import (
NativeSparseAttnMultiStepBackend,
)
return NativeSparseAttnMultiStepBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_nsa_prefill_backend(self):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_flashinfer_decode_backend(self): def _create_flashinfer_decode_backend(self):
if not get_global_server_args().use_mla_backend: if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
......
...@@ -81,6 +81,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -81,6 +81,7 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens_cpu = torch.full( self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
) )
self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs
if self.enable_torch_compile: if self.enable_torch_compile:
set_torch_compile_config() set_torch_compile_config()
...@@ -92,6 +93,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -92,6 +93,7 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens = torch.full( self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
) )
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros( self.out_cache_loc = torch.zeros(
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64 (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
) )
...@@ -165,6 +167,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -165,6 +167,9 @@ class EAGLEDraftCudaGraphRunner:
# Graph inputs # Graph inputs
req_pool_indices = self.req_pool_indices[:num_seqs] req_pool_indices = self.req_pool_indices[:num_seqs]
seq_lens = self.seq_lens[:num_seqs] seq_lens = self.seq_lens[:num_seqs]
seq_lens_cpu = self.seq_lens_cpu[:num_seqs]
extend_seq_lens = self.extend_seq_lens[:num_seqs]
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs]
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
mrope_positions = self.mrope_positions[:, :num_tokens] mrope_positions = self.mrope_positions[:, :num_tokens]
...@@ -227,6 +232,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -227,6 +232,9 @@ class EAGLEDraftCudaGraphRunner:
input_ids=None, input_ids=None,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
......
...@@ -78,6 +78,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -78,6 +78,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.seq_lens_cpu = torch.full( self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
) )
self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs
if self.enable_torch_compile: if self.enable_torch_compile:
set_torch_compile_config() set_torch_compile_config()
...@@ -196,7 +197,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -196,7 +197,9 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs] req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs] seq_lens = self.seq_lens[:bs]
seq_lens_cpu = self.seq_lens_cpu[:bs]
extend_seq_lens = self.extend_seq_lens[:bs] extend_seq_lens = self.extend_seq_lens[:bs]
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs]
accept_length = self.accept_length[:bs] accept_length = self.accept_length[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens] out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
...@@ -254,6 +257,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -254,6 +257,7 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids=input_ids, input_ids=input_ids,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
next_token_logits_buffer=next_token_logits_buffer, next_token_logits_buffer=next_token_logits_buffer,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
...@@ -271,6 +275,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -271,6 +275,7 @@ class EAGLEDraftExtendCudaGraphRunner:
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
attn_backend=self.eagle_worker.draft_extend_attn_backend, attn_backend=self.eagle_worker.draft_extend_attn_backend,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
padded_static_len=self.padded_static_len, padded_static_len=self.padded_static_len,
) )
...@@ -373,6 +378,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -373,6 +378,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if forward_batch.extend_seq_lens_cpu is not None:
self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu
if bs != raw_bs: if bs != raw_bs:
forward_batch.spec_info.positions = self.positions[:num_tokens] forward_batch.spec_info.positions = self.positions[:num_tokens]
forward_batch.spec_info.accept_length = self.accept_length[:bs] forward_batch.spec_info.accept_length = self.accept_length[:bs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment