Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
de2dd738
"vscode:/vscode.git/clone" did not exist on "7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0"
Unverified
Commit
de2dd738
authored
Aug 20, 2025
by
Even Zhou
Committed by
GitHub
Aug 20, 2025
Browse files
Revert "[feature] Rework Ascend NPU graph support" (#9385)
parent
1ec97697
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
81 additions
and
546 deletions
+81
-546
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
...els/fused_moe_triton/benchmark_torch_compile_fused_moe.py
+1
-1
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+4
-10
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+22
-132
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+19
-32
python/sglang/srt/model_executor/cuda_graph_runner_impl.py
python/sglang/srt/model_executor/cuda_graph_runner_impl.py
+0
-36
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-19
python/sglang/srt/model_executor/npu_graph_runner.py
python/sglang/srt/model_executor/npu_graph_runner.py
+0
-94
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+1
-1
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+1
-1
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+1
-1
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+1
-1
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+9
-11
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+9
-11
test/srt/ascend/test_ascend_graph_tp1_bf16.py
test/srt/ascend/test_ascend_graph_tp1_bf16.py
+0
-95
test/srt/ascend/test_ascend_graph_tp2_bf16.py
test/srt/ascend/test_ascend_graph_tp2_bf16.py
+0
-97
test/srt/run_suite.py
test/srt/run_suite.py
+0
-2
No files found.
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
View file @
de2dd738
...
@@ -9,7 +9,7 @@ from transformers import AutoConfig
...
@@ -9,7 +9,7 @@ from transformers import AutoConfig
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
fused_moe
as
fused_moe_triton
,
fused_moe
as
fused_moe_triton
,
)
)
from
sglang.srt.model_executor.graph_runner
import
set_torch_compile_config
from
sglang.srt.model_executor.
cuda_
graph_runner
import
set_torch_compile_config
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
de2dd738
...
@@ -55,7 +55,7 @@ _is_npu = is_npu()
...
@@ -55,7 +55,7 @@ _is_npu = is_npu()
@
dataclass
@
dataclass
class
GraphCaptureContext
:
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
if
not
_is_npu
else
torch
.
npu
.
Stream
stream
:
torch
.
cuda
.
Stream
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
...
@@ -252,13 +252,9 @@ class GroupCoordinator:
...
@@ -252,13 +252,9 @@ class GroupCoordinator:
if
is_cuda_alike
():
if
is_cuda_alike
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
_is_npu
:
self
.
device
=
torch
.
device
(
f
"npu:
{
local_rank
}
"
)
else
:
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
use_pynccl
=
use_pynccl
self
.
use_pynccl
=
use_pynccl
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_custom_allreduce
=
use_custom_allreduce
...
@@ -406,7 +402,7 @@ class GroupCoordinator:
...
@@ -406,7 +402,7 @@ class GroupCoordinator:
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
):
):
if
graph_capture_context
is
None
:
if
graph_capture_context
is
None
:
stream
=
self
.
device_module
.
Stream
()
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
graph_capture_context
=
GraphCaptureContext
(
stream
)
else
:
else
:
stream
=
graph_capture_context
.
stream
stream
=
graph_capture_context
.
stream
...
@@ -417,11 +413,11 @@ class GroupCoordinator:
...
@@ -417,11 +413,11 @@ class GroupCoordinator:
# ensure all initialization operations complete before attempting to
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
# capture the graph on another stream
curr_stream
=
self
.
device_module
.
current_stream
()
curr_stream
=
torch
.
cuda
.
current_stream
()
if
curr_stream
!=
stream
:
if
curr_stream
!=
stream
:
stream
.
wait_stream
(
curr_stream
)
stream
.
wait_stream
(
curr_stream
)
with
self
.
device_module
.
stream
(
stream
),
maybe_ca_context
:
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# allreduce \ Mode | Eager | Graph |
...
@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
...
@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
)
)
elif
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
elif
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
torch
.
xpu
.
empty_cache
()
torch
.
xpu
.
empty_cache
()
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
empty_cache
()
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
de2dd738
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
torch_npu
import
torch_npu
...
@@ -27,7 +27,6 @@ class ForwardMetadata:
...
@@ -27,7 +27,6 @@ class ForwardMetadata:
# seq len inputs
# seq len inputs
extend_seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_list
:
Optional
[
List
[
int
]]
=
None
class
AscendAttnBackend
(
AttentionBackend
):
class
AscendAttnBackend
(
AttentionBackend
):
...
@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
super
().
__init__
()
self
.
forward_metadata
=
None
self
.
forward_metadata
=
ForwardMetadata
()
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
self
.
gen_attention_mask
(
128
,
model_runner
.
dtype
)
self
.
gen_attention_mask
(
128
,
model_runner
.
dtype
)
self
.
page_size
=
model_runner
.
page_size
self
.
page_size
=
model_runner
.
page_size
...
@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend):
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
native_attn
=
TorchNativeAttnBackend
(
model_runner
)
self
.
native_attn
=
TorchNativeAttnBackend
(
model_runner
)
self
.
graph_metadata
=
{}
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
graph_mode
=
False
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
"""Init the metadata for a forward pass."""
self
.
forward_metadata
=
ForwardMetadata
()
self
.
forward_metadata
.
block_tables
=
(
self
.
forward_metadata
.
block_tables
=
(
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
forward_batch
.
seq_lens
.
max
()
forward_batch
.
req_pool_indices
,
:
forward_batch
.
seq_lens
.
max
()
...
@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend):
)
)
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
graph_mode
=
False
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
graph_metadata
=
{
"block_tables"
:
torch
.
empty
(
(
max_bs
,
self
.
max_context_len
//
self
.
page_size
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
}
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
metadata
=
ForwardMetadata
()
metadata
.
block_tables
=
self
.
graph_metadata
[
"block_tables"
][:
bs
,
:]
metadata
.
seq_lens_cpu_list
=
seq_lens
.
cpu
().
int
().
tolist
()
self
.
graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
self
.
graph_mode
=
True
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
metadata
=
self
.
graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
[:
bs
].
max
().
item
()
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
block_tables
[:
bs
,
:
max_seq_pages
].
copy_
(
self
.
req_to_token
[
req_pool_indices
[:
bs
],
:
max_len
][:,
::
self
.
page_size
]
//
self
.
page_size
)
metadata
.
block_tables
[:
bs
,
max_seq_pages
:].
fill_
(
0
)
metadata
.
block_tables
[
bs
:,
:].
fill_
(
0
)
self
.
forward_metadata
=
metadata
self
.
graph_mode
=
True
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
...
@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
if
self
.
graph_mode
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
layer
.
qk_head_dim
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
query
=
q
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
num_tokens
=
query
.
shape
[
0
]
workspace
=
(
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
query
,
k_cache
,
v_cache
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSH"
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
)
)
output
=
torch
.
empty
(
(
num_tokens
,
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
softmax_lse
=
torch
.
empty
(
1
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch_npu
.
npu_fused_infer_attention_score
.
out
(
query
,
k_cache
,
v_cache
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSH"
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
workspace
=
workspace
,
out
=
[
output
,
softmax_lse
],
)
else
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
num_tokens
=
query
.
shape
[
0
]
num_tokens
=
query
.
shape
[
0
]
output
=
torch
.
empty
(
output
=
torch
.
empty
(
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
dtype
=
query
.
dtype
,
dtype
=
query
.
dtype
,
device
=
query
.
device
,
device
=
query
.
device
,
)
)
torch_npu
.
_npu_paged_attention
(
torch_npu
.
_npu_paged_attention
(
query
=
query
,
query
=
query
,
key_cache
=
k_cache
,
key_cache
=
k_cache
,
value_cache
=
v_cache
,
value_cache
=
v_cache
,
num_heads
=
layer
.
tp_q_head_num
,
num_heads
=
layer
.
tp_q_head_num
,
num_kv_heads
=
layer
.
tp_k_head_num
,
num_kv_heads
=
layer
.
tp_k_head_num
,
scale_value
=
layer
.
scaling
,
scale_value
=
layer
.
scaling
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_table
=
self
.
forward_metadata
.
block_tables
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
out
=
output
,
out
=
output
,
)
)
return
output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
else
:
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
de2dd738
...
@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache):
...
@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache):
v_scale
:
Optional
[
float
]
=
None
,
v_scale
:
Optional
[
float
]
=
None
,
layer_id_override
:
Optional
[
int
]
=
None
,
layer_id_override
:
Optional
[
int
]
=
None
,
):
):
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.
cuda_
graph_runner
import
get_is_capture_mode
if
layer_id_override
is
not
None
:
if
layer_id_override
is
not
None
:
layer_id
=
layer_id_override
layer_id
=
layer_id_override
...
...
python/sglang/srt/model_executor/graph_runner.py
→
python/sglang/srt/model_executor/
cuda_
graph_runner.py
View file @
de2dd738
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Run the model with
device
graph and torch.compile."""
"""Run the model with
cuda
graph and torch.compile."""
from
__future__
import
annotations
from
__future__
import
annotations
...
@@ -221,7 +221,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -221,7 +221,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
return
capture_bs
,
compile_bs
return
capture_bs
,
compile_bs
# Reuse this memory pool across all
device
graph runners.
# Reuse this memory pool across all
cuda
graph runners.
global_graph_memory_pool
=
None
global_graph_memory_pool
=
None
...
@@ -234,14 +234,12 @@ def set_global_graph_memory_pool(val):
...
@@ -234,14 +234,12 @@ def set_global_graph_memory_pool(val):
global_graph_memory_pool
=
val
global_graph_memory_pool
=
val
class
GraphRunner
:
class
Cuda
GraphRunner
:
"""A GraphRunner
is a base class to
run the forward pass of a model with
device
graph and torch.compile."""
"""A
Cuda
GraphRunner run
s
the forward pass of a model with
cuda
graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Parse args
# Parse args
self
.
model_runner
=
model_runner
self
.
model_runner
=
model_runner
self
.
device
=
model_runner
.
device
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
graphs
=
{}
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
output_buffers
=
{}
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
...
@@ -267,7 +265,7 @@ class GraphRunner:
...
@@ -267,7 +265,7 @@ class GraphRunner:
# Batch sizes to capture
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
rank0_log
(
f
"Capture graph bs
{
self
.
capture_bs
}
"
)
rank0_log
(
f
"Capture
cuda
graph bs
{
self
.
capture_bs
}
"
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
self
.
num_tokens_per_bs
=
1
...
@@ -307,15 +305,13 @@ class GraphRunner:
...
@@ -307,15 +305,13 @@ class GraphRunner:
self
.
model_runner
.
lora_manager
.
init_cuda_graph_batch_info
(
self
.
max_bs
)
self
.
model_runner
.
lora_manager
.
init_cuda_graph_batch_info
(
self
.
max_bs
)
# Graph inputs
# Graph inputs
with
torch
.
device
(
self
.
device
):
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
full
(
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
)
self
.
out_cache_loc
=
torch
.
zeros
(
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
(
self
.
max_num_token
,),
dtype
=
self
.
_cache_loc_dtype
()
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
num_token_non_padded
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
num_token_non_padded
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
...
@@ -370,12 +366,12 @@ class GraphRunner:
...
@@ -370,12 +366,12 @@ class GraphRunner:
*
self
.
num_tokens_per_bs
*
self
.
num_tokens_per_bs
),
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
device
=
"cuda"
,
)
)
self
.
next_token_logits_buffer
=
torch
.
zeros
(
self
.
next_token_logits_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
vocab_size
),
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
vocab_size
),
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
self
.
device
,
device
=
"cuda"
,
)
)
# Capture
# Capture
...
@@ -384,12 +380,9 @@ class GraphRunner:
...
@@ -384,12 +380,9 @@ class GraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture
device
graph failed:
{
e
}
\n
{
GRAPH_CAPTURE_FAILED_MSG
}
"
f
"Capture
cuda
graph failed:
{
e
}
\n
{
CUDA_
GRAPH_CAPTURE_FAILED_MSG
}
"
)
)
def
_cache_loc_dtype
(
self
):
return
torch
.
int64
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
cuda_graph_bs
=
(
cuda_graph_bs
=
(
...
@@ -509,16 +502,8 @@ class GraphRunner:
...
@@ -509,16 +502,8 @@ class GraphRunner:
)
)
logger
.
info
(
log_message
)
logger
.
info
(
log_message
)
def
_capture_graph
(
self
,
graph
,
pool
,
stream
,
run_once_fn
):
with
self
.
device_module
.
graph
(
graph
,
pool
=
pool
,
stream
=
stream
):
out
=
run_once_fn
()
return
out
def
_create_device_graph
(
self
):
pass
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
self
.
_create_device_g
raph
()
graph
=
torch
.
cuda
.
CUDAG
raph
()
stream
=
self
.
stream
stream
=
self
.
stream
num_tokens
=
bs
*
self
.
num_tokens_per_bs
num_tokens
=
bs
*
self
.
num_tokens_per_bs
...
@@ -658,17 +643,19 @@ class GraphRunner:
...
@@ -658,17 +643,19 @@ class GraphRunner:
return
logits_output_or_pp_proxy_tensors
return
logits_output_or_pp_proxy_tensors
for
_
in
range
(
2
):
for
_
in
range
(
2
):
self
.
device_module
.
synchronize
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
self
.
model_runner
.
tp_group
.
barrier
()
run_once
()
run_once
()
if
get_global_graph_memory_pool
()
is
None
:
if
get_global_graph_memory_pool
()
is
None
:
set_global_graph_memory_pool
(
self
.
device_module
.
graph_pool_handle
())
set_global_graph_memory_pool
(
torch
.
cuda
.
graph_pool_handle
())
# Set graph pool id globally to be able to use symmetric memory
# Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id
(
get_global_graph_memory_pool
())
set_graph_pool_id
(
get_global_graph_memory_pool
())
out
=
self
.
_capture_graph
(
with
torch
.
cuda
.
graph
(
graph
,
get_global_graph_memory_pool
(),
stream
,
run_once
graph
,
pool
=
get_global_graph_memory_pool
(),
stream
=
stream
)
):
out
=
run_once
()
return
graph
,
out
return
graph
,
out
...
@@ -850,7 +837,7 @@ class GraphRunner:
...
@@ -850,7 +837,7 @@ class GraphRunner:
return
spec_info
return
spec_info
GRAPH_CAPTURE_FAILED_MSG
=
(
CUDA_
GRAPH_CAPTURE_FAILED_MSG
=
(
"Possible solutions:
\n
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
...
...
python/sglang/srt/model_executor/cuda_graph_runner_impl.py
deleted
100644 → 0
View file @
1ec97697
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cuda graph and torch.compile."""
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
torch
from
sglang.srt.model_executor.graph_runner
import
GraphRunner
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
CudaGraphRunner
(
GraphRunner
):
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Parse args
super
().
__init__
(
model_runner
)
def
_create_device_graph
(
self
):
return
torch
.
cuda
.
CUDAGraph
()
python/sglang/srt/model_executor/model_runner.py
View file @
de2dd738
...
@@ -89,11 +89,8 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -89,11 +89,8 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool
,
ReqToTokenPool
,
SWAKVPool
,
SWAKVPool
,
)
)
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
# TODO(iforgetmyname): Renaming on the way
from
sglang.srt.model_executor.cuda_graph_runner_impl
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
...
@@ -344,12 +341,9 @@ class ModelRunner:
...
@@ -344,12 +341,9 @@ class ModelRunner:
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
self
.
init_cublas
()
self
.
init_cublas
()
self
.
init_attention_backend
()
self
.
init_attention_backend
()
self
.
init_device_graphs
()
self
.
init_cuda_graphs
()
elif
self
.
device
==
"npu"
:
self
.
init_attention_backend
()
self
.
init_device_graphs
()
else
:
else
:
self
.
graph_runner
=
None
self
.
cuda_
graph_runner
=
None
self
.
cuda_graph_mem_usage
=
0
self
.
cuda_graph_mem_usage
=
0
self
.
init_attention_backend
()
self
.
init_attention_backend
()
...
@@ -923,8 +917,7 @@ class ModelRunner:
...
@@ -923,8 +917,7 @@ class ModelRunner:
)
)
# We need to get device after patch otherwise the device would be wrong
# We need to get device after patch otherwise the device would be wrong
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
infered_device
=
torch
.
cuda
.
current_device
()
infered_device
=
self
.
device_module
.
current_device
()
named_tensors
=
[
named_tensors
=
[
(
name
,
_unwrap_tensor
(
tensor
,
tp_rank
=
self
.
tp_rank
,
device
=
infered_device
))
(
name
,
_unwrap_tensor
(
tensor
,
tp_rank
=
self
.
tp_rank
,
device
=
infered_device
))
...
@@ -1592,9 +1585,9 @@ class ModelRunner:
...
@@ -1592,9 +1585,9 @@ class ModelRunner:
.
cuda
()
.
cuda
()
)
)
def
init_
device
_graphs
(
self
):
def
init_
cuda
_graphs
(
self
):
"""Capture cuda graphs."""
"""Capture cuda graphs."""
self
.
graph_runner
=
None
self
.
cuda_
graph_runner
=
None
self
.
cuda_graph_mem_usage
=
0
self
.
cuda_graph_mem_usage
=
0
if
not
self
.
is_generation
:
if
not
self
.
is_generation
:
...
@@ -1609,9 +1602,8 @@ class ModelRunner:
...
@@ -1609,9 +1602,8 @@ class ModelRunner:
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
)
self
.
graph_runner
=
(
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
CudaGraphRunner
(
self
)
if
not
_is_npu
else
NPUGraphRunner
(
self
)
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
cuda_graph_mem_usage
=
before_mem
-
after_mem
self
.
cuda_graph_mem_usage
=
before_mem
-
after_mem
logger
.
info
(
logger
.
info
(
...
@@ -1763,11 +1755,11 @@ class ModelRunner:
...
@@ -1763,11 +1755,11 @@ class ModelRunner:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
can_run_cuda_graph
=
bool
(
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
forward_batch
.
forward_mode
.
is_cuda_graph
()
and
self
.
graph_runner
and
self
.
cuda_
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
and
self
.
cuda_
graph_runner
.
can_run
(
forward_batch
)
)
)
if
can_run_cuda_graph
:
if
can_run_cuda_graph
:
ret
=
self
.
graph_runner
.
replay
(
ret
=
self
.
cuda_
graph_runner
.
replay
(
forward_batch
,
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
...
...
python/sglang/srt/model_executor/npu_graph_runner.py
deleted
100644 → 0
View file @
1ec97697
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with npu graph and torch.compile."""
from
__future__
import
annotations
import
logging
import
threading
from
typing
import
TYPE_CHECKING
import
torch
from
sglang.srt.model_executor.graph_runner
import
GraphRunner
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
class
NPUGraphRunner
(
GraphRunner
):
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
(
model_runner
)
def
_create_device_graph
(
self
):
return
torch
.
npu
.
NPUGraph
()
def
_capture_graph
(
self
,
graph
,
pool
,
stream
,
run_once_fn
):
with
torch
.
npu
.
graph
(
graph
,
pool
=
pool
,
stream
=
stream
,
auto_dispatch_capture
=
True
,
):
out
=
run_once_fn
()
return
out
def
_update_inputs
(
self
,
seq_lens
):
self
.
graphs
[
self
.
bs
].
update
(
cpu_update_input
=
[{
"actual_seq_lengths_kv"
:
seq_lens
}]
)
def
_cache_loc_dtype
(
self
):
return
torch
.
int32
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
,
pp_proxy_tensors
)
else
:
# In speculative decoding, these two fields are still needed.
self
.
input_ids
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
positions
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
positions
)
# Replay
seq_lens
=
forward_batch
.
seq_lens
.
cpu
().
tolist
()
+
[
0
]
*
(
self
.
bs
-
self
.
raw_bs
)
thread
=
threading
.
Thread
(
target
=
self
.
_update_inputs
,
args
=
(
seq_lens
,))
thread
.
start
()
self
.
graphs
[
self
.
bs
].
replay
()
thread
.
join
()
output
=
self
.
output_buffers
[
self
.
bs
]
if
isinstance
(
output
,
LogitsProcessorOutput
):
return
LogitsProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
self
.
raw_num_token
],
hidden_states
=
(
output
.
hidden_states
[:
self
.
raw_num_token
]
if
output
.
hidden_states
is
not
None
else
None
),
)
else
:
assert
isinstance
(
output
,
PPProxyTensors
)
return
PPProxyTensors
({
k
:
v
[:
self
.
bs
]
for
k
,
v
in
output
.
tensors
.
items
()})
python/sglang/srt/models/deepseek_v2.py
View file @
de2dd738
...
@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
):
):
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.
cuda_
graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
if
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
:
if
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
:
...
...
python/sglang/srt/models/glm4_moe.py
View file @
de2dd738
...
@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
(
from
sglang.srt.models.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
DeepseekV2DecoderLayer
,
...
...
python/sglang/srt/models/mllama.py
View file @
de2dd738
...
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
...
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.
cuda_
graph_runner
import
get_is_capture_mode
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
=
(
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
=
(
self
.
_batch_image_inputs
(
forward_batch
)
self
.
_batch_image_inputs
(
forward_batch
)
...
...
python/sglang/srt/models/qwen3.py
View file @
de2dd738
...
@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
de2dd738
...
@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope
...
@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
de2dd738
...
@@ -6,22 +6,20 @@ from typing import TYPE_CHECKING, Callable
...
@@ -6,22 +6,20 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.model_executor.cuda_graph_runner
import
(
# TODO(iforgetmyname): Renaming on the way
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
from
sglang.srt.model_executor.cuda_graph_runner_impl
import
CudaGraphRunner
CudaGraphRunner
,
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.model_executor.graph_runner
import
(
GRAPH_CAPTURE_FAILED_MSG
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
model_capture_mode
,
model_capture_mode
,
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_attn_tp_gather
,
...
@@ -123,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -123,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
{
GRAPH_CAPTURE_FAILED_MSG
}
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_
GRAPH_CAPTURE_FAILED_MSG
}
"
)
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
de2dd738
...
@@ -6,16 +6,9 @@ from typing import TYPE_CHECKING, Callable
...
@@ -6,16 +6,9 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.model_executor.cuda_graph_runner
import
(
# TODO(iforgetmyname): Renaming on the way
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
from
sglang.srt.model_executor.cuda_graph_runner_impl
import
CudaGraphRunner
CudaGraphRunner
,
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.model_executor.graph_runner
import
(
GRAPH_CAPTURE_FAILED_MSG
,
LogitsProcessorOutput
,
LogitsProcessorOutput
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
...
@@ -23,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import (
...
@@ -23,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import (
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
fast_topk
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
fast_topk
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_attn_tp_gather
,
...
@@ -151,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -151,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
{
GRAPH_CAPTURE_FAILED_MSG
}
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_
GRAPH_CAPTURE_FAILED_MSG
}
"
)
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
test/srt/ascend/test_ascend_graph_tp1_bf16.py
deleted
100644 → 0
View file @
1ec97697
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"Qwen/Qwen2.5-7B-Instruct"
:
{
"accuracy"
:
0.85
,
"latency"
:
150
,
"output_throughput"
:
30
,
},
}
class
TestAscendGraphTp1Bf16
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--mem-fraction-static"
,
0.8
,
"--attention-backend"
,
"ascend"
,
]
def
test_a_gsm8k
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/ascend/test_ascend_graph_tp2_bf16.py
deleted
100644 → 0
View file @
1ec97697
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"Qwen/Qwen2.5-7B-Instruct"
:
{
"accuracy"
:
0.85
,
"latency"
:
180
,
"output_throughput"
:
20
,
},
}
class
TestAscendGraphTp2Bf16
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--mem-fraction-static"
,
0.8
,
"--attention-backend"
,
"ascend"
,
"--tp-size"
,
2
,
]
def
test_a_gsm8k
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
de2dd738
...
@@ -269,11 +269,9 @@ suite_xeon = {
...
@@ -269,11 +269,9 @@ suite_xeon = {
suite_ascend
=
{
suite_ascend
=
{
"per-commit-1-ascend-npu"
:
[
"per-commit-1-ascend-npu"
:
[
TestFile
(
"ascend/test_ascend_tp1_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_tp1_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_graph_tp1_bf16.py"
,
400
),
],
],
"per-commit-2-ascend-npu"
:
[
"per-commit-2-ascend-npu"
:
[
TestFile
(
"ascend/test_ascend_tp2_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_tp2_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_graph_tp2_bf16.py"
,
400
),
],
],
"per-commit-4-ascend-npu"
:
[
"per-commit-4-ascend-npu"
:
[
TestFile
(
"ascend/test_ascend_mla_w8a8int8.py"
,
400
),
TestFile
(
"ascend/test_ascend_mla_w8a8int8.py"
,
400
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment