Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ce6b17c0
Unverified
Commit
ce6b17c0
authored
Oct 30, 2025
by
Even Zhou
Committed by
GitHub
Oct 30, 2025
Browse files
[Feature] Support DeepSeek MTP on NPU (#11897)
Co-authored-by:
liupeng374
<
liupeng374@huawei.com
>
parent
cafebef1
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
850 additions
and
117 deletions
+850
-117
.github/workflows/pr-test-npu.yml
.github/workflows/pr-test-npu.yml
+6
-2
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+233
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-1
python/sglang/srt/model_executor/npu_graph_runner.py
python/sglang/srt/model_executor/npu_graph_runner.py
+7
-3
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+11
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-2
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+16
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+42
-36
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+68
-25
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+261
-16
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+5
-2
python/sglang/srt/speculative/eagle_worker_v2.py
python/sglang/srt/speculative/eagle_worker_v2.py
+15
-9
python/sglang/srt/speculative/spec_utils.py
python/sglang/srt/speculative/spec_utils.py
+44
-8
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+10
-6
test/srt/ascend/test_ascend_deepseek_mtp.py
test/srt/ascend/test_ascend_deepseek_mtp.py
+117
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
.github/workflows/pr-test-npu.yml
View file @
ce6b17c0
...
@@ -65,7 +65,7 @@ jobs:
...
@@ -65,7 +65,7 @@ jobs:
if
:
github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
if
:
github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on
:
linux-arm64-npu-2
runs-on
:
linux-arm64-npu-2
strategy
:
strategy
:
fail-fast
:
fals
e
fail-fast
:
tru
e
matrix
:
matrix
:
part
:
[
0
,
1
,
2
]
part
:
[
0
,
1
,
2
]
container
:
container
:
...
@@ -144,6 +144,10 @@ jobs:
...
@@ -144,6 +144,10 @@ jobs:
per-commit-16-ascend-a3
:
per-commit-16-ascend-a3
:
if
:
github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
if
:
github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on
:
linux-aarch64-a3-16
runs-on
:
linux-aarch64-a3-16
strategy
:
fail-fast
:
true
matrix
:
part
:
[
0
,
1
]
container
:
container
:
image
:
swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
image
:
swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
steps
:
steps
:
...
@@ -177,4 +181,4 @@ jobs:
...
@@ -177,4 +181,4 @@ jobs:
run
:
|
run
:
|
export PATH="/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}"
export PATH="/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}"
cd test/srt
cd test/srt
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 3600
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 3600
--auto-partition-id ${{ matrix.part }} --auto-partition-size 2
python/sglang/srt/layers/attention/ascend_backend.py
View file @
ce6b17c0
...
@@ -59,6 +59,19 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -59,6 +59,19 @@ class AscendAttnBackend(AttentionBackend):
)
)
self
.
mask_len
=
max_seq_len
self
.
mask_len
=
max_seq_len
def
get_verify_buffers_to_fill_after_draft
(
self
):
"""
Return buffers for verify attention kernels that needs to be filled after draft.
Typically, these are tree mask and position buffers.
"""
return
[
None
,
None
]
def
update_verify_buffers_to_fill_after_draft
(
self
,
spec_info
:
SpecInput
,
cuda_graph_bs
:
Optional
[
int
]
):
pass
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
super
().
__init__
()
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
...
@@ -87,15 +100,22 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -87,15 +100,22 @@ class AscendAttnBackend(AttentionBackend):
device
=
model_runner
.
device
,
device
=
model_runner
.
device
,
)
)
)
)
self
.
speculative_num_draft_tokens
=
(
model_runner
.
server_args
.
speculative_num_draft_tokens
)
self
.
mtp_mask
=
torch
.
tril
(
torch
.
ones
(
2048
,
2048
,
dtype
=
torch
.
bool
)).
npu
()
self
.
mtp_mask
=
~
self
.
mtp_mask
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
"""Init the metadata for a forward pass."""
tp_size
=
get_attention_tp_size
()
tp_size
=
get_attention_tp_size
()
self
.
forward_metadata
=
ForwardMetadata
()
self
.
forward_metadata
=
ForwardMetadata
()
seq_lens_max
=
forward_batch
.
seq_lens
.
max
()
if
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_max
+=
self
.
speculative_num_draft_tokens
self
.
forward_metadata
.
block_tables
=
(
self
.
forward_metadata
.
block_tables
=
(
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
forward_batch
.
seq_lens
.
max
()
forward_batch
.
req_pool_indices
,
:
seq_lens
_
max
][:,
::
self
.
page_size
]
][:,
::
self
.
page_size
]
//
self
.
page_size
//
self
.
page_size
)
)
...
@@ -104,16 +124,23 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -104,16 +124,23 @@ class AscendAttnBackend(AttentionBackend):
forward_batch
.
extend_seq_lens
.
cpu
().
int
()
forward_batch
.
extend_seq_lens
.
cpu
().
int
()
)
)
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
if
(
not
forward_batch
.
forward_mode
.
is_draft_extend_v2
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
):
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
)
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
)
self
.
forward_metadata
.
seq_lens_list_cumsum
=
seq_lens_list_cumsum
self
.
forward_metadata
.
seq_lens_list_cumsum
=
seq_lens_list_cumsum
if
forward_batch
.
forward_mode
.
is_target_verify
():
self
.
forward_metadata
.
seq_lens_cpu_int
+=
self
.
speculative_num_draft_tokens
self
.
graph_mode
=
False
self
.
graph_mode
=
False
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
graph_metadata
=
{
self
.
graph_metadata
=
{
"block_tables"
:
torch
.
empty
(
"block_tables"
:
torch
.
empty
(
(
max_bs
,
self
.
max_context_len
//
self
.
page_size
),
(
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
),
...
@@ -156,6 +183,8 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -156,6 +183,8 @@ class AscendAttnBackend(AttentionBackend):
):
):
metadata
=
self
.
graph_metadata
[
bs
]
metadata
=
self
.
graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
[:
bs
].
max
().
item
()
max_len
=
seq_lens_cpu
[:
bs
].
max
().
item
()
if
forward_mode
.
is_target_verify
():
max_len
+=
self
.
speculative_num_draft_tokens
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
block_tables
[:
bs
,
:
max_seq_pages
].
copy_
(
metadata
.
block_tables
[:
bs
,
:
max_seq_pages
].
copy_
(
...
@@ -257,6 +286,25 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -257,6 +286,25 @@ class AscendAttnBackend(AttentionBackend):
k_rope
,
k_rope
,
topk_indices
,
topk_indices
,
)
)
if
(
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
or
forward_batch
.
forward_mode
.
is_draft_extend_v2
()
):
if
is_mla_preprocess_enabled
():
save_kv_cache
=
False
return
self
.
forward_mtp
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
=
q_rope
,
k_rope
=
k_rope
,
)
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
@@ -393,6 +441,118 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -393,6 +441,118 @@ class AscendAttnBackend(AttentionBackend):
)
)
return
attn_output
return
attn_output
def
forward_mtp
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
save_kv_cache
:
if
self
.
use_mla
:
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
kv_lora_rank
)
k_rope
=
k_rope
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
qk_rope_head_dim
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
k_rope
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
c_kv
,
k_rope
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
k_rope_cache
=
k_rope
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
page_size
,
self
.
qk_rope_head_dim
)
c_kv_cache
=
c_kv
.
view
(
-
1
,
layer
.
tp_v_head_num
,
self
.
page_size
,
self
.
kv_lora_rank
)
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
self
.
kv_lora_rank
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
self
.
qk_rope_head_dim
)
if
not
self
.
graph_mode
:
num_token_padding
=
q
.
shape
[
0
]
q_nope
=
q_nope
[:
forward_batch
.
num_token_non_padded_cpu
]
q_rope
=
q_rope
[:
forward_batch
.
num_token_non_padded_cpu
]
if
self
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
else
:
actual_seq_lengths_kv
=
(
self
.
forward_metadata
.
seq_lens_cpu_int
.
cpu
().
int
().
tolist
()
)
if
forward_batch
.
forward_mode
.
is_draft_extend
():
actual_seq_lengths
=
(
np
.
array
(
forward_batch
.
extend_seq_lens_cpu
).
cumsum
().
tolist
()
)
else
:
actual_seq_lengths
=
np
.
arange
(
self
.
speculative_num_draft_tokens
,
self
.
speculative_num_draft_tokens
+
q_nope
.
shape
[
0
],
self
.
speculative_num_draft_tokens
,
)
workspace
=
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
q_nope
,
c_kv_cache
,
c_kv_cache
,
query_rope
=
q_rope
,
key_rope
=
k_rope_cache
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"TND"
,
scale
=
layer
.
scaling
,
antiquant_mode
=
0
,
antiquant_scale
=
None
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
sparse_mode
=
3
,
atten_mask
=
self
.
mtp_mask
,
actual_seq_lengths
=
actual_seq_lengths
,
actual_seq_lengths_kv
=
actual_seq_lengths_kv
,
)
attn_output
=
torch
.
empty_like
(
q_nope
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
softmax_lse
=
torch
.
empty
(
1
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch_npu
.
npu_fused_infer_attention_score
.
out
(
q_nope
,
c_kv_cache
,
c_kv_cache
,
query_rope
=
q_rope
,
key_rope
=
k_rope_cache
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"TND"
,
scale
=
layer
.
scaling
,
antiquant_mode
=
0
,
antiquant_scale
=
None
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
sparse_mode
=
3
,
atten_mask
=
self
.
mtp_mask
,
actual_seq_lengths
=
actual_seq_lengths
,
actual_seq_lengths_kv
=
actual_seq_lengths_kv
,
workspace
=
workspace
,
out
=
[
attn_output
,
softmax_lse
],
)
attn_output
=
attn_output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
if
(
not
self
.
graph_mode
and
forward_batch
.
num_token_non_padded_cpu
!=
num_token_padding
):
attn_output
=
torch
.
cat
(
[
attn_output
,
attn_output
.
new_zeros
(
num_token_padding
-
attn_output
.
shape
[
0
],
*
attn_output
.
shape
[
1
:]
),
],
dim
=
0
,
)
return
attn_output
def
forward_decode_graph
(
def
forward_decode_graph
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
@@ -690,3 +850,71 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -690,3 +850,71 @@ class AscendAttnBackend(AttentionBackend):
out
=
attn_output
,
out
=
attn_output
,
)
)
return
attn_output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
self
.
kv_lora_rank
)
return
attn_output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
self
.
kv_lora_rank
)
class
AscendAttnMultiStepDraftBackend
:
"""
Wrap multiple Ascend attention backends as one for multiple consecutive
draft decoding steps
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
attn_backends
=
[]
for
_
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
AscendAttnBackend
(
model_runner
))
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
call_fn
:
int
):
assert
forward_batch
.
spec_info
is
not
None
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
,
max_num_tokens
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
None
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
python/sglang/srt/managers/schedule_batch.py
View file @
ce6b17c0
...
@@ -77,6 +77,9 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...
@@ -77,6 +77,9 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.utils
import
flatten_nested_list
from
sglang.srt.utils
import
flatten_nested_list
from
sglang.srt.utils.common
import
is_npu
_is_npu
=
is_npu
()
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
...
@@ -1050,7 +1053,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1050,7 +1053,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
has_grammar
:
bool
=
False
has_grammar
:
bool
=
False
# Device
# Device
if
not
_is_npu
:
device
:
str
=
"cuda"
device
:
str
=
"cuda"
else
:
device
:
str
=
"npu"
# Speculative decoding
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
...
...
python/sglang/srt/model_executor/npu_graph_runner.py
View file @
ce6b17c0
...
@@ -75,6 +75,10 @@ class NPUGraphRunner(CudaGraphRunner):
...
@@ -75,6 +75,10 @@ class NPUGraphRunner(CudaGraphRunner):
# Replay
# Replay
if
not
is_deepseek_nsa
(
self
.
model_runner
.
model_config
.
hf_config
):
if
not
is_deepseek_nsa
(
self
.
model_runner
.
model_config
.
hf_config
):
if
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_cpu
=
forward_batch
.
seq_lens
.
cpu
()
+
self
.
num_tokens_per_bs
seq_lens
=
seq_lens_cpu
.
tolist
()
+
[
0
]
*
(
self
.
bs
-
self
.
raw_bs
)
else
:
seq_lens
=
forward_batch
.
seq_lens
.
cpu
().
tolist
()
+
[
0
]
*
(
seq_lens
=
forward_batch
.
seq_lens
.
cpu
().
tolist
()
+
[
0
]
*
(
self
.
bs
-
self
.
raw_bs
self
.
bs
-
self
.
raw_bs
)
)
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
ce6b17c0
...
@@ -38,12 +38,13 @@ from sglang.srt.models.deepseek_v2 import (
...
@@ -38,12 +38,13 @@ from sglang.srt.models.deepseek_v2 import (
enable_nextn_moe_bf16_cast_to_fp8
,
enable_nextn_moe_bf16_cast_to_fp8
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
,
is_npu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
class
DeepseekModelNextN
(
nn
.
Module
):
class
DeepseekModelNextN
(
nn
.
Module
):
...
@@ -85,13 +86,21 @@ class DeepseekModelNextN(nn.Module):
...
@@ -85,13 +86,21 @@ class DeepseekModelNextN(nn.Module):
self
.
eh_proj
=
nn
.
Linear
(
2
*
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
eh_proj
=
nn
.
Linear
(
2
*
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
if
_is_cuda
else
None
layer_name
=
"decoder"
if
_is_npu
and
(
get_global_server_args
().
speculative_draft_model_path
==
get_global_server_args
().
model_path
):
layer_name
=
"layers."
+
str
(
config
.
num_hidden_layers
)
self
.
decoder
=
DeepseekV2DecoderLayer
(
self
.
decoder
=
DeepseekV2DecoderLayer
(
config
,
config
,
0
,
0
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_quant_config
=
moe_quant_config
,
moe_quant_config
=
moe_quant_config
,
is_nextn
=
True
,
is_nextn
=
True
,
prefix
=
add_prefix
(
"decoder"
,
prefix
),
prefix
=
add_prefix
(
layer_name
,
prefix
),
alt_stream
=
self
.
alt_stream
,
alt_stream
=
self
.
alt_stream
,
)
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
ce6b17c0
...
@@ -290,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch):
...
@@ -290,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch):
forward_batch
.
forward_mode
.
is_extend
()
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend_v2
()
):
):
if
hasattr
(
attn
,
"indexer"
):
if
hasattr
(
attn
,
"indexer"
):
return
AttnForwardMethod
.
NPU_MLA_SPARSE
return
AttnForwardMethod
.
NPU_MLA_SPARSE
...
@@ -3753,8 +3754,12 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3753,8 +3754,12 @@ class DeepseekV2ForCausalLM(nn.Module):
del
self
.
lm_head
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
self
.
lm_head
.
weight
=
head
if
not
_is_npu
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
else
:
torch
.
npu
.
empty_cache
()
torch
.
npu
.
synchronize
()
@
classmethod
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
def
get_model_config_for_expert_location
(
cls
,
config
):
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
ce6b17c0
...
@@ -49,6 +49,7 @@ class DraftBackendFactory:
...
@@ -49,6 +49,7 @@ class DraftBackendFactory:
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
"ascend"
:
self
.
_create_ascend_decode_backend
,
}
}
return
self
.
_create_backend
(
return
self
.
_create_backend
(
...
@@ -72,6 +73,7 @@ class DraftBackendFactory:
...
@@ -72,6 +73,7 @@ class DraftBackendFactory:
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
"ascend"
:
self
.
_create_ascend_prefill_backend
,
}
}
backend_name
=
(
backend_name
=
(
"decode_attention_backend"
"decode_attention_backend"
...
@@ -173,6 +175,15 @@ class DraftBackendFactory:
...
@@ -173,6 +175,15 @@ class DraftBackendFactory:
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
)
def
_create_ascend_decode_backend
(
self
):
from
sglang.srt.layers.attention.ascend_backend
import
(
AscendAttnMultiStepDraftBackend
,
)
return
AscendAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_flashinfer_prefill_backend
(
self
):
def
_create_flashinfer_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -219,6 +230,11 @@ class DraftBackendFactory:
...
@@ -219,6 +230,11 @@ class DraftBackendFactory:
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_ascend_prefill_backend
(
self
):
from
sglang.srt.layers.attention.ascend_backend
import
AscendAttnBackend
return
AscendAttnBackend
(
self
.
draft_model_runner
)
def
_create_flashmla_prefill_backend
(
self
):
def
_create_flashmla_prefill_backend
(
self
):
logger
.
warning
(
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
"flashmla prefill backend is not yet supported for draft extend."
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
ce6b17c0
...
@@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import (
...
@@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin
,
EagleDraftInputV2Mixin
,
EagleVerifyInputV2Mixin
,
EagleVerifyInputV2Mixin
,
)
)
from
sglang.srt.speculative.eagle_utils
import
verify_tree_greedy_func
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpecInputType
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpecInputType
from
sglang.srt.speculative.spec_utils
import
(
from
sglang.srt.speculative.spec_utils
import
(
SIMULATE_ACC_LEN
,
SIMULATE_ACC_LEN
,
TREE_SPEC_KERNEL_AVAILABLE
,
TREE_SPEC_KERNEL_AVAILABLE
,
align_evict_mask_to_page_size
,
align_evict_mask_to_page_size
,
assign_req_to_token_pool
,
assign_req_to_token_pool
_func
,
create_accept_length_filter
,
create_accept_length_filter
,
create_extend_after_decode_spec_info
,
create_extend_after_decode_spec_info
,
filter_finished_cache_loc_kernel
,
filter_finished_cache_loc_kernel
,
...
@@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import (
...
@@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc
,
get_src_tgt_cache_loc
,
get_target_cache_loc
,
get_target_cache_loc
,
)
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_npu
,
next_power_of_2
_is_npu
=
is_npu
()
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
top_k_renorm_prob
,
top_k_renorm_prob
,
top_p_renorm_prob
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
)
)
elif
is_hip
():
from
sgl_kernel
import
verify_tree_greedy
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@
classmethod
@
classmethod
def
create_idle_input
(
cls
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
):
def
create_idle_input
(
cls
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
):
if
not
_is_npu
:
device
=
"cuda"
else
:
device
=
"npu"
return
cls
(
return
cls
(
draft_token
=
torch
.
empty
((
0
,),
dtype
=
torch
.
long
,
device
=
"cuda"
),
draft_token
=
torch
.
empty
((
0
,),
dtype
=
torch
.
long
,
device
=
device
),
custom_mask
=
torch
.
full
((
0
,),
True
,
dtype
=
torch
.
bool
,
device
=
"cuda"
),
custom_mask
=
torch
.
full
((
0
,),
True
,
dtype
=
torch
.
bool
,
device
=
device
),
positions
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
),
positions
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
device
),
retrive_index
=
torch
.
full
(
retrive_index
=
torch
.
full
(
(
0
,
num_verify_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
(
0
,
num_verify_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
device
),
),
retrive_next_token
=
torch
.
full
(
retrive_next_token
=
torch
.
full
(
(
0
,
num_verify_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
(
0
,
num_verify_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
device
),
),
retrive_next_sibling
=
torch
.
full
(
retrive_next_sibling
=
torch
.
full
(
(
0
,
num_verify_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
(
0
,
num_verify_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
device
),
),
retrive_cum_len
=
None
,
retrive_cum_len
=
None
,
topk
=
topk
,
topk
=
topk
,
...
@@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
self
.
last_loc
=
last_loc
self
.
last_loc
=
last_loc
bs
=
batch
.
batch_size
()
bs
=
batch
.
batch_size
()
assign_req_to_token_pool
[(
bs
,)]
(
assign_req_to_token_pool
_func
(
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
,
end_offset
,
end_offset
,
batch
.
out_cache_loc
,
batch
.
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
,
next_power_of_2
(
bs
),
)
)
def
generate_attn_arg_prefill
(
def
generate_attn_arg_prefill
(
...
@@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
paged_kernel_lens_sum
:
int
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
):
):
device
=
req_pool_indices
.
device
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
qo_indptr
=
torch
.
arange
(
0
,
0
,
(
1
+
batch_size
)
*
self
.
draft_token_num
,
(
1
+
batch_size
)
*
self
.
draft_token_num
,
step
=
self
.
draft_token_num
,
step
=
self
.
draft_token_num
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
device
,
)
)
cum_kv_seq_len
=
torch
.
zeros
(
cum_kv_seq_len
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
)
paged_kernel_lens
=
paged_kernel_lens
+
self
.
draft_token_num
paged_kernel_lens
=
paged_kernel_lens
+
self
.
draft_token_num
...
@@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
+
self
.
draft_token_num
*
batch_size
,
paged_kernel_lens_sum
+
self
.
draft_token_num
*
batch_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
device
,
)
)
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
req_to_token
,
req_to_token
,
...
@@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
predict_shape
=
list
(
logits_output
.
next_token_logits
.
shape
)[:
-
1
]
predict_shape
=
list
(
logits_output
.
next_token_logits
.
shape
)[:
-
1
]
predict_shape
[
-
1
]
+=
1
predict_shape
[
-
1
]
+=
1
predict
=
torch
.
empty
(
predict_shape
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
predict
=
torch
.
empty
(
predict_shape
,
dtype
=
torch
.
int32
,
device
=
batch
.
device
)
accept_index
=
torch
.
full
(
accept_index
=
torch
.
full
(
(
bs
,
self
.
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
(
bs
,
self
.
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
batch
.
device
)
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
batch
.
device
)
if
bs
!=
len
(
sampling_info
):
if
bs
!=
len
(
sampling_info
):
sampling_info
=
copy
.
deepcopy
(
sampling_info
)
sampling_info
=
copy
.
deepcopy
(
sampling_info
)
...
@@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
linear_penalty
=
torch
.
zeros
(
linear_penalty
=
torch
.
zeros
(
(
bs
,
logits_output
.
next_token_logits
.
shape
[
1
]),
(
bs
,
logits_output
.
next_token_logits
.
shape
[
1
]),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
batch
.
device
,
)
)
sampling_info
.
apply_logits_bias
(
linear_penalty
)
sampling_info
.
apply_logits_bias
(
linear_penalty
)
logits_output
.
next_token_logits
.
add_
(
logits_output
.
next_token_logits
.
add_
(
...
@@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
"Falling back to greedy verification."
"Falling back to greedy verification."
)
)
if
is_all_greedy
or
not
TREE_SPEC_KERNEL_AVAILABLE
:
if
is_all_greedy
or
not
TREE_SPEC_KERNEL_AVAILABLE
or
_is_npu
:
target_predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
target_predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
target_predict
=
target_predict
.
reshape
(
bs
,
self
.
draft_token_num
)
target_predict
=
target_predict
.
reshape
(
bs
,
self
.
draft_token_num
)
predict
,
accept_index
,
accept_length
=
verify_tree_greedy_func
(
verify_tree_greedy
(
predicts
=
predict
,
# mutable
predicts
=
predict
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_length
,
# mutable
accept_token_num
=
accept_length
,
# mutable
...
@@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
retrive_next_token
=
self
.
retrive_next_token
,
retrive_next_token
=
self
.
retrive_next_token
,
retrive_next_sibling
=
self
.
retrive_next_sibling
,
retrive_next_sibling
=
self
.
retrive_next_sibling
,
target_predict
=
target_predict
,
target_predict
=
target_predict
,
topk
=
self
.
topk
,
)
)
else
:
else
:
# apply temperature and get target probs
# apply temperature and get target probs
expanded_temperature
=
torch
.
repeat_interleave
(
expanded_temperature
=
torch
.
repeat_interleave
(
...
@@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
target_probs
=
target_probs
.
reshape
(
bs
,
self
.
draft_token_num
,
-
1
)
target_probs
=
target_probs
.
reshape
(
bs
,
self
.
draft_token_num
,
-
1
)
draft_probs
=
torch
.
zeros
(
draft_probs
=
torch
.
zeros
(
target_probs
.
shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
target_probs
.
shape
,
dtype
=
torch
.
float32
,
device
=
batch
.
device
)
)
# coins for rejection sampling
# coins for rejection sampling
coins
=
torch
.
rand_like
(
candidates
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
coins
=
torch
.
rand_like
(
candidates
,
dtype
=
torch
.
float32
,
device
=
batch
.
device
)
# coins for final sampling
# coins for final sampling
coins_for_final_sampling
=
torch
.
rand
(
coins_for_final_sampling
=
torch
.
rand
(
(
bs
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
(
bs
,),
dtype
=
torch
.
float32
,
device
=
batch
.
device
)
)
tree_speculative_sampling_target_only
(
tree_speculative_sampling_target_only
(
predicts
=
predict
,
# mutable
predicts
=
predict
,
# mutable
...
@@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
if
not
has_finished
:
if
not
has_finished
:
if
page_size
==
1
or
self
.
topk
==
1
:
if
page_size
==
1
or
self
.
topk
==
1
:
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
accept_index
]
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
accept_index
]
assign_req_to_token_pool
[(
bs
,)]
(
assign_req_to_token_pool
_func
(
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
out_cache_loc
,
batch
.
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
,
next_power_of_2
(
bs
),
)
)
else
:
else
:
batch
.
out_cache_loc
=
tgt_cache_loc
batch
.
out_cache_loc
=
tgt_cache_loc
...
@@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
)
)
else
:
else
:
if
page_size
==
1
or
self
.
topk
==
1
:
if
page_size
==
1
or
self
.
topk
==
1
:
assign_req_to_token_pool
[(
bs
,)]
(
assign_req_to_token_pool
_func
(
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
out_cache_loc
[
accept_index
],
batch
.
out_cache_loc
[
accept_index
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
,
next_power_of_2
(
bs
),
)
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens_cpu
.
add_
(
accept_length_cpu
+
1
)
batch
.
seq_lens_cpu
.
add_
(
accept_length_cpu
+
1
)
...
@@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
...
@@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
paged_kernel_lens_sum
:
int
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
):
):
device
=
req_pool_indices
.
device
bs
=
self
.
accept_length
.
numel
()
bs
=
self
.
accept_length
.
numel
()
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
if
paged_kernel_lens_sum
is
None
:
if
paged_kernel_lens_sum
is
None
:
paged_kernel_lens_sum
=
cum_kv_seq_len
[
-
1
]
paged_kernel_lens_sum
=
cum_kv_seq_len
[
-
1
]
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
device
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
ce6b17c0
...
@@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import (
)
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.eagle_utils
import
verify_tree_greedy_func
from
sglang.srt.speculative.spec_utils
import
(
from
sglang.srt.speculative.spec_utils
import
(
SIMULATE_ACC_LEN
,
SIMULATE_ACC_LEN
,
generate_simulated_accept_index
,
generate_simulated_accept_index
,
)
)
from
sglang.srt.utils.common
import
fast_topk
,
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils.common
import
fast_topk
,
is_cuda
,
is_hip
,
is_npu
,
next_power_of_2
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
...
@@ -41,11 +46,8 @@ if is_cuda():
...
@@ -41,11 +46,8 @@ if is_cuda():
top_k_renorm_prob
,
top_k_renorm_prob
,
top_p_renorm_prob
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
)
)
from
sgl_kernel.top_k
import
fast_topk
from
sgl_kernel.top_k
import
fast_topk
elif
is_hip
():
from
sgl_kernel
import
verify_tree_greedy
@
triton
.
jit
@
triton
.
jit
...
@@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1(
...
@@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1(
@
dataclass
@
dataclass
class
EagleDraftInputV2Mixin
:
class
EagleDraftInputV2Mixin
:
def
prepare_for_decode
(
self
:
EagleDraftInput
,
batch
:
ScheduleBatch
):
def
prepare_for_decode
(
self
:
EagleDraftInput
,
batch
:
ScheduleBatch
):
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
_func
bs
=
batch
.
batch_size
()
bs
=
batch
.
batch_size
()
...
@@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin:
...
@@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin:
extend_num_tokens
,
extend_num_tokens
,
)
)
assign_req_to_token_pool
[(
bs
,)]
(
assign_req_to_token_pool
_func
(
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
self
.
allocate_lens
,
new_allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
bs
,
next_power_of_2
(
bs
),
)
)
self
.
allocate_lens
=
new_allocate_lens
self
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): make this sync optional
# FIXME(lsyin): make this sync optional
...
@@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin:
...
@@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin:
bs
=
len
(
batch
.
req_pool_indices
)
bs
=
len
(
batch
.
req_pool_indices
)
batch
.
input_ids
=
self
.
draft_token
batch
.
input_ids
=
self
.
draft_token
device
=
batch
.
input_ids
.
device
device
=
batch
.
input_ids
.
device
batch
.
out_cache_loc
=
torch
.
empty
(
batch
.
out_cache_loc
=
assign_extend_cache_locs_func
(
(
bs
*
self
.
draft_token_num
,),
req_pool_indices
=
batch
.
req_pool_indices
,
dtype
=
torch
.
int64
,
req_to_token
=
req_to_token_pool
.
req_to_token
,
start_offset
=
batch
.
seq_lens
,
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
,
batch_size
=
bs
,
draft_token_num
=
self
.
draft_token_num
,
device
=
device
,
device
=
device
,
)
)
assign_extend_cache_locs
[(
bs
,)](
batch
.
req_pool_indices
,
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
batch
.
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
# Get a forward batch
# Get a forward batch
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
...
@@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin:
...
@@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin:
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
device
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
device
)
# Sample tokens
# Sample tokens
if
sampling_info
.
is_all_greedy
:
if
sampling_info
.
is_all_greedy
or
_is_npu
:
target_predict
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
)
target_predict
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
)
target_predict
=
target_predict
.
reshape
(
bs
,
self
.
draft_token_num
)
target_predict
=
target_predict
.
reshape
(
bs
,
self
.
draft_token_num
)
predict
,
accept_index
,
accept_length
=
verify_tree_greedy_func
(
verify_tree_greedy
(
predicts
=
predict
,
# mutable
predicts
=
predict
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_length
,
# mutable
accept_token_num
=
accept_length
,
# mutable
...
@@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin:
...
@@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin:
retrive_next_token
=
self
.
retrive_next_token
,
retrive_next_token
=
self
.
retrive_next_token
,
retrive_next_sibling
=
self
.
retrive_next_sibling
,
retrive_next_sibling
=
self
.
retrive_next_sibling
,
target_predict
=
target_predict
,
target_predict
=
target_predict
,
topk
=
self
.
topk
,
)
)
else
:
else
:
# Apply temperature and get target probs
# Apply temperature and get target probs
...
@@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin:
...
@@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin:
return
predict
,
accept_length
,
accept_index
return
predict
,
accept_length
,
accept_index
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
disable
=
_is_npu
)
def
select_top_k_tokens_tmp
(
def
select_top_k_tokens_tmp
(
i
:
int
,
i
:
int
,
topk_p
:
torch
.
Tensor
,
topk_p
:
torch
.
Tensor
,
...
@@ -456,3 +452,50 @@ def assign_extend_cache_locs(
...
@@ -456,3 +452,50 @@ def assign_extend_cache_locs(
tl
.
store
(
out_cache_ptr
+
save_offset
,
data
,
mask
=
mask
)
tl
.
store
(
out_cache_ptr
+
save_offset
,
data
,
mask
=
mask
)
load_offset
+=
BLOCK_SIZE
load_offset
+=
BLOCK_SIZE
save_offset
+=
BLOCK_SIZE
save_offset
+=
BLOCK_SIZE
def
assign_extend_cache_locs_func
(
req_pool_indices
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
start_offset
:
torch
.
Tensor
,
end_offset
:
torch
.
Tensor
,
batch_size
:
int
,
draft_token_num
:
int
,
device
,
)
->
torch
.
Tensor
:
if
_is_cuda
or
_is_hip
:
out_cache_loc
=
torch
.
empty
(
(
batch_size
*
draft_token_num
,),
dtype
=
torch
.
int64
,
device
=
device
,
)
assign_extend_cache_locs
[(
batch_size
,)](
req_pool_indices
,
req_to_token
,
start_offset
,
end_offset
,
out_cache_loc
,
req_to_token
.
shape
[
1
],
next_power_of_2
(
batch_size
),
)
return
out_cache_loc
elif
_is_npu
:
import
sgl_kernel_npu
# noqa: F401
out_cache_loc
=
torch
.
empty
(
(
batch_size
*
draft_token_num
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
torch
.
ops
.
npu
.
cache_loc_update
(
req_pool_indices
,
req_to_token
,
start_offset
,
end_offset
,
out_cache_loc
,
)
out_cache_loc
=
out_cache_loc
.
to
(
dtype
=
torch
.
int64
)
return
out_cache_loc
python/sglang/srt/speculative/eagle_utils.py
View file @
ce6b17c0
...
@@ -4,14 +4,128 @@ from typing import List, Optional
...
@@ -4,14 +4,128 @@ from typing import List, Optional
import
torch
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
is_npu
if
is_cuda
()
or
is_hip
():
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
(
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
)
def
build_tree_efficient_native
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
draft_token_num
:
int
,
tree_mask_mode
:
int
,
bs
:
int
,
):
# Generate batch and token index ranges
bs_range
=
torch
.
arange
(
bs
,
device
=
tree_mask
.
device
).
view
(
-
1
,
1
)
draft_token_num_range
=
torch
.
arange
(
draft_token_num
,
device
=
tree_mask
.
device
)
# Optimized common case for performance.
if
draft_token_num
==
2
and
topk
==
1
and
tree_mask_mode
==
TreeMaskMode
.
FULL_MASK
:
positions
=
verified_seq_len
.
repeat_interleave
(
draft_token_num
)
positions
=
(
positions
.
view
(
bs
,
-
1
)
+
draft_token_num_range
).
view
(
-
1
)
retrive_index
[:]
=
bs_range
*
draft_token_num
+
draft_token_num_range
retrive_next_token
[:,
0
]
=
1
retrive_next_token
[:,
1
]
=
-
1
return
(
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
tree_mask
,
)
# Precompute sequence tree indices
draft_token_num_range1
=
torch
.
arange
(
draft_token_num
-
1
,
device
=
tree_mask
.
device
)
cum_seq_len
=
torch
.
cumsum
(
verified_seq_len
*
draft_token_num
,
dim
=
0
)
cum_seq_len
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
tree_mask
.
device
),
cum_seq_len
))
cum_seq_len
=
cum_seq_len
[:
-
1
]
seq_tree_idx
=
(
draft_token_num
*
draft_token_num
*
torch
.
arange
(
bs
,
device
=
tree_mask
.
device
)
+
cum_seq_len
)
# Batch processing for tree mask
if
tree_mask_mode
==
TreeMaskMode
.
FULL_MASK
:
token_tree_base
=
(
seq_tree_idx
.
view
(
-
1
,
1
)
+
(
verified_seq_len
.
view
(
-
1
,
1
)
+
draft_token_num
)
*
draft_token_num_range
)
token_tree_indices
=
token_tree_base
+
verified_seq_len
.
view
(
-
1
,
1
)
+
1
else
:
token_tree_indices
=
(
bs_range
*
draft_token_num
**
2
+
draft_token_num_range
*
draft_token_num
+
1
)
tree_mask
[
token_tree_indices
.
flatten
()
-
1
]
=
True
indices
=
token_tree_indices
.
unsqueeze
(
-
1
)
+
draft_token_num_range1
.
view
(
1
,
1
,
-
1
)
tree_mask
[
indices
.
view
(
-
1
)]
=
False
positions
=
verified_seq_len
.
repeat_interleave
(
draft_token_num
)
parent_tb_indices
=
selected_index
//
topk
retrive_index
[:]
=
bs_range
*
draft_token_num
+
draft_token_num_range
tree_mask
[
token_tree_indices
.
view
(
-
1
,
1
)
+
draft_token_num_range1
]
=
True
for
bid
in
range
(
bs
):
for
tid
in
range
(
draft_token_num
):
position
=
0
if
tid
==
0
:
# Process root node
for
i
in
range
(
draft_token_num
-
1
,
0
,
-
1
):
parent_position
=
0
parent_tb_idx
=
parent_tb_indices
[
bid
][
i
-
1
]
if
parent_tb_idx
>
0
:
parent_token_idx
=
parent_list
[
bid
][
parent_tb_idx
]
loop_num
=
draft_token_num
-
parent_position
for
_
in
range
(
loop_num
):
if
selected_index
[
bid
][
parent_position
]
==
parent_token_idx
:
parent_position
+=
1
break
parent_position
+=
1
if
parent_position
==
draft_token_num
:
continue
if
retrive_next_token
[
bid
][
parent_position
]
!=
-
1
:
retrive_next_sibling
[
bid
][
i
]
=
retrive_next_token
[
bid
][
parent_position
]
retrive_next_token
[
bid
][
parent_position
]
=
i
else
:
# Process no-root nodes
cur_position
=
tid
-
1
while
True
:
position
+=
1
if
cur_position
>=
draft_token_num
:
tree_mask
[
token_tree_indices
+
cur_position
]
=
True
parent_tb_idx
=
selected_index
[
bid
][
cur_position
]
//
topk
else
:
parent_tb_idx
=
parent_tb_indices
[
bid
][
cur_position
]
if
parent_tb_idx
==
0
:
break
token_idx
=
parent_list
[
bid
][
parent_tb_idx
]
cur_position
=
0
for
_
in
range
(
draft_token_num
):
if
selected_index
[
bid
][
cur_position
]
==
token_idx
:
break
cur_position
+=
1
positions
[
bid
*
draft_token_num
+
tid
]
+=
position
return
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
tree_mask
def
organize_draft_results
(
def
organize_draft_results
(
score_list
:
List
[
torch
.
Tensor
],
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
...
@@ -114,6 +228,27 @@ def build_tree_kernel_efficient(
...
@@ -114,6 +228,27 @@ def build_tree_kernel_efficient(
(
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
(
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
)
if
_is_npu
:
(
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
tree_mask
,
)
=
build_tree_efficient_native
(
parent_list
,
top_scores_index
,
seq_lens
,
tree_mask
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
num_verify_tokens
,
tree_mask_mode
,
bs
,
)
else
:
sgl_build_tree_kernel_efficient
(
sgl_build_tree_kernel_efficient
(
parent_list
,
parent_list
,
top_scores_index
,
top_scores_index
,
...
@@ -136,3 +271,113 @@ def build_tree_kernel_efficient(
...
@@ -136,3 +271,113 @@ def build_tree_kernel_efficient(
retrive_next_sibling
,
retrive_next_sibling
,
draft_tokens
,
draft_tokens
,
)
)
def
verify_tree_greedy_native
(
predicts
:
torch
.
Tensor
,
accept_index
:
torch
.
Tensor
,
accept_token_num
:
torch
.
Tensor
,
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
target_predict
:
torch
.
Tensor
,
topk
:
int
=
-
1
,
):
batch_size
,
num_draft_tokens
=
candidates
.
shape
# Optimized common case for performance.
if
num_draft_tokens
==
2
and
accept_index
.
shape
[
1
]
==
2
and
topk
==
1
:
comparison_result
=
candidates
[:,
1
]
==
target_predict
[:,
0
]
predicts
=
target_predict
.
flatten
()
accept_index
=
torch
.
arange
(
0
,
num_draft_tokens
*
batch_size
,
device
=
candidates
.
device
,
dtype
=
torch
.
long
).
reshape
(
batch_size
,
num_draft_tokens
)
comparison_result
=
comparison_result
.
to
(
torch
.
int64
)
accept_index_mask
=
accept_index
[:,
1
]
*
comparison_result
accept_index
[:,
1
]
=
accept_index_mask
-
(
1
-
comparison_result
)
accept_token_num
=
comparison_result
.
int
()
return
predicts
,
accept_index
,
accept_token_num
# BFS
for
bx
in
range
(
batch_size
):
cur_candidates
=
candidates
[
bx
]
cur_retrive_index
=
retrive_index
[
bx
]
cur_next_token
=
retrive_next_token
[
bx
]
cur_next_sibling
=
retrive_next_sibling
[
bx
]
cur_target
=
target_predict
[
bx
]
last_accepted_idx
=
cur_retrive_index
[
0
]
accept_index
[
bx
,
0
]
=
last_accepted_idx
num_accepted
=
0
cur_node
=
0
for
_
in
range
(
1
,
num_draft_tokens
):
cur_node
=
cur_next_token
[
cur_node
]
found
=
False
while
cur_node
!=
-
1
:
draft_idx
=
cur_retrive_index
[
cur_node
]
draft_token
=
cur_candidates
[
cur_node
]
target_token
=
cur_target
[
last_accepted_idx
-
num_draft_tokens
*
bx
]
if
draft_token
==
target_token
:
predicts
[
last_accepted_idx
]
=
target_token
num_accepted
+=
1
accept_index
[
bx
,
num_accepted
]
=
draft_idx
last_accepted_idx
=
draft_idx
found
=
True
break
else
:
cur_node
=
cur_next_sibling
[
cur_node
]
if
not
found
:
break
accept_token_num
[
bx
]
=
num_accepted
predicts
[
last_accepted_idx
]
=
cur_target
[
last_accepted_idx
-
num_draft_tokens
*
bx
]
return
predicts
,
accept_index
,
accept_token_num
def
verify_tree_greedy_func
(
predicts
:
torch
.
Tensor
,
accept_index
:
torch
.
Tensor
,
accept_token_num
:
torch
.
Tensor
,
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
target_predict
:
torch
.
Tensor
,
topk
:
int
=
-
1
,
):
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
verify_tree_greedy
verify_tree_greedy
(
predicts
=
predicts
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_token_num
,
# mutable
candidates
=
candidates
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
target_predict
=
target_predict
,
)
elif
_is_npu
:
predicts
,
accept_index
,
accept_token_num
=
verify_tree_greedy_native
(
predicts
=
predicts
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_token_num
,
# mutable
candidates
=
candidates
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
target_predict
=
target_predict
,
topk
=
topk
,
)
return
predicts
,
accept_index
,
accept_token_num
python/sglang/srt/speculative/eagle_worker.py
View file @
ce6b17c0
...
@@ -53,9 +53,12 @@ from sglang.srt.utils import (
...
@@ -53,9 +53,12 @@ from sglang.srt.utils import (
get_available_gpu_memory
,
get_available_gpu_memory
,
get_bool_env_var
,
get_bool_env_var
,
is_cuda
,
is_cuda
,
is_npu
,
next_power_of_2
,
next_power_of_2
,
)
)
_is_npu
=
is_npu
()
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
segment_packbits
# noqa: F401
from
sgl_kernel
import
segment_packbits
# noqa: F401
...
@@ -205,7 +208,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -205,7 +208,7 @@ class EAGLEWorker(TpModelWorker):
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner_for_draft_extend
=
None
self
.
cuda_graph_runner_for_draft_extend
=
None
if
self
.
server_args
.
disable_cuda_graph
:
if
self
.
server_args
.
disable_cuda_graph
or
_is_npu
:
return
return
# Capture draft
# Capture draft
...
@@ -945,7 +948,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -945,7 +948,7 @@ class EAGLEWorker(TpModelWorker):
draft_input
.
hidden_states
=
logits_output
.
hidden_states
draft_input
.
hidden_states
=
logits_output
.
hidden_states
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
disable
=
_is_npu
)
def
get_last_loc_large_page_size_top_k_1
(
def
get_last_loc_large_page_size_top_k_1
(
req_to_token
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
...
...
python/sglang/srt/speculative/eagle_worker_v2.py
View file @
ce6b17c0
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch.cuda
import
Stream
as
CudaStream
from
sglang.srt.environ
import
envs
from
sglang.srt.environ
import
envs
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
...
@@ -38,18 +37,21 @@ from sglang.srt.utils.common import (
...
@@ -38,18 +37,21 @@ from sglang.srt.utils.common import (
empty_context
,
empty_context
,
fast_topk
,
fast_topk
,
get_available_gpu_memory
,
get_available_gpu_memory
,
is_npu
,
next_power_of_2
,
next_power_of_2
,
)
)
_is_npu
=
is_npu
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
_get_plan_stream
(
def
_get_plan_stream
(
device
:
str
,
device
:
str
,
)
->
Tuple
[
Optional
[
CudaStream
]
,
contextlib
.
AbstractContextManager
]:
)
->
Tuple
[
any
,
contextlib
.
AbstractContextManager
]:
if
envs
.
SGLANG_ENABLE_OVERLAP_PLAN_STREAM
.
get
():
if
envs
.
SGLANG_ENABLE_OVERLAP_PLAN_STREAM
.
get
():
plan_stream
:
CudaStream
=
torch
.
get_device_module
(
device
).
Stream
()
plan_stream
=
torch
.
get_device_module
(
device
).
Stream
()
plan_stream_ctx
=
torch
.
cuda
.
stream
(
plan_stream
)
plan_stream_ctx
=
torch
.
get_device_module
(
device
)
.
stream
(
plan_stream
)
return
plan_stream
,
plan_stream_ctx
return
plan_stream
,
plan_stream_ctx
else
:
else
:
return
None
,
contextlib
.
nullcontext
()
return
None
,
contextlib
.
nullcontext
()
...
@@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker):
...
@@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker):
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner_for_draft_extend
=
None
self
.
cuda_graph_runner_for_draft_extend
=
None
if
self
.
server_args
.
disable_cuda_graph
:
if
self
.
server_args
.
disable_cuda_graph
or
_is_npu
:
return
return
# Capture draft
# Capture draft
...
@@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker):
...
@@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker):
)
)
if
self
.
plan_stream
:
if
self
.
plan_stream
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
plan_stream
)
torch
.
get_device_module
(
self
.
device
).
current_stream
().
wait_stream
(
self
.
plan_stream
)
# Run draft extend batch in the main compute stream
# Run draft extend batch in the main compute stream
draft_logits_output
=
self
.
draft_runner
.
model
.
forward
(
draft_logits_output
=
self
.
draft_runner
.
model
.
forward
(
...
@@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker):
...
@@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker):
# Since batch.seq_lens is allocated in another stream, we need
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
# record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
# while forward_stream is still running.
batch
.
seq_lens
.
record_stream
(
torch
.
cuda
.
current_stream
())
batch
.
seq_lens
.
record_stream
(
torch
.
get_device_module
(
self
.
device
).
current_stream
()
)
# Parse args
# Parse args
verify_input
:
EagleVerifyInput
=
batch
.
spec_info
verify_input
:
EagleVerifyInput
=
batch
.
spec_info
...
@@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
...
@@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
# Correct some buffers due to the overlap plan
# Correct some buffers due to the overlap plan
if
self
.
plan_stream
:
if
self
.
plan_stream
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
plan_stream
)
torch
.
get_device_module
()
.
current_stream
().
wait_stream
(
self
.
plan_stream
)
# Some values such as custom_mask and position depend on the output of draft,
# Some values such as custom_mask and position depend on the output of draft,
# so the previous plan step used the wrong values. Here, we need to run the related
# so the previous plan step used the wrong values. Here, we need to run the related
...
@@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
...
@@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
accept_index
,
accept_index
,
)
=
verify_input
.
sample
(
batch
,
logits_output
)
)
=
verify_input
.
sample
(
batch
,
logits_output
)
new_seq_lens
=
batch
.
seq_lens
+
accept_length
new_seq_lens
=
batch
.
seq_lens
+
accept_length
verify_done
=
torch
.
cuda
.
Event
()
verify_done
=
torch
.
get_device_module
(
self
.
device
)
.
Event
()
verify_done
.
record
()
verify_done
.
record
()
all_verified_id
=
predict
[
accept_index
]
all_verified_id
=
predict
[
accept_index
]
...
...
python/sglang/srt/speculative/spec_utils.py
View file @
ce6b17c0
...
@@ -19,16 +19,22 @@ from sglang.srt.distributed.parallel_state import (
...
@@ -19,16 +19,22 @@ from sglang.srt.distributed.parallel_state import (
from
sglang.srt.environ
import
envs
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
is_npu
,
next_power_of_2
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_info
import
EagleVerifyInput
from
sglang.srt.speculative.eagle_info
import
EagleVerifyInput
if
is_cuda
()
:
if
_
is_cuda
:
from
sgl_kernel
import
fast_topk
from
sgl_kernel
import
fast_topk
elif
is_hip
()
:
elif
_
is_hip
:
from
sgl_kernel
import
fast_topk
from
sgl_kernel
import
fast_topk
else
:
from
sglang.srt.utils.common
import
fast_topk
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -39,7 +45,7 @@ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
...
@@ -39,7 +45,7 @@ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD
=
envs
.
SGLANG_SIMULATE_ACC_METHOD
.
get
()
SIMULATE_ACC_METHOD
=
envs
.
SGLANG_SIMULATE_ACC_METHOD
.
get
()
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE
=
is_cuda
()
# This kernel is only available for CUDA now
TREE_SPEC_KERNEL_AVAILABLE
=
_
is_cuda
# This kernel is only available for CUDA now
@
triton
.
jit
@
triton
.
jit
...
@@ -103,6 +109,36 @@ def assign_req_to_token_pool(
...
@@ -103,6 +109,36 @@ def assign_req_to_token_pool(
load_offset
+=
BLOCK_SIZE
load_offset
+=
BLOCK_SIZE
def
assign_req_to_token_pool_func
(
req_pool_indices
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
start_offset
:
torch
.
Tensor
,
end_offset
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
batch_size
:
int
,
):
if
_is_cuda
or
_is_hip
:
assign_req_to_token_pool
[(
batch_size
,)](
req_pool_indices
,
req_to_token
,
start_offset
,
end_offset
,
out_cache_loc
,
req_to_token
.
shape
[
1
],
next_power_of_2
(
batch_size
),
)
elif
_is_npu
:
import
sgl_kernel_npu
# noqa: F401
torch
.
ops
.
npu
.
cache_loc_assign
(
req_pool_indices
,
req_to_token
,
start_offset
,
end_offset
,
out_cache_loc
,
)
@
triton
.
jit
@
triton
.
jit
def
assign_draft_cache_locs
(
def
assign_draft_cache_locs
(
req_pool_indices
,
req_pool_indices
,
...
@@ -331,7 +367,7 @@ def get_target_cache_loc(
...
@@ -331,7 +367,7 @@ def get_target_cache_loc(
)
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
disable
=
_is_npu
)
def
get_src_tgt_cache_loc
(
def
get_src_tgt_cache_loc
(
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
...
@@ -381,7 +417,7 @@ def filter_finished_cache_loc_kernel(
...
@@ -381,7 +417,7 @@ def filter_finished_cache_loc_kernel(
)
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
disable
=
_is_npu
)
def
create_accept_length_filter
(
def
create_accept_length_filter
(
accept_length
:
torch
.
Tensor
,
accept_length
:
torch
.
Tensor
,
unfinished_index_device
:
torch
.
Tensor
,
unfinished_index_device
:
torch
.
Tensor
,
...
@@ -395,7 +431,7 @@ def create_accept_length_filter(
...
@@ -395,7 +431,7 @@ def create_accept_length_filter(
return
accept_length_filter
return
accept_length_filter
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
disable
=
_is_npu
)
def
select_top_k_tokens
(
def
select_top_k_tokens
(
i
:
int
,
i
:
int
,
topk_p
:
torch
.
Tensor
,
topk_p
:
torch
.
Tensor
,
...
@@ -413,7 +449,7 @@ def select_top_k_tokens(
...
@@ -413,7 +449,7 @@ def select_top_k_tokens(
tree_info
=
(
tree_info
=
(
topk_p
.
unsqueeze
(
1
),
# shape: (b, 1, topk)
topk_p
.
unsqueeze
(
1
),
# shape: (b, 1, topk)
topk_index
,
# shape: (b, topk)
topk_index
,
# shape: (b, topk)
torch
.
arange
(
-
1
,
topk
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
torch
.
arange
(
-
1
,
topk
,
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
repeat
(
topk_p
.
shape
[
0
],
1
),
# shape: (b, topk + 1)
.
repeat
(
topk_p
.
shape
[
0
],
1
),
# shape: (b, topk + 1)
)
)
...
...
python/sglang/srt/utils/common.py
View file @
ce6b17c0
...
@@ -3106,12 +3106,16 @@ def apply_module_patch(target_module, target_function, wrappers):
...
@@ -3106,12 +3106,16 @@ def apply_module_patch(target_module, target_function, wrappers):
setattr
(
original_module
,
target_function
,
candidate
)
setattr
(
original_module
,
target_function
,
candidate
)
for
key
,
value
in
sys
.
modules
.
copy
().
items
():
for
key
,
value
in
sys
.
modules
.
copy
().
items
():
try
:
if
(
if
(
target_function
is
not
None
target_function
is
not
None
and
hasattr
(
value
,
target_function
)
and
hasattr
(
value
,
target_function
)
and
id
(
getattr
(
value
,
target_function
))
==
original_function_id
and
id
(
getattr
(
value
,
target_function
))
==
original_function_id
):
):
setattr
(
value
,
target_function
,
candidate
)
setattr
(
value
,
target_function
,
candidate
)
except
ImportError
as
e
:
# Ignore some modules reporting ImportError when calling hasattr
logger
.
warning
(
f
"Ignore
{
value
}
reports ImportError with:
\n
{
str
(
e
)
}
"
)
def
parse_module_path
(
module_path
,
function_name
,
create_dummy
):
def
parse_module_path
(
module_path
,
function_name
,
create_dummy
):
...
...
test/srt/ascend/test_ascend_deepseek_mtp.py
0 → 100644
View file @
ce6b17c0
import
os
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8"
:
{
"accuracy"
:
0.95
,
"latency"
:
1000
,
"output_throughput"
:
6
,
},
}
class
TestAscendDeepSeekMTP
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--attention-backend"
,
"ascend"
,
"--quantization"
,
"w8a8_int8"
,
"--mem-fraction-static"
,
0.8
,
"--disable-radix-cache"
,
"--chunked-prefill-size"
,
32768
,
"--tp-size"
,
16
,
"--speculative-algorithm"
,
"NEXTN"
,
"--speculative-num-steps"
,
1
,
"--speculative-eagle-topk"
,
1
,
"--speculative-num-draft-tokens"
,
2
,
]
cls
.
extra_envs
=
{
"SGLANG_NPU_USE_MLAPO"
:
"1"
,
}
os
.
environ
.
update
(
cls
.
extra_envs
)
def
test_a_gsm8k
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
1500
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
ce6b17c0
...
@@ -359,6 +359,7 @@ suite_ascend = {
...
@@ -359,6 +359,7 @@ suite_ascend = {
],
],
"per-commit-16-ascend-a3"
:
[
"per-commit-16-ascend-a3"
:
[
TestFile
(
"ascend/test_ascend_deepep.py"
,
400
),
TestFile
(
"ascend/test_ascend_deepep.py"
,
400
),
TestFile
(
"ascend/test_ascend_deepseek_mtp.py"
,
400
),
],
],
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment