Unverified Commit efa47334 authored by Paiiii's avatar Paiiii Committed by GitHub
Browse files

[Spec Decoding] Support MTP for dsv3.2 (#11652)


Co-authored-by: default avatarPaiiiiiiiiiiiiii <zengpai@baidu.com>
parent d658f049
......@@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
in [
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLMNextN",
]
and getattr(config, "index_topk", None) is not None
)
......
......@@ -266,7 +266,10 @@ class Indexer(CustomOp):
)
blocksize = page_size
seqlens_32 = metadata.get_seqlens_int32()
if forward_batch.forward_mode.is_target_verify():
seqlens_32 = metadata.get_seqlens_expanded()
else:
seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
seqlens_32, blocksize, self.sm_count
......@@ -317,8 +320,9 @@ class Indexer(CustomOp):
k_fp8_list = []
k_scale_list = []
ks_list = []
ke_list = []
offset = 0
seq_lens_expanded = metadata.get_seqlens_expanded()
block_tables = metadata.get_page_table_64()
assert (
......@@ -341,30 +345,34 @@ class Indexer(CustomOp):
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
ke_list.append(ke)
offset += extend_seq_len
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded
ke = torch.cat(ke_list, dim=0)
logits = deep_gemm.fp8_mqa_logits(
q_fp8,
q_fp8[:offset],
kv_fp8,
weights,
weights[:offset],
ks,
ke,
clean_logits=False,
)
token_nums, _, _ = q_fp8.shape
assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk)
raw_topk_result = metadata.topk_transform(logits, self.index_topk)
topk_result = torch.full(
(token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
)
topk_result[:offset] = raw_topk_result
return topk_result
def forward_indexer(
......@@ -500,6 +508,8 @@ class Indexer(CustomOp):
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
......@@ -521,7 +531,10 @@ class Indexer(CustomOp):
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
)
if forward_batch.forward_mode.is_decode_or_idle():
if (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
):
topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata
)
......
......@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
_is_hip = is_hip()
if _is_hip:
......@@ -148,7 +149,14 @@ NSA_DECODE_IMPL: _NSA_IMPL_T
class NativeSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
speculative_step_id=0,
topk=0,
speculative_num_steps=0,
):
super().__init__()
self.forward_metadata: NSAMetadata
self.device = model_runner.device
......@@ -185,6 +193,14 @@ class NativeSparseAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
# Speculative decoding
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
self.speculative_step_id = speculative_step_id
def get_device_int32_arange(self, l: int) -> torch.Tensor:
if l > len(self._arange_buf):
next_pow_of_2 = 1 << (l - 1).bit_length()
......@@ -208,13 +224,15 @@ class NativeSparseAttnBackend(AttentionBackend):
batch_size = forward_batch.batch_size
device = forward_batch.seq_lens.device
assert (
forward_batch.spec_info is None
), "Spec decoding is not supported for NSA backend now"
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
if forward_batch.forward_mode.is_target_verify():
draft_token_num = self.speculative_num_draft_tokens
else:
draft_token_num = 0
cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
assert forward_batch.seq_lens_cpu is not None
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :max_seqlen_k
]
......@@ -224,6 +242,41 @@ class NativeSparseAttnBackend(AttentionBackend):
max_seqlen_q = 1
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
seqlens_expanded = cache_seqlens_int32
elif forward_batch.forward_mode.is_target_verify():
max_seqlen_q = self.speculative_num_draft_tokens
nsa_max_seqlen_q = self.speculative_num_draft_tokens
cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
1,
dtype=torch.int32,
device=device,
)
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size
forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu
seqlens_int32_cpu = [
self.speculative_num_draft_tokens + kv_len
for kv_len in forward_batch.seq_lens_cpu.tolist()
]
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=device,
)
for qo_len, kv_len in zip(
extend_seq_lens_cpu,
seqlens_int32_cpu,
strict=True,
)
]
)
page_table = torch.repeat_interleave(
page_table, repeats=self.speculative_num_draft_tokens, dim=0
)
elif forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
......@@ -232,7 +285,11 @@ class NativeSparseAttnBackend(AttentionBackend):
), "All of them must not be None"
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
assert forward_batch.extend_seq_lens is not None
if any(forward_batch.extend_prefix_lens_cpu):
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
max_seqlen_q = max(extend_seq_lens_cpu)
cu_seqlens_q = compute_cu_seqlens(
forward_batch.extend_seq_lens.to(torch.int32)
......@@ -277,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
flashmla_metadata=(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
seq_len_q=1,
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
......@@ -288,6 +345,7 @@ class NativeSparseAttnBackend(AttentionBackend):
nsa_seqlens_expanded=seqlens_expanded,
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
real_page_table=self._transform_table_1_to_real(page_table),
nsa_max_seqlen_q=1,
)
self.forward_metadata = metadata
......@@ -302,7 +360,9 @@ class NativeSparseAttnBackend(AttentionBackend):
to avoid memory allocations.
"""
self.decode_cuda_graph_metadata: Dict = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cache_seqlens": torch.ones(
max_num_tokens, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
......@@ -311,7 +371,7 @@ class NativeSparseAttnBackend(AttentionBackend):
),
# fake page_table for sparse_prefill
"page_table": torch.zeros(
max_bs,
max_num_tokens,
self.max_context_len,
dtype=torch.int32,
device=self.device,
......@@ -319,9 +379,9 @@ class NativeSparseAttnBackend(AttentionBackend):
"flashmla_metadata": (
self._compute_flashmla_metadata(
cache_seqlens=torch.ones(
max_bs, dtype=torch.int32, device=self.device
max_num_tokens, dtype=torch.int32, device=self.device
),
seq_len_q=1, # TODO handle MTP which is not 1
seq_len_q=1,
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
......@@ -339,50 +399,166 @@ class NativeSparseAttnBackend(AttentionBackend):
spec_info: Optional[SpecInput],
):
"""Initialize forward metadata for capturing CUDA graph."""
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
if forward_mode.is_decode_or_idle():
# Normal Decode
# Get sequence information
cache_seqlens_int32 = seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
# Use max context length for seq_len_k
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
max_seqlen_q = 1
max_seqlen_k = page_table_1.shape[1]
# Normal Decode
# Get sequence information
cache_seqlens_int32 = seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
# Precompute page table
# Precompute cumulative sequence lengths
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
)
seqlens_expanded = cache_seqlens_int32
nsa_extend_seq_lens_list = [1] * num_tokens
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, num_tokens + 1))
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1,
)
)
else:
flashmla_metadata = None
elif forward_mode.is_target_verify():
cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
torch.int32
)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
max_seqlen_q = 1
page_table_1 = self.decode_cuda_graph_metadata["page_table"][
: bs * self.speculative_num_draft_tokens, :
]
max_seqlen_k = page_table_1.shape[1]
cu_seqlens_q = torch.arange(
0,
bs * self.speculative_num_draft_tokens + 1,
1,
dtype=torch.int32,
device=self.device,
)
# Use max context length for seq_len_k
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
max_seq_len_k = page_table_1.shape[1]
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
# Precompute page table
# Precompute cumulative sequence lengths
seqlens_int32_cpu = [
self.speculative_num_draft_tokens + kv_len
for kv_len in seq_lens.tolist()
]
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=self.device,
)
for qo_len, kv_len in zip(
extend_seq_lens_cpu,
seqlens_int32_cpu,
strict=True,
)
]
)
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
seqlens_expanded, nsa_index_topk=self.nsa_index_topk
)
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
real_page_table = self._transform_table_1_to_real(page_table_1)
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs + 1))
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1,
)
)
else:
flashmla_metadata = None
elif forward_mode.is_draft_extend():
cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
torch.int32
)
else:
flashmla_metadata = None
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
max_seqlen_k = page_table_1.shape[1]
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
extend_seq_lens = torch.full(
(bs,),
self.speculative_num_draft_tokens,
device=self.device,
dtype=torch.int32,
)
max_seqlen_q = max(extend_seq_lens_cpu)
cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32))
seqlens_int32_cpu = [
self.speculative_num_draft_tokens + kv_len
for kv_len in seq_lens.tolist()
]
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=self.device,
)
for qo_len, kv_len in zip(
extend_seq_lens_cpu,
seqlens_int32_cpu,
strict=True,
)
]
)
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
seqlens_expanded, nsa_index_topk=self.nsa_index_topk
)
nsa_extend_seq_lens_list = [1] * bs
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
# As the DeepGemm is not support for q_len = 3/4 in Indexer and every token has independent topk_indices,
# we made the Q shape [bs * speculative_num_draft_tokens, 1, head_nums, dim].
# So seq_len_q is 1 for flashmla_metadata in target_verify and draft_extend mode.
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1,
)
)
else:
flashmla_metadata = None
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
real_page_table = self._transform_table_1_to_real(page_table_1)
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=1,
max_seq_len_k=max_seq_len_k,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table_1,
......@@ -390,9 +566,9 @@ class NativeSparseAttnBackend(AttentionBackend):
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=cache_seqlens_int32,
nsa_seqlens_expanded=seqlens_expanded,
real_page_table=real_page_table,
nsa_extend_seq_lens_list=[1] * bs,
nsa_extend_seq_lens_list=nsa_extend_seq_lens_list,
)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
......@@ -411,33 +587,119 @@ class NativeSparseAttnBackend(AttentionBackend):
):
"""Initialize forward metadata for replaying CUDA graph."""
assert seq_lens_cpu is not None
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
# Normal Decode
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
max_len = int(seq_lens_cpu.max().item())
if forward_mode.is_decode_or_idle():
# Normal Decode
max_len = int(seq_lens_cpu.max().item())
cache_seqlens = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_len]
metadata.page_table_1[:, :max_len].copy_(page_indices)
nsa_cache_seqlens = compute_nsa_seqlens(
cache_seqlens, nsa_index_topk=self.nsa_index_topk
)
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
seqlens_expanded = cache_seqlens
elif forward_mode.is_target_verify():
max_seqlen_k = int(
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
)
cache_seqlens = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_len]
metadata.page_table_1[:, :max_len].copy_(page_indices)
cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to(
torch.int32
)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
page_indices = torch.repeat_interleave(
page_indices, repeats=self.speculative_num_draft_tokens, dim=0
)
metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
seqlens_int32_cpu = [
self.speculative_num_draft_tokens + kv_len
for kv_len in seq_lens_cpu.tolist()
]
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=self.device,
)
for qo_len, kv_len in zip(
extend_seq_lens_cpu,
seqlens_int32_cpu,
strict=True,
)
]
)
metadata.nsa_seqlens_expanded.copy_(seqlens_expanded)
nsa_cache_seqlens = compute_nsa_seqlens(
seqlens_expanded, self.nsa_index_topk
)
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
elif forward_mode.is_draft_extend():
max_seqlen_k = int(seq_lens_cpu.max().item())
cache_seqlens = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
extend_seq_lens_cpu = spec_info.accept_length[:bs].tolist()
seqlens_int32_cpu = [
self.speculative_num_draft_tokens + kv_len
for kv_len in seq_lens_cpu.tolist()
]
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=self.device,
)
for qo_len, kv_len in zip(
extend_seq_lens_cpu,
seqlens_int32_cpu,
strict=True,
)
]
)
metadata.nsa_seqlens_expanded[: seqlens_expanded.size(0)].copy_(
seqlens_expanded
)
nsa_cache_seqlens = compute_nsa_seqlens(
seqlens_expanded, self.nsa_index_topk
)
metadata.nsa_cache_seqlens_int32[: seqlens_expanded.size(0)].copy_(
nsa_cache_seqlens
)
seqlens_expanded_size = seqlens_expanded.size(0)
assert (
metadata.nsa_cache_seqlens_int32 is not None
and metadata.nsa_cu_seqlens_k is not None
and self.nsa_index_topk is not None
)
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
metadata.nsa_cu_seqlens_k[1:].copy_(
metadata.nsa_cu_seqlens_k[1 : 1 + seqlens_expanded_size].copy_(
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
)
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
......@@ -451,10 +713,13 @@ class NativeSparseAttnBackend(AttentionBackend):
assert metadata.real_page_table is metadata.page_table_1
if NSA_DECODE_IMPL == "flashmla_decode":
metadata.flashmla_metadata.copy_(
flashmla_metadata = metadata.flashmla_metadata.slice(
slice(0, seqlens_expanded_size + 1)
)
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens,
seq_len_q=1, # TODO handle MTP which is not 1
seq_len_q=1,
)
)
......@@ -473,10 +738,7 @@ class NativeSparseAttnBackend(AttentionBackend):
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert (
not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
), "NSA backend doesn't support speculative decoding"
if k is not None:
assert v is not None
if save_kv_cache:
......@@ -884,3 +1146,58 @@ class NativeSparseAttnBackend(AttentionBackend):
flashmla_metadata=flashmla_metadata,
num_splits=num_splits,
)
class NativeSparseAttnMultiStepBackend:
def __init__(
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
):
self.model_runner = model_runner
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
NativeSparseAttnBackend(
model_runner,
speculative_step_id=i,
topk=self.topk,
speculative_num_steps=self.speculative_num_steps,
)
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
......@@ -48,6 +48,7 @@ class DraftBackendFactory:
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend,
}
return self._create_backend(
......@@ -70,6 +71,7 @@ class DraftBackendFactory:
"flashmla": self._create_flashmla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend,
}
backend_name = (
"decode_attention_backend"
......@@ -82,6 +84,20 @@ class DraftBackendFactory:
"EAGLE is not supported in attention backend {backend_type}",
)
def _create_nsa_decode_backend(self):
from sglang.srt.layers.attention.nsa_backend import (
NativeSparseAttnMultiStepBackend,
)
return NativeSparseAttnMultiStepBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_nsa_prefill_backend(self):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_flashinfer_decode_backend(self):
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
......
......@@ -81,6 +81,7 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs
if self.enable_torch_compile:
set_torch_compile_config()
......@@ -92,6 +93,7 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros(
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
)
......@@ -165,6 +167,9 @@ class EAGLEDraftCudaGraphRunner:
# Graph inputs
req_pool_indices = self.req_pool_indices[:num_seqs]
seq_lens = self.seq_lens[:num_seqs]
seq_lens_cpu = self.seq_lens_cpu[:num_seqs]
extend_seq_lens = self.extend_seq_lens[:num_seqs]
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs]
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
positions = self.positions[:num_tokens]
mrope_positions = self.mrope_positions[:, :num_tokens]
......@@ -227,6 +232,9 @@ class EAGLEDraftCudaGraphRunner:
input_ids=None,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
......
......@@ -78,6 +78,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs
if self.enable_torch_compile:
set_torch_compile_config()
......@@ -196,7 +197,9 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
seq_lens_cpu = self.seq_lens_cpu[:bs]
extend_seq_lens = self.extend_seq_lens[:bs]
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs]
accept_length = self.accept_length[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
......@@ -254,6 +257,7 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
next_token_logits_buffer=next_token_logits_buffer,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
......@@ -271,6 +275,7 @@ class EAGLEDraftExtendCudaGraphRunner:
capture_hidden_mode=CaptureHiddenMode.LAST,
attn_backend=self.eagle_worker.draft_extend_attn_backend,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
padded_static_len=self.padded_static_len,
)
......@@ -373,6 +378,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if forward_batch.extend_seq_lens_cpu is not None:
self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu
if bs != raw_bs:
forward_batch.spec_info.positions = self.positions[:num_tokens]
forward_batch.spec_info.accept_length = self.accept_length[:bs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment