Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
efa47334
Unverified
Commit
efa47334
authored
Oct 19, 2025
by
Paiiii
Committed by
GitHub
Oct 19, 2025
Browse files
[Spec Decoding] Support MTP for dsv3.2 (#11652)
Co-authored-by:
Paiiiiiiiiiiiiii
<
zengpai@baidu.com
>
parent
d658f049
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
445 additions
and
79 deletions
+445
-79
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+5
-1
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+23
-10
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+385
-68
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+16
-0
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+8
-0
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+8
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
efa47334
...
...
@@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return
(
config
.
architectures
is
not
None
and
config
.
architectures
[
0
]
in
[
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
]
in
[
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
"DeepseekV3ForCausalLMNextN"
,
]
and
getattr
(
config
,
"index_topk"
,
None
)
is
not
None
)
...
...
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
efa47334
...
...
@@ -266,7 +266,10 @@ class Indexer(CustomOp):
)
blocksize
=
page_size
seqlens_32
=
metadata
.
get_seqlens_int32
()
if
forward_batch
.
forward_mode
.
is_target_verify
():
seqlens_32
=
metadata
.
get_seqlens_expanded
()
else
:
seqlens_32
=
metadata
.
get_seqlens_int32
()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata
=
deep_gemm
.
get_paged_mqa_logits_metadata
(
seqlens_32
,
blocksize
,
self
.
sm_count
...
...
@@ -317,8 +320,9 @@ class Indexer(CustomOp):
k_fp8_list
=
[]
k_scale_list
=
[]
ks_list
=
[]
ke_list
=
[]
offset
=
0
seq_lens_expanded
=
metadata
.
get_seqlens_expanded
()
block_tables
=
metadata
.
get_page_table_64
()
assert
(
...
...
@@ -341,30 +345,34 @@ class Indexer(CustomOp):
)
extend_seq_len
=
forward_batch
.
extend_seq_lens_cpu
[
i
]
ks
=
torch
.
full
((
extend_seq_len
,),
offset
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ke
=
ks
+
seq_lens_expanded
[
offset
:
offset
+
extend_seq_len
]
k_fp8_list
.
append
(
k_fp8
)
k_scale_list
.
append
(
k_scale
)
ks_list
.
append
(
ks
)
ke_list
.
append
(
ke
)
offset
+=
extend_seq_len
k_fp8
=
torch
.
cat
(
k_fp8_list
,
dim
=
0
).
view
(
torch
.
float8_e4m3fn
)
k_scale
=
torch
.
cat
(
k_scale_list
,
dim
=
0
).
view
(
torch
.
float32
).
squeeze
(
-
1
)
kv_fp8
=
(
k_fp8
,
k_scale
)
ks
=
torch
.
cat
(
ks_list
,
dim
=
0
)
seq_lens_expanded
=
metadata
.
get_seqlens_expanded
()
ke
=
ks
+
seq_lens_expanded
ke
=
torch
.
cat
(
ke_list
,
dim
=
0
)
logits
=
deep_gemm
.
fp8_mqa_logits
(
q_fp8
,
q_fp8
[:
offset
]
,
kv_fp8
,
weights
,
weights
[:
offset
]
,
ks
,
ke
,
clean_logits
=
False
,
)
token_nums
,
_
,
_
=
q_fp8
.
shape
assert
logits
.
shape
[
0
]
==
len
(
seq_lens_expanded
)
topk_result
=
metadata
.
topk_transform
(
logits
,
self
.
index_topk
)
raw_topk_result
=
metadata
.
topk_transform
(
logits
,
self
.
index_topk
)
topk_result
=
torch
.
full
(
(
token_nums
,
self
.
index_topk
),
-
1
,
device
=
q_fp8
.
device
,
dtype
=
torch
.
int32
)
topk_result
[:
offset
]
=
raw_topk_result
return
topk_result
def
forward_indexer
(
...
...
@@ -500,6 +508,8 @@ class Indexer(CustomOp):
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
if
not
forward_batch
.
out_cache_loc
.
is_contiguous
():
forward_batch
.
out_cache_loc
=
forward_batch
.
out_cache_loc
.
contiguous
()
forward_batch
.
token_to_kv_pool
.
set_index_k_and_scale_buffer
(
layer_id
=
layer_id
,
loc
=
forward_batch
.
out_cache_loc
,
...
...
@@ -521,7 +531,10 @@ class Indexer(CustomOp):
(
x
.
shape
[
0
],
self
.
index_topk
),
-
1
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
(
forward_batch
.
forward_mode
.
is_decode_or_idle
()
or
forward_batch
.
forward_mode
.
is_target_verify
()
):
topk_result
=
self
.
_get_topk_paged
(
forward_batch
,
layer_id
,
q_fp8
,
weights
,
metadata
)
...
...
python/sglang/srt/layers/attention/nsa_backend.py
View file @
efa47334
...
...
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInput
_is_hip
=
is_hip
()
if
_is_hip
:
...
...
@@ -148,7 +149,14 @@ NSA_DECODE_IMPL: _NSA_IMPL_T
class
NativeSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
speculative_step_id
=
0
,
topk
=
0
,
speculative_num_steps
=
0
,
):
super
().
__init__
()
self
.
forward_metadata
:
NSAMetadata
self
.
device
=
model_runner
.
device
...
...
@@ -185,6 +193,14 @@ class NativeSparseAttnBackend(AttentionBackend):
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
# Speculative decoding
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
or
0
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_draft_tokens
=
(
model_runner
.
server_args
.
speculative_num_draft_tokens
)
self
.
speculative_step_id
=
speculative_step_id
def
get_device_int32_arange
(
self
,
l
:
int
)
->
torch
.
Tensor
:
if
l
>
len
(
self
.
_arange_buf
):
next_pow_of_2
=
1
<<
(
l
-
1
).
bit_length
()
...
...
@@ -208,13 +224,15 @@ class NativeSparseAttnBackend(AttentionBackend):
batch_size
=
forward_batch
.
batch_size
device
=
forward_batch
.
seq_lens
.
device
assert
(
forward_batch
.
spec_info
is
None
),
"Spec decoding is not supported for NSA backend now"
cache_seqlens_int32
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
if
forward_batch
.
forward_mode
.
is_target_verify
():
draft_token_num
=
self
.
speculative_num_draft_tokens
else
:
draft_token_num
=
0
cache_seqlens_int32
=
(
forward_batch
.
seq_lens
+
draft_token_num
).
to
(
torch
.
int32
)
cu_seqlens_k
=
compute_cu_seqlens
(
cache_seqlens_int32
)
assert
forward_batch
.
seq_lens_cpu
is
not
None
max_seqlen_k
=
int
(
forward_batch
.
seq_lens_cpu
.
max
().
item
())
max_seqlen_k
=
int
(
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
draft_token_num
)
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
max_seqlen_k
]
...
...
@@ -224,6 +242,41 @@ class NativeSparseAttnBackend(AttentionBackend):
max_seqlen_q
=
1
cu_seqlens_q
=
self
.
get_device_int32_arange
(
batch_size
+
1
)
seqlens_expanded
=
cache_seqlens_int32
elif
forward_batch
.
forward_mode
.
is_target_verify
():
max_seqlen_q
=
self
.
speculative_num_draft_tokens
nsa_max_seqlen_q
=
self
.
speculative_num_draft_tokens
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
*
self
.
speculative_num_draft_tokens
+
1
,
1
,
dtype
=
torch
.
int32
,
device
=
device
,
)
extend_seq_lens_cpu
=
[
self
.
speculative_num_draft_tokens
]
*
batch_size
forward_batch
.
extend_seq_lens_cpu
=
extend_seq_lens_cpu
seqlens_int32_cpu
=
[
self
.
speculative_num_draft_tokens
+
kv_len
for
kv_len
in
forward_batch
.
seq_lens_cpu
.
tolist
()
]
seqlens_expanded
=
torch
.
cat
(
[
torch
.
arange
(
kv_len
-
qo_len
+
1
,
kv_len
+
1
,
dtype
=
torch
.
int32
,
device
=
device
,
)
for
qo_len
,
kv_len
in
zip
(
extend_seq_lens_cpu
,
seqlens_int32_cpu
,
strict
=
True
,
)
]
)
page_table
=
torch
.
repeat_interleave
(
page_table
,
repeats
=
self
.
speculative_num_draft_tokens
,
dim
=
0
)
elif
forward_batch
.
forward_mode
.
is_extend
():
assert
(
forward_batch
.
extend_seq_lens_cpu
is
not
None
...
...
@@ -232,7 +285,11 @@ class NativeSparseAttnBackend(AttentionBackend):
),
"All of them must not be None"
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
assert
forward_batch
.
extend_seq_lens
is
not
None
if
any
(
forward_batch
.
extend_prefix_lens_cpu
):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
max_seqlen_q
=
max
(
extend_seq_lens_cpu
)
cu_seqlens_q
=
compute_cu_seqlens
(
forward_batch
.
extend_seq_lens
.
to
(
torch
.
int32
)
...
...
@@ -277,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
flashmla_metadata
=
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
# TODO handle MTP which is not 1
seq_len_q
=
1
,
)
if
NSA_DECODE_IMPL
==
"flashmla_decode"
else
None
...
...
@@ -288,6 +345,7 @@ class NativeSparseAttnBackend(AttentionBackend):
nsa_seqlens_expanded
=
seqlens_expanded
,
nsa_extend_seq_lens_list
=
extend_seq_lens_cpu
,
real_page_table
=
self
.
_transform_table_1_to_real
(
page_table
),
nsa_max_seqlen_q
=
1
,
)
self
.
forward_metadata
=
metadata
...
...
@@ -302,7 +360,9 @@ class NativeSparseAttnBackend(AttentionBackend):
to avoid memory allocations.
"""
self
.
decode_cuda_graph_metadata
:
Dict
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cache_seqlens"
:
torch
.
ones
(
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
...
...
@@ -311,7 +371,7 @@ class NativeSparseAttnBackend(AttentionBackend):
),
# fake page_table for sparse_prefill
"page_table"
:
torch
.
zeros
(
max_
b
s
,
max_
num_token
s
,
self
.
max_context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
...
...
@@ -319,9 +379,9 @@ class NativeSparseAttnBackend(AttentionBackend):
"flashmla_metadata"
:
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
torch
.
ones
(
max_
b
s
,
dtype
=
torch
.
int32
,
device
=
self
.
device
max_
num_token
s
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
seq_len_q
=
1
,
# TODO handle MTP which is not 1
seq_len_q
=
1
,
)
if
NSA_DECODE_IMPL
==
"flashmla_decode"
else
None
...
...
@@ -339,50 +399,166 @@ class NativeSparseAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInput
],
):
"""Initialize forward metadata for capturing CUDA graph."""
assert
forward_mode
.
is_decode_or_idle
(),
"Only support decode for now"
assert
(
spec_info
is
None
),
"Speculative decoding is not supported for NSA backend now"
if
forward_mode
.
is_decode_or_idle
():
# Normal Decode
# Get sequence information
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
cu_seqlens_k
=
compute_cu_seqlens
(
cache_seqlens_int32
)
# Use max context length for seq_len_k
page_table_1
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
max_seqlen_q
=
1
max_seqlen_k
=
page_table_1
.
shape
[
1
]
# Normal Decode
# Get sequence information
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
cu_seqlens_k
=
compute_cu_seqlens
(
cache_seqlens_int32
)
# Precompute page table
# Precompute cumulative sequence lengths
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
][:
bs
+
1
]
nsa_cache_seqlens_int32
=
compute_nsa_seqlens
(
cache_seqlens_int32
,
nsa_index_topk
=
self
.
nsa_index_topk
)
seqlens_expanded
=
cache_seqlens_int32
nsa_extend_seq_lens_list
=
[
1
]
*
num_tokens
if
NSA_DECODE_IMPL
==
"flashmla_decode"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
num_tokens
+
1
))
flashmla_metadata
.
copy_
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
)
)
else
:
flashmla_metadata
=
None
elif
forward_mode
.
is_target_verify
():
cache_seqlens_int32
=
(
seq_lens
+
self
.
speculative_num_draft_tokens
).
to
(
torch
.
int32
)
cu_seqlens_k
=
compute_cu_seqlens
(
cache_seqlens_int32
)
max_seqlen_q
=
1
page_table_1
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][
:
bs
*
self
.
speculative_num_draft_tokens
,
:
]
max_seqlen_k
=
page_table_1
.
shape
[
1
]
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
self
.
speculative_num_draft_tokens
+
1
,
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Use max context length for seq_len_k
page_table_1
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
max_seq_len_k
=
page_table_1
.
shape
[
1
]
extend_seq_lens_cpu
=
[
self
.
speculative_num_draft_tokens
]
*
bs
# Precompute page table
# Precompute cumulative sequence lengths
seqlens_int32_cpu
=
[
self
.
speculative_num_draft_tokens
+
kv_len
for
kv_len
in
seq_lens
.
tolist
()
]
seqlens_expanded
=
torch
.
cat
(
[
torch
.
arange
(
kv_len
-
qo_len
+
1
,
kv_len
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
qo_len
,
kv_len
in
zip
(
extend_seq_lens_cpu
,
seqlens_int32_cpu
,
strict
=
True
,
)
]
)
nsa_cache_seqlens_int32
=
compute_nsa_seqlens
(
seqlens_expanded
,
nsa_index_topk
=
self
.
nsa_index_topk
)
nsa_extend_seq_lens_list
=
[
1
]
*
bs
*
self
.
speculative_num_draft_tokens
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
][:
bs
+
1
]
nsa_cache_seqlens_int32
=
compute_nsa_seqlens
(
cache_seqlens_int32
,
nsa_index_topk
=
self
.
nsa_index_topk
)
nsa_cu_seqlens_k
=
compute_cu_seqlens
(
nsa_cache_seqlens_int32
)
nsa_cu_seqlens_q
=
self
.
get_device_int32_arange
(
len
(
nsa_cu_seqlens_k
))
real_page_table
=
self
.
_transform_table_1_to_real
(
page_table_1
)
if
NSA_DECODE_IMPL
==
"flashmla_decode"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
bs
*
self
.
speculative_num_draft_tokens
+
1
))
if
NSA_DECODE_IMPL
==
"flashmla_decode"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
bs
+
1
))
flashmla_metadata
.
copy_
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
# TODO handle MTP which is not 1
flashmla_metadata
.
copy_
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
)
)
else
:
flashmla_metadata
=
None
elif
forward_mode
.
is_draft_extend
():
cache_seqlens_int32
=
(
seq_lens
+
self
.
speculative_num_draft_tokens
).
to
(
torch
.
int32
)
else
:
flashmla_metadata
=
None
cu_seqlens_k
=
compute_cu_seqlens
(
cache_seqlens_int32
)
page_table_1
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
max_seqlen_k
=
page_table_1
.
shape
[
1
]
extend_seq_lens_cpu
=
[
self
.
speculative_num_draft_tokens
]
*
bs
extend_seq_lens
=
torch
.
full
(
(
bs
,),
self
.
speculative_num_draft_tokens
,
device
=
self
.
device
,
dtype
=
torch
.
int32
,
)
max_seqlen_q
=
max
(
extend_seq_lens_cpu
)
cu_seqlens_q
=
compute_cu_seqlens
(
extend_seq_lens
.
to
(
torch
.
int32
))
seqlens_int32_cpu
=
[
self
.
speculative_num_draft_tokens
+
kv_len
for
kv_len
in
seq_lens
.
tolist
()
]
seqlens_expanded
=
torch
.
cat
(
[
torch
.
arange
(
kv_len
-
qo_len
+
1
,
kv_len
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
qo_len
,
kv_len
in
zip
(
extend_seq_lens_cpu
,
seqlens_int32_cpu
,
strict
=
True
,
)
]
)
nsa_cache_seqlens_int32
=
compute_nsa_seqlens
(
seqlens_expanded
,
nsa_index_topk
=
self
.
nsa_index_topk
)
nsa_extend_seq_lens_list
=
[
1
]
*
bs
if
NSA_DECODE_IMPL
==
"flashmla_decode"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
bs
*
self
.
speculative_num_draft_tokens
+
1
))
# As the DeepGemm is not support for q_len = 3/4 in Indexer and every token has independent topk_indices,
# we made the Q shape [bs * speculative_num_draft_tokens, 1, head_nums, dim].
# So seq_len_q is 1 for flashmla_metadata in target_verify and draft_extend mode.
flashmla_metadata
.
copy_
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
)
)
else
:
flashmla_metadata
=
None
nsa_cu_seqlens_k
=
compute_cu_seqlens
(
nsa_cache_seqlens_int32
)
nsa_cu_seqlens_q
=
self
.
get_device_int32_arange
(
len
(
nsa_cu_seqlens_k
))
real_page_table
=
self
.
_transform_table_1_to_real
(
page_table_1
)
metadata
=
NSAMetadata
(
page_size
=
self
.
real_page_size
,
cache_seqlens_int32
=
cache_seqlens_int32
,
max_seq_len_q
=
1
,
max_seq_len_k
=
max_seq
_
len_k
,
max_seq_len_q
=
max_seqlen_q
,
max_seq_len_k
=
max_seqlen_k
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
page_table_1
=
page_table_1
,
...
...
@@ -390,9 +566,9 @@ class NativeSparseAttnBackend(AttentionBackend):
nsa_cache_seqlens_int32
=
nsa_cache_seqlens_int32
,
nsa_cu_seqlens_q
=
nsa_cu_seqlens_q
,
nsa_cu_seqlens_k
=
nsa_cu_seqlens_k
,
nsa_seqlens_expanded
=
cache_
seqlens_
int32
,
nsa_seqlens_expanded
=
seqlens_
expanded
,
real_page_table
=
real_page_table
,
nsa_extend_seq_lens_list
=
[
1
]
*
bs
,
nsa_extend_seq_lens_list
=
nsa_extend_seq_lens_list
,
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
...
...
@@ -411,33 +587,119 @@ class NativeSparseAttnBackend(AttentionBackend):
):
"""Initialize forward metadata for replaying CUDA graph."""
assert
seq_lens_cpu
is
not
None
assert
forward_mode
.
is_decode_or_idle
(),
"Only support decode for now"
assert
(
spec_info
is
None
),
"Speculative decoding is not supported for NSA backend now"
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
# Normal Decode
metadata
:
NSAMetadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
int
(
seq_lens_cpu
.
max
().
item
())
if
forward_mode
.
is_decode_or_idle
():
# Normal Decode
max_len
=
int
(
seq_lens_cpu
.
max
().
item
())
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
.
copy_
(
cache_seqlens
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
cache_seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
page_indices
=
self
.
req_to_token
[
req_pool_indices
,
:
max_len
]
metadata
.
page_table_1
[:,
:
max_len
].
copy_
(
page_indices
)
nsa_cache_seqlens
=
compute_nsa_seqlens
(
cache_seqlens
,
nsa_index_topk
=
self
.
nsa_index_topk
)
metadata
.
nsa_cache_seqlens_int32
.
copy_
(
nsa_cache_seqlens
)
seqlens_expanded
=
cache_seqlens
elif
forward_mode
.
is_target_verify
():
max_seqlen_k
=
int
(
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
.
copy_
(
cache_seqlens
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
cache_seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
page_indices
=
self
.
req_to_token
[
req_pool_indices
,
:
max_len
]
metadata
.
page_table_1
[:,
:
max_len
].
copy_
(
page_indices
)
cache_seqlens
=
(
seq_lens
+
self
.
speculative_num_draft_tokens
).
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
.
copy_
(
cache_seqlens
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
cache_seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
page_indices
=
self
.
req_to_token
[
req_pool_indices
,
:
max_seqlen_k
]
page_indices
=
torch
.
repeat_interleave
(
page_indices
,
repeats
=
self
.
speculative_num_draft_tokens
,
dim
=
0
)
metadata
.
page_table_1
[:,
:
max_seqlen_k
].
copy_
(
page_indices
)
extend_seq_lens_cpu
=
[
self
.
speculative_num_draft_tokens
]
*
bs
seqlens_int32_cpu
=
[
self
.
speculative_num_draft_tokens
+
kv_len
for
kv_len
in
seq_lens_cpu
.
tolist
()
]
seqlens_expanded
=
torch
.
cat
(
[
torch
.
arange
(
kv_len
-
qo_len
+
1
,
kv_len
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
qo_len
,
kv_len
in
zip
(
extend_seq_lens_cpu
,
seqlens_int32_cpu
,
strict
=
True
,
)
]
)
metadata
.
nsa_seqlens_expanded
.
copy_
(
seqlens_expanded
)
nsa_cache_seqlens
=
compute_nsa_seqlens
(
seqlens_expanded
,
self
.
nsa_index_topk
)
metadata
.
nsa_cache_seqlens_int32
.
copy_
(
nsa_cache_seqlens
)
elif
forward_mode
.
is_draft_extend
():
max_seqlen_k
=
int
(
seq_lens_cpu
.
max
().
item
())
cache_seqlens
=
seq_lens
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
.
copy_
(
cache_seqlens
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
cache_seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
page_indices
=
self
.
req_to_token
[
req_pool_indices
,
:
max_seqlen_k
]
metadata
.
page_table_1
[:,
:
max_seqlen_k
].
copy_
(
page_indices
)
extend_seq_lens_cpu
=
spec_info
.
accept_length
[:
bs
].
tolist
()
seqlens_int32_cpu
=
[
self
.
speculative_num_draft_tokens
+
kv_len
for
kv_len
in
seq_lens_cpu
.
tolist
()
]
seqlens_expanded
=
torch
.
cat
(
[
torch
.
arange
(
kv_len
-
qo_len
+
1
,
kv_len
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
qo_len
,
kv_len
in
zip
(
extend_seq_lens_cpu
,
seqlens_int32_cpu
,
strict
=
True
,
)
]
)
metadata
.
nsa_seqlens_expanded
[:
seqlens_expanded
.
size
(
0
)].
copy_
(
seqlens_expanded
)
nsa_cache_seqlens
=
compute_nsa_seqlens
(
seqlens_expanded
,
self
.
nsa_index_topk
)
metadata
.
nsa_cache_seqlens_int32
[:
seqlens_expanded
.
size
(
0
)].
copy_
(
nsa_cache_seqlens
)
seqlens_expanded_size
=
seqlens_expanded
.
size
(
0
)
assert
(
metadata
.
nsa_cache_seqlens_int32
is
not
None
and
metadata
.
nsa_cu_seqlens_k
is
not
None
and
self
.
nsa_index_topk
is
not
None
)
nsa_cache_seqlens
=
compute_nsa_seqlens
(
cache_seqlens
,
self
.
nsa_index_topk
)
metadata
.
nsa_cache_seqlens_int32
.
copy_
(
nsa_cache_seqlens
)
metadata
.
nsa_cu_seqlens_k
[
1
:].
copy_
(
metadata
.
nsa_cu_seqlens_k
[
1
:
1
+
seqlens_expanded_size
].
copy_
(
torch
.
cumsum
(
nsa_cache_seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
...
...
@@ -451,10 +713,13 @@ class NativeSparseAttnBackend(AttentionBackend):
assert
metadata
.
real_page_table
is
metadata
.
page_table_1
if
NSA_DECODE_IMPL
==
"flashmla_decode"
:
metadata
.
flashmla_metadata
.
copy_
(
flashmla_metadata
=
metadata
.
flashmla_metadata
.
slice
(
slice
(
0
,
seqlens_expanded_size
+
1
)
)
flashmla_metadata
.
copy_
(
self
.
_compute_flashmla_metadata
(
cache_seqlens
=
nsa_cache_seqlens
,
seq_len_q
=
1
,
# TODO handle MTP which is not 1
seq_len_q
=
1
,
)
)
...
...
@@ -473,10 +738,7 @@ class NativeSparseAttnBackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
(
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
),
"NSA backend doesn't support speculative decoding"
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
...
...
@@ -884,3 +1146,58 @@ class NativeSparseAttnBackend(AttentionBackend):
flashmla_metadata
=
flashmla_metadata
,
num_splits
=
num_splits
,
)
class
NativeSparseAttnMultiStepBackend
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
):
self
.
model_runner
=
model_runner
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
NativeSparseAttnBackend
(
model_runner
,
speculative_step_id
=
i
,
topk
=
self
.
topk
,
speculative_num_steps
=
self
.
speculative_num_steps
,
)
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
python/sglang/srt/speculative/draft_utils.py
View file @
efa47334
...
...
@@ -48,6 +48,7 @@ class DraftBackendFactory:
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
}
return
self
.
_create_backend
(
...
...
@@ -70,6 +71,7 @@ class DraftBackendFactory:
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
}
backend_name
=
(
"decode_attention_backend"
...
...
@@ -82,6 +84,20 @@ class DraftBackendFactory:
"EAGLE is not supported in attention backend {backend_type}"
,
)
def
_create_nsa_decode_backend
(
self
):
from
sglang.srt.layers.attention.nsa_backend
import
(
NativeSparseAttnMultiStepBackend
,
)
return
NativeSparseAttnMultiStepBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_nsa_prefill_backend
(
self
):
from
sglang.srt.layers.attention.nsa_backend
import
NativeSparseAttnBackend
return
NativeSparseAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_flashinfer_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
efa47334
...
...
@@ -81,6 +81,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
seq_lens_cpu
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
self
.
extend_seq_lens_cpu
=
[
self
.
seq_len_fill_value
]
*
self
.
max_bs
if
self
.
enable_torch_compile
:
set_torch_compile_config
()
...
...
@@ -92,6 +93,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
self
.
extend_seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
(
(
self
.
max_num_token
*
self
.
speculative_num_steps
,),
dtype
=
torch
.
int64
)
...
...
@@ -165,6 +167,9 @@ class EAGLEDraftCudaGraphRunner:
# Graph inputs
req_pool_indices
=
self
.
req_pool_indices
[:
num_seqs
]
seq_lens
=
self
.
seq_lens
[:
num_seqs
]
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
num_seqs
]
extend_seq_lens
=
self
.
extend_seq_lens
[:
num_seqs
]
extend_seq_lens_cpu
=
self
.
extend_seq_lens_cpu
[:
num_seqs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
*
self
.
speculative_num_steps
]
positions
=
self
.
positions
[:
num_tokens
]
mrope_positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
...
...
@@ -227,6 +232,9 @@ class EAGLEDraftCudaGraphRunner:
input_ids
=
None
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens_cpu
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens_cpu
=
extend_seq_lens_cpu
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
efa47334
...
...
@@ -78,6 +78,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
seq_lens_cpu
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
self
.
extend_seq_lens_cpu
=
[
self
.
num_tokens_per_bs
]
*
self
.
max_bs
if
self
.
enable_torch_compile
:
set_torch_compile_config
()
...
...
@@ -196,7 +197,9 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids
=
self
.
input_ids
[:
num_tokens
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
extend_seq_lens
=
self
.
extend_seq_lens
[:
bs
]
extend_seq_lens_cpu
=
self
.
extend_seq_lens_cpu
[:
bs
]
accept_length
=
self
.
accept_length
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
...
...
@@ -254,6 +257,7 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens_cpu
,
next_token_logits_buffer
=
next_token_logits_buffer
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
...
...
@@ -271,6 +275,7 @@ class EAGLEDraftExtendCudaGraphRunner:
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
,
attn_backend
=
self
.
eagle_worker
.
draft_extend_attn_backend
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens_cpu
=
extend_seq_lens_cpu
,
padded_static_len
=
self
.
padded_static_len
,
)
...
...
@@ -373,6 +378,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
if
forward_batch
.
extend_seq_lens_cpu
is
not
None
:
self
.
extend_seq_lens_cpu
[:
raw_bs
]
=
forward_batch
.
extend_seq_lens_cpu
if
bs
!=
raw_bs
:
forward_batch
.
spec_info
.
positions
=
self
.
positions
[:
num_tokens
]
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
bs
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment