Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f4fafacc
Unverified
Commit
f4fafacc
authored
Aug 20, 2025
by
Even Zhou
Committed by
GitHub
Aug 19, 2025
Browse files
Revert "[feature] Ascend NPU graph support (#8027)" (#9348)
parent
01d47a27
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
878 additions
and
1349 deletions
+878
-1349
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
...els/fused_moe_triton/benchmark_torch_compile_fused_moe.py
+1
-1
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+4
-10
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+22
-135
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+817
-6
python/sglang/srt/model_executor/graph_runner.py
python/sglang/srt/model_executor/graph_runner.py
+0
-860
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-16
python/sglang/srt/model_executor/npu_graph_runner.py
python/sglang/srt/model_executor/npu_graph_runner.py
+0
-94
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+1
-1
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+1
-1
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+1
-1
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+1
-1
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+9
-9
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+9
-9
test/srt/run_suite.py
test/srt/run_suite.py
+0
-11
test/srt/test_ascend_graph_tp1_bf16.py
test/srt/test_ascend_graph_tp1_bf16.py
+0
-95
test/srt/test_ascend_graph_tp2_bf16.py
test/srt/test_ascend_graph_tp2_bf16.py
+0
-97
No files found.
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
View file @
f4fafacc
...
@@ -9,7 +9,7 @@ from transformers import AutoConfig
...
@@ -9,7 +9,7 @@ from transformers import AutoConfig
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
fused_moe
as
fused_moe_triton
,
fused_moe
as
fused_moe_triton
,
)
)
from
sglang.srt.model_executor.graph_runner
import
set_torch_compile_config
from
sglang.srt.model_executor.
cuda_
graph_runner
import
set_torch_compile_config
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
f4fafacc
...
@@ -55,7 +55,7 @@ _is_npu = is_npu()
...
@@ -55,7 +55,7 @@ _is_npu = is_npu()
@
dataclass
@
dataclass
class
GraphCaptureContext
:
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
if
not
_is_npu
else
torch
.
npu
.
Stream
stream
:
torch
.
cuda
.
Stream
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
...
@@ -252,13 +252,9 @@ class GroupCoordinator:
...
@@ -252,13 +252,9 @@ class GroupCoordinator:
if
is_cuda_alike
():
if
is_cuda_alike
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
_is_npu
:
self
.
device
=
torch
.
device
(
f
"npu:
{
local_rank
}
"
)
else
:
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
use_pynccl
=
use_pynccl
self
.
use_pynccl
=
use_pynccl
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_custom_allreduce
=
use_custom_allreduce
...
@@ -406,7 +402,7 @@ class GroupCoordinator:
...
@@ -406,7 +402,7 @@ class GroupCoordinator:
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
):
):
if
graph_capture_context
is
None
:
if
graph_capture_context
is
None
:
stream
=
self
.
device_module
.
Stream
()
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
graph_capture_context
=
GraphCaptureContext
(
stream
)
else
:
else
:
stream
=
graph_capture_context
.
stream
stream
=
graph_capture_context
.
stream
...
@@ -417,11 +413,11 @@ class GroupCoordinator:
...
@@ -417,11 +413,11 @@ class GroupCoordinator:
# ensure all initialization operations complete before attempting to
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
# capture the graph on another stream
curr_stream
=
self
.
device_module
.
current_stream
()
curr_stream
=
torch
.
cuda
.
current_stream
()
if
curr_stream
!=
stream
:
if
curr_stream
!=
stream
:
stream
.
wait_stream
(
curr_stream
)
stream
.
wait_stream
(
curr_stream
)
with
self
.
device_module
.
stream
(
stream
),
maybe_ca_context
:
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# allreduce \ Mode | Eager | Graph |
...
@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
...
@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
)
)
elif
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
elif
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
torch
.
xpu
.
empty_cache
()
torch
.
xpu
.
empty_cache
()
elif
hasattr
(
torch
,
"npu"
)
and
torch
.
npu
.
is_available
():
torch
.
npu
.
empty_cache
()
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
f4fafacc
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
torch_npu
import
torch_npu
...
@@ -27,7 +27,6 @@ class ForwardMetadata:
...
@@ -27,7 +27,6 @@ class ForwardMetadata:
# seq len inputs
# seq len inputs
extend_seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_int
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_cpu_list
:
Optional
[
List
[
int
]]
=
None
class
AscendAttnBackend
(
AttentionBackend
):
class
AscendAttnBackend
(
AttentionBackend
):
...
@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
super
().
__init__
()
self
.
forward_metadata
=
None
self
.
forward_metadata
=
ForwardMetadata
()
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
self
.
gen_attention_mask
(
128
,
model_runner
.
dtype
)
self
.
gen_attention_mask
(
128
,
model_runner
.
dtype
)
self
.
page_size
=
model_runner
.
page_size
self
.
page_size
=
model_runner
.
page_size
...
@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend):
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
native_attn
=
TorchNativeAttnBackend
(
model_runner
)
self
.
native_attn
=
TorchNativeAttnBackend
(
model_runner
)
self
.
graph_metadata
=
{}
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
graph_mode
=
False
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
"""Init the metadata for a forward pass."""
self
.
forward_metadata
=
ForwardMetadata
()
self
.
forward_metadata
.
block_tables
=
(
self
.
forward_metadata
.
block_tables
=
(
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
forward_batch
.
seq_lens
.
max
()
forward_batch
.
req_pool_indices
,
:
forward_batch
.
seq_lens
.
max
()
...
@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend):
)
)
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
graph_mode
=
False
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
graph_metadata
=
{
"block_tables"
:
torch
.
empty
(
(
max_bs
,
self
.
max_context_len
//
self
.
page_size
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
}
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
metadata
=
ForwardMetadata
()
metadata
.
block_tables
=
self
.
graph_metadata
[
"block_tables"
][:
bs
,
:]
metadata
.
seq_lens_cpu_list
=
seq_lens
.
cpu
().
int
().
tolist
()
self
.
graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
self
.
graph_mode
=
True
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
metadata
=
self
.
graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
[:
bs
].
max
().
item
()
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
block_tables
[:
bs
,
:
max_seq_pages
].
copy_
(
self
.
req_to_token
[
req_pool_indices
[:
bs
],
:
max_len
][:,
::
self
.
page_size
]
//
self
.
page_size
)
metadata
.
block_tables
[:
bs
,
max_seq_pages
:].
fill_
(
0
)
metadata
.
block_tables
[
bs
:,
:].
fill_
(
0
)
self
.
forward_metadata
=
metadata
self
.
graph_mode
=
True
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
...
@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
if
self
.
graph_mode
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
layer
.
qk_head_dim
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
query
=
q
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
num_tokens
=
query
.
shape
[
0
]
workspace
=
(
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
query
,
k_cache
,
v_cache
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSH"
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
)
)
output
=
torch
.
empty
(
(
num_tokens
,
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
softmax_lse
=
torch
.
empty
(
1
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch_npu
.
npu_fused_infer_attention_score
.
out
(
query
,
k_cache
,
v_cache
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSH"
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
workspace
=
workspace
,
out
=
[
output
,
softmax_lse
],
)
else
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
num_tokens
=
query
.
shape
[
0
]
num_tokens
=
query
.
shape
[
0
]
output
=
torch
.
empty
(
output
=
torch
.
empty
(
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
dtype
=
query
.
dtype
,
dtype
=
query
.
dtype
,
device
=
query
.
device
,
device
=
query
.
device
,
)
)
torch_npu
.
_npu_paged_attention
(
torch_npu
.
_npu_paged_attention
(
query
=
query
,
query
=
query
,
key_cache
=
k_cache
,
key_cache
=
k_cache
,
value_cache
=
v_cache
,
value_cache
=
v_cache
,
num_heads
=
layer
.
tp_q_head_num
,
num_heads
=
layer
.
tp_q_head_num
,
num_kv_heads
=
layer
.
tp_k_head_num
,
num_kv_heads
=
layer
.
tp_k_head_num
,
scale_value
=
layer
.
scaling
,
scale_value
=
layer
.
scaling
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_table
=
self
.
forward_metadata
.
block_tables
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
out
=
output
,
out
=
output
,
)
)
return
output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
else
:
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
...
@@ -330,6 +220,3 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -330,6 +220,3 @@ class AscendAttnBackend(AttentionBackend):
out
=
attn_output
,
out
=
attn_output
,
)
)
return
attn_output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
self
.
kv_lora_rank
)
return
attn_output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
self
.
kv_lora_rank
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
python/sglang/srt/mem_cache/memory_pool.py
View file @
f4fafacc
...
@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache):
...
@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache):
v_scale
:
Optional
[
float
]
=
None
,
v_scale
:
Optional
[
float
]
=
None
,
layer_id_override
:
Optional
[
int
]
=
None
,
layer_id_override
:
Optional
[
int
]
=
None
,
):
):
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.
cuda_
graph_runner
import
get_is_capture_mode
if
layer_id_override
is
not
None
:
if
layer_id_override
is
not
None
:
layer_id
=
layer_id_override
layer_id
=
layer_id_override
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
f4fafacc
...
@@ -15,22 +15,833 @@
...
@@ -15,22 +15,833 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
bisect
import
gc
import
inspect
import
logging
import
os
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
torch
import
torch
import
tqdm
from
torch.profiler
import
ProfilerActivity
,
profile
from
sglang.srt.model_executor.graph_runner
import
GraphRunner
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
set_graph_pool_id
,
)
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
(
DpPaddingMode
,
get_attention_tp_rank
,
get_attention_tp_size
,
set_dp_buffer_len
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
enable_num_token_non_padded
,
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
from
sglang.srt.utils
import
(
empty_context
,
get_available_gpu_memory
,
get_device_memory_capacity
,
rank0_log
,
require_attn_tp_gather
,
require_gathered_buffer
,
require_mlp_sync
,
require_mlp_tp_gather
,
)
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode
=
False
def
get_is_capture_mode
():
return
is_capture_mode
@
contextmanager
def
model_capture_mode
():
global
is_capture_mode
is_capture_mode
=
True
yield
is_capture_mode
=
False
@
contextmanager
def
freeze_gc
(
enable_cudagraph_gc
:
bool
):
"""
Optimize garbage collection during CUDA graph capture.
Clean up, then freeze all remaining objects from being included
in future collections if GC is disabled during capture.
"""
gc
.
collect
()
should_freeze
=
not
enable_cudagraph_gc
if
should_freeze
:
gc
.
freeze
()
try
:
yield
finally
:
if
should_freeze
:
gc
.
unfreeze
()
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
sub
.
leave_torch_compile
()
else
:
sub
.
enter_torch_compile
(
num_tokens
=
num_tokens
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
,
num_tokens
)
@
contextmanager
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
num_tokens
:
int
,
tp_group
:
GroupCoordinator
,
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm
=
None
try
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
False
,
num_tokens
=
num_tokens
)
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
mode
=
os
.
environ
.
get
(
"SGLANG_TORCH_COMPILE_MODE"
,
"max-autotune-no-cudagraphs"
),
dynamic
=
False
,
)
else
:
yield
model
.
forward
finally
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
,
num_tokens
=
num_tokens
)
tp_group
.
ca_comm
=
backup_ca_comm
def
set_torch_compile_config
():
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
coordinate_descent_tuning
=
True
torch
.
_inductor
.
config
.
triton
.
unique_kernel_names
=
True
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
1024
if
hasattr
(
torch
.
_dynamo
.
config
,
"cache_size_limit"
):
torch
.
_dynamo
.
config
.
cache_size_limit
=
1024
monkey_patch_torch_compile
()
def
get_batch_sizes_to_capture
(
model_runner
:
ModelRunner
):
server_args
=
model_runner
.
server_args
capture_bs
=
server_args
.
cuda_graph_bs
if
capture_bs
is
None
:
if
server_args
.
speculative_algorithm
is
None
:
if
server_args
.
disable_cuda_graph_padding
:
capture_bs
=
list
(
range
(
1
,
33
))
+
list
(
range
(
48
,
161
,
16
))
else
:
capture_bs
=
[
1
,
2
,
4
,
8
]
+
list
(
range
(
16
,
161
,
8
))
else
:
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs
=
(
list
(
range
(
1
,
9
))
+
list
(
range
(
10
,
33
,
2
))
+
list
(
range
(
40
,
64
,
8
))
+
list
(
range
(
80
,
161
,
16
))
)
gpu_mem
=
get_device_memory_capacity
()
if
gpu_mem
is
not
None
:
if
gpu_mem
>
90
*
1024
:
# H200, H20
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
if
gpu_mem
>
160
*
1000
:
# B200, MI300
capture_bs
+=
list
(
range
(
256
,
513
,
16
))
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs
+=
[
model_runner
.
req_to_token_pool
.
size
]
mul_base
=
1
class
CudaGraphRunner
(
GraphRunner
):
if
server_args
.
enable_two_batch_overlap
:
mul_base
*=
2
if
require_gathered_buffer
(
server_args
):
mul_base
*=
get_attention_tp_size
()
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
%
mul_base
==
0
]
if
server_args
.
cuda_graph_max_bs
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
capture_bs
+=
list
(
range
(
max
(
capture_bs
),
server_args
.
cuda_graph_max_bs
+
1
,
16
)
)
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
,
f
"
{
capture_bs
=
}
"
compile_bs
=
(
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
if
server_args
.
enable_torch_compile
else
[]
)
return
capture_bs
,
compile_bs
# Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool
=
None
def
get_global_graph_memory_pool
():
return
global_graph_memory_pool
def
set_global_graph_memory_pool
(
val
):
global
global_graph_memory_pool
global_graph_memory_pool
=
val
class
CudaGraphRunner
:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Parse args
# Parse args
super
().
__init__
(
model_runner
)
self
.
model_runner
=
model_runner
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
require_gathered_buffer
=
require_gathered_buffer
(
model_runner
.
server_args
)
self
.
require_mlp_tp_gather
=
require_mlp_tp_gather
(
model_runner
.
server_args
)
self
.
require_mlp_sync
=
require_mlp_sync
(
model_runner
.
server_args
)
self
.
require_attn_tp_gather
=
require_attn_tp_gather
(
model_runner
.
server_args
)
self
.
enable_two_batch_overlap
=
(
model_runner
.
server_args
.
enable_two_batch_overlap
)
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
enable_profile_cuda_graph
=
(
model_runner
.
server_args
.
enable_profile_cuda_graph
)
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
rank0_log
(
f
"Capture cuda graph bs
{
self
.
capture_bs
}
"
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen"
)
else
:
self
.
capture_forward_mode
=
ForwardMode
.
TARGET_VERIFY
self
.
num_tokens_per_bs
=
(
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
)
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if
model_runner
.
server_args
.
enable_return_hidden_states
:
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_bs
,
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
self
.
seq_lens_cpu
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
if
self
.
enable_torch_compile
:
set_torch_compile_config
()
if
self
.
model_runner
.
server_args
.
enable_lora
:
self
.
model_runner
.
lora_manager
.
init_cuda_graph_batch_info
(
self
.
max_bs
)
# Graph inputs
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
num_token_non_padded
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
tbo_plugin
=
TboCudaGraphRunnerPlugin
()
# pipeline parallelism
if
self
.
pp_size
>
1
:
self
.
pp_proxy_tensors
=
{
"hidden_states"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
"residual"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
}
# Speculative_inference
if
model_runner
.
spec_algorithm
.
is_eagle3
():
self
.
model_runner
.
model
.
set_eagle3_layers_to_capture
()
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self
.
encoder_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
encoder_len_fill_value
,
dtype
=
torch
.
int32
)
else
:
self
.
encoder_lens
=
None
if
self
.
require_gathered_buffer
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
else
:
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
)
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
custom_mask
=
torch
.
ones
(
(
(
self
.
seq_lens
.
sum
().
item
()
+
self
.
max_num_token
)
*
self
.
num_tokens_per_bs
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
)
self
.
next_token_logits_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
vocab_size
),
dtype
=
torch
.
float
,
device
=
"cuda"
,
)
# Capture
try
:
with
model_capture_mode
():
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
cuda_graph_bs
=
(
max
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
max
(
forward_batch
.
global_num_tokens_cpu
)
)
else
:
cuda_graph_bs
=
forward_batch
.
batch_size
is_bs_supported
=
(
cuda_graph_bs
in
self
.
graphs
if
self
.
disable_padding
else
cuda_graph_bs
<=
self
.
max_bs
)
if
self
.
require_mlp_sync
:
is_bs_supported
=
is_bs_supported
and
forward_batch
.
can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported
=
(
torch
.
all
(
forward_batch
.
encoder_lens
>
0
)
if
self
.
is_encoder_decoder
else
True
)
requested_capture_hidden_mode
=
max
(
forward_batch
.
capture_hidden_mode
,
(
forward_batch
.
spec_info
.
capture_hidden_mode
if
getattr
(
forward_batch
.
spec_info
,
"capture_hidden_mode"
,
None
)
is
not
None
else
CaptureHiddenMode
.
NULL
),
)
capture_hidden_mode_matches
=
(
requested_capture_hidden_mode
==
CaptureHiddenMode
.
NULL
or
requested_capture_hidden_mode
==
self
.
capture_hidden_mode
)
is_tbo_supported
=
(
forward_batch
.
can_run_tbo
if
self
.
enable_two_batch_overlap
else
True
)
return
(
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
and
capture_hidden_mode_matches
)
def
capture
(
self
)
->
None
:
profile_context
=
empty_context
()
if
self
.
enable_profile_cuda_graph
:
profile_context
=
profile
(
activities
=
[
ProfilerActivity
.
CPU
,
ProfilerActivity
.
CUDA
],
record_shapes
=
True
,
)
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with
freeze_gc
(
self
.
model_runner
.
server_args
.
enable_cudagraph_gc
),
graph_capture
()
as
graph_capture_context
:
with
profile_context
as
prof
:
self
.
stream
=
graph_capture_context
.
stream
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range
=
(
tqdm
.
tqdm
(
list
(
reversed
(
self
.
capture_bs
)))
if
get_tensor_model_parallel_rank
()
==
0
else
reversed
(
self
.
capture_bs
)
)
for
i
,
bs
in
enumerate
(
capture_range
):
if
get_tensor_model_parallel_rank
()
==
0
:
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
capture_range
.
set_description
(
f
"Capturing batches (
{
bs
=
}
{
avail_mem
=
:.
2
f
}
GB)"
)
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
num_tokens
=
bs
*
self
.
num_tokens_per_bs
,
tp_group
=
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
graph
,
output_buffers
,
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
output_buffers
[
bs
]
=
output_buffers
# Save gemlite cache after each capture
save_gemlite_cache
()
if
self
.
enable_profile_cuda_graph
:
log_message
=
(
"Sorted by CUDA Time:
\n
"
+
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"cuda_time_total"
,
row_limit
=
10
)
+
"
\n\n
Sorted by CPU Time:
\n
"
+
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"cpu_time_total"
,
row_limit
=
10
)
)
logger
.
info
(
log_message
)
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
num_tokens
=
bs
*
self
.
num_tokens_per_bs
# Graph inputs
input_ids
=
self
.
input_ids
[:
num_tokens
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
encoder_lens
=
None
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
next_token_logits_buffer
=
self
.
next_token_logits_buffer
[:
num_tokens
]
self
.
num_token_non_padded
[...]
=
num_tokens
# pipeline parallelism
if
self
.
pp_size
>
1
:
pp_proxy_tensors
=
PPProxyTensors
(
{
k
:
v
[:
num_tokens
]
for
k
,
v
in
self
.
pp_proxy_tensors
.
items
()}
)
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
]
*
self
.
dp_size
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
]
*
self
.
dp_size
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
global_dp_buffer_len
=
num_tokens
*
self
.
dp_size
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
],
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
],
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
global_dp_buffer_len
=
num_tokens
else
:
global_dp_buffer_len
=
None
spec_info
=
self
.
get_spec_info
(
num_tokens
)
if
self
.
capture_hidden_mode
!=
CaptureHiddenMode
.
FULL
:
self
.
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
)
if
self
.
model_runner
.
server_args
.
enable_lora
:
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
lora_ids
=
[
None
]
*
bs
else
:
lora_ids
=
None
forward_batch
=
ForwardBatch
(
forward_mode
=
self
.
capture_forward_mode
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
orig_seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
().
item
(),
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
positions
=
positions
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
DpPaddingMode
.
get_default_mode_in_cuda_graph
(),
global_dp_buffer_len
=
global_dp_buffer_len
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
num_token_non_padded
=
self
.
num_token_non_padded
,
global_forward_mode
=
self
.
capture_forward_mode
,
lora_ids
=
lora_ids
,
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
if
lora_ids
is
not
None
:
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_batch
.
forward_mode
,
forward_batch
.
spec_info
,
)
# Run and capture
def
run_once
():
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
kwargs
=
{}
if
(
self
.
pp_size
>
1
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
):
kwargs
[
"pp_proxy_tensors"
]
=
PPProxyTensors
(
{
k
:
v
.
clone
()
for
k
,
v
in
pp_proxy_tensors
.
tensors
.
items
()}
)
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
return
logits_output_or_pp_proxy_tensors
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
run_once
()
if
get_global_graph_memory_pool
()
is
None
:
set_global_graph_memory_pool
(
torch
.
cuda
.
graph_pool_handle
())
# Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id
(
get_global_graph_memory_pool
())
with
torch
.
cuda
.
graph
(
graph
,
pool
=
get_global_graph_memory_pool
(),
stream
=
stream
):
out
=
run_once
()
return
graph
,
out
def
recapture_if_needed
(
self
,
forward_batch
:
ForwardBatch
):
# If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch
=
(
forward_batch
.
capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info
=
getattr
(
forward_batch
.
spec_info
,
"capture_hidden_mode"
,
CaptureHiddenMode
.
NULL
)
capture_hidden_mode_required_for_returning_hidden_states
=
(
CaptureHiddenMode
.
FULL
if
self
.
model_runner
.
server_args
.
enable_return_hidden_states
else
CaptureHiddenMode
.
NULL
)
# Determine the highest capture_hidden_mode required
# (If we have FULL, we can emulate LAST or NULL)
# (If we have LAST, we can emulate NULL)
required_capture_hidden_mode
=
max
(
capture_hidden_mode_required_by_forward_batch
,
capture_hidden_mode_required_by_spec_info
,
capture_hidden_mode_required_for_returning_hidden_states
,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if
self
.
capture_hidden_mode
!=
required_capture_hidden_mode
:
self
.
capture_hidden_mode
=
required_capture_hidden_mode
self
.
capture
()
def
replay_prepare
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
):
self
.
recapture_if_needed
(
forward_batch
)
raw_bs
=
forward_batch
.
batch_size
raw_num_token
=
raw_bs
*
self
.
num_tokens_per_bs
# Pad
if
self
.
require_mlp_tp_gather
:
max_num_tokens
=
max
(
forward_batch
.
global_num_tokens_cpu
)
max_batch_size
=
(
max_num_tokens
/
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
max_num_tokens
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max_batch_size
)
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
out_cache_loc
.
zero_
()
# Common inputs
self
.
input_ids
[:
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
seq_lens_cpu
=
None
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
if
pp_proxy_tensors
:
for
key
in
self
.
pp_proxy_tensors
.
keys
():
dim
=
pp_proxy_tensors
[
key
].
shape
[
0
]
self
.
pp_proxy_tensors
[
key
][:
dim
].
copy_
(
pp_proxy_tensors
[
key
])
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
require_gathered_buffer
:
self
.
global_num_tokens_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
self
.
global_num_tokens_for_logprob_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
if
enable_num_token_non_padded
(
self
.
model_runner
.
server_args
):
num_token_non_padded
=
forward_batch
.
num_token_non_padded
if
self
.
require_gathered_buffer
:
tokens_per_rank
=
bs
//
self
.
attn_tp_size
*
self
.
num_tokens_per_bs
num_local_token_non_padded
=
torch
.
clamp
(
num_token_non_padded
-
tokens_per_rank
*
self
.
attn_tp_rank
,
min
=
0
,
max
=
tokens_per_rank
,
)
self
.
num_token_non_padded
.
copy_
(
num_local_token_non_padded
)
else
:
self
.
num_token_non_padded
.
copy_
(
num_token_non_padded
)
if
self
.
enable_two_batch_overlap
:
self
.
tbo_plugin
.
replay_prepare
(
forward_mode
=
self
.
capture_forward_mode
,
bs
=
bs
,
num_token_non_padded
=
len
(
forward_batch
.
input_ids
),
spec_info
=
forward_batch
.
spec_info
,
)
if
forward_batch
.
forward_mode
.
is_idle
()
and
forward_batch
.
spec_info
is
not
None
:
forward_batch
.
spec_info
.
custom_mask
=
self
.
custom_mask
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
self
.
req_pool_indices
[:
bs
],
self
.
seq_lens
[:
bs
],
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
)
*
self
.
seq_len_fill_value
,
self
.
encoder_lens
[:
bs
]
if
self
.
is_encoder_decoder
else
None
,
self
.
capture_forward_mode
,
forward_batch
.
spec_info
,
seq_lens_cpu
=
seq_lens_cpu
,
)
# Store fields
self
.
raw_bs
=
raw_bs
self
.
raw_num_token
=
raw_num_token
self
.
bs
=
bs
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
,
pp_proxy_tensors
)
else
:
# In speculative decoding, these two fields are still needed.
self
.
input_ids
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
positions
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
positions
)
# Replay
self
.
graphs
[
self
.
bs
].
replay
()
output
=
self
.
output_buffers
[
self
.
bs
]
if
isinstance
(
output
,
LogitsProcessorOutput
):
return
LogitsProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
self
.
raw_num_token
],
hidden_states
=
(
output
.
hidden_states
[:
self
.
raw_num_token
]
if
output
.
hidden_states
is
not
None
else
None
),
)
else
:
assert
isinstance
(
output
,
PPProxyTensors
)
return
PPProxyTensors
({
k
:
v
[:
self
.
bs
]
for
k
,
v
in
output
.
tensors
.
items
()})
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_utils
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
else
:
spec_info
=
EagleVerifyInput
(
draft_token
=
None
,
custom_mask
=
self
.
custom_mask
,
positions
=
None
,
retrive_index
=
None
,
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
seq_lens_sum
=
None
,
seq_lens_cpu
=
None
,
)
return
spec_info
def
_create_device_graph
(
self
):
CUDA_GRAPH_CAPTURE_FAILED_MSG
=
(
return
torch
.
cuda
.
CUDAGraph
()
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
python/sglang/srt/model_executor/graph_runner.py
deleted
100644 → 0
View file @
01d47a27
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with device graph and torch.compile."""
from
__future__
import
annotations
import
bisect
import
gc
import
inspect
import
logging
import
os
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
torch
import
tqdm
from
torch.profiler
import
ProfilerActivity
,
profile
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
set_graph_pool_id
,
)
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
(
DpPaddingMode
,
get_attention_tp_rank
,
get_attention_tp_size
,
set_dp_buffer_len
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
enable_num_token_non_padded
,
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
from
sglang.srt.utils
import
(
empty_context
,
get_available_gpu_memory
,
get_device_memory_capacity
,
rank0_log
,
require_attn_tp_gather
,
require_gathered_buffer
,
require_mlp_sync
,
require_mlp_tp_gather
,
)
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode
=
False
def
get_is_capture_mode
():
return
is_capture_mode
@
contextmanager
def
model_capture_mode
():
global
is_capture_mode
is_capture_mode
=
True
yield
is_capture_mode
=
False
@
contextmanager
def
freeze_gc
(
enable_cudagraph_gc
:
bool
):
"""
Optimize garbage collection during CUDA graph capture.
Clean up, then freeze all remaining objects from being included
in future collections if GC is disabled during capture.
"""
gc
.
collect
()
should_freeze
=
not
enable_cudagraph_gc
if
should_freeze
:
gc
.
freeze
()
try
:
yield
finally
:
if
should_freeze
:
gc
.
unfreeze
()
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
sub
.
leave_torch_compile
()
else
:
sub
.
enter_torch_compile
(
num_tokens
=
num_tokens
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
,
num_tokens
)
@
contextmanager
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
num_tokens
:
int
,
tp_group
:
GroupCoordinator
,
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm
=
None
try
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
False
,
num_tokens
=
num_tokens
)
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
mode
=
os
.
environ
.
get
(
"SGLANG_TORCH_COMPILE_MODE"
,
"max-autotune-no-cudagraphs"
),
dynamic
=
False
,
)
else
:
yield
model
.
forward
finally
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
,
num_tokens
=
num_tokens
)
tp_group
.
ca_comm
=
backup_ca_comm
def
set_torch_compile_config
():
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
coordinate_descent_tuning
=
True
torch
.
_inductor
.
config
.
triton
.
unique_kernel_names
=
True
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
1024
if
hasattr
(
torch
.
_dynamo
.
config
,
"cache_size_limit"
):
torch
.
_dynamo
.
config
.
cache_size_limit
=
1024
monkey_patch_torch_compile
()
def
get_batch_sizes_to_capture
(
model_runner
:
ModelRunner
):
server_args
=
model_runner
.
server_args
capture_bs
=
server_args
.
cuda_graph_bs
if
capture_bs
is
None
:
if
server_args
.
speculative_algorithm
is
None
:
if
server_args
.
disable_cuda_graph_padding
:
capture_bs
=
list
(
range
(
1
,
33
))
+
list
(
range
(
48
,
161
,
16
))
else
:
capture_bs
=
[
1
,
2
,
4
,
8
]
+
list
(
range
(
16
,
161
,
8
))
else
:
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs
=
(
list
(
range
(
1
,
9
))
+
list
(
range
(
10
,
33
,
2
))
+
list
(
range
(
40
,
64
,
8
))
+
list
(
range
(
80
,
161
,
16
))
)
gpu_mem
=
get_device_memory_capacity
()
if
gpu_mem
is
not
None
:
if
gpu_mem
>
90
*
1024
:
# H200, H20
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
if
gpu_mem
>
160
*
1000
:
# B200, MI300
capture_bs
+=
list
(
range
(
256
,
513
,
16
))
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs
+=
[
model_runner
.
req_to_token_pool
.
size
]
mul_base
=
1
if
server_args
.
enable_two_batch_overlap
:
mul_base
*=
2
if
require_gathered_buffer
(
server_args
):
mul_base
*=
get_attention_tp_size
()
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
%
mul_base
==
0
]
if
server_args
.
cuda_graph_max_bs
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
capture_bs
+=
list
(
range
(
max
(
capture_bs
),
server_args
.
cuda_graph_max_bs
+
1
,
16
)
)
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
,
f
"
{
capture_bs
=
}
"
compile_bs
=
(
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
if
server_args
.
enable_torch_compile
else
[]
)
return
capture_bs
,
compile_bs
# Reuse this memory pool across all device graph runners.
global_graph_memory_pool
=
None
def
get_global_graph_memory_pool
():
return
global_graph_memory_pool
def
set_global_graph_memory_pool
(
val
):
global
global_graph_memory_pool
global_graph_memory_pool
=
val
class
GraphRunner
:
"""A GraphRunner is a base class to run the forward pass of a model with device graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Parse args
self
.
model_runner
=
model_runner
self
.
device
=
model_runner
.
device
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
require_gathered_buffer
=
require_gathered_buffer
(
model_runner
.
server_args
)
self
.
require_mlp_tp_gather
=
require_mlp_tp_gather
(
model_runner
.
server_args
)
self
.
require_mlp_sync
=
require_mlp_sync
(
model_runner
.
server_args
)
self
.
require_attn_tp_gather
=
require_attn_tp_gather
(
model_runner
.
server_args
)
self
.
enable_two_batch_overlap
=
(
model_runner
.
server_args
.
enable_two_batch_overlap
)
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
enable_profile_cuda_graph
=
(
model_runner
.
server_args
.
enable_profile_cuda_graph
)
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
rank0_log
(
f
"Capture graph bs
{
self
.
capture_bs
}
"
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen"
)
else
:
self
.
capture_forward_mode
=
ForwardMode
.
TARGET_VERIFY
self
.
num_tokens_per_bs
=
(
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
)
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if
model_runner
.
server_args
.
enable_return_hidden_states
:
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_bs
,
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
self
.
seq_lens_cpu
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
if
self
.
enable_torch_compile
:
set_torch_compile_config
()
if
self
.
model_runner
.
server_args
.
enable_lora
:
self
.
model_runner
.
lora_manager
.
init_cuda_graph_batch_info
(
self
.
max_bs
)
# Graph inputs
with
torch
.
device
(
self
.
device
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
(
(
self
.
max_num_token
,),
dtype
=
self
.
_cache_loc_dtype
()
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
num_token_non_padded
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
tbo_plugin
=
TboCudaGraphRunnerPlugin
()
# pipeline parallelism
if
self
.
pp_size
>
1
:
self
.
pp_proxy_tensors
=
{
"hidden_states"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
"residual"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
}
# Speculative_inference
if
model_runner
.
spec_algorithm
.
is_eagle3
():
self
.
model_runner
.
model
.
set_eagle3_layers_to_capture
()
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self
.
encoder_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
encoder_len_fill_value
,
dtype
=
torch
.
int32
)
else
:
self
.
encoder_lens
=
None
if
self
.
require_gathered_buffer
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
else
:
assert
self
.
require_attn_tp_gather
self
.
global_num_tokens_gpu
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
global_num_tokens_for_logprob_gpu
=
torch
.
zeros
(
(
1
,),
dtype
=
torch
.
int32
)
else
:
self
.
global_num_tokens_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
custom_mask
=
torch
.
ones
(
(
(
self
.
seq_lens
.
sum
().
item
()
+
self
.
max_num_token
)
*
self
.
num_tokens_per_bs
),
dtype
=
torch
.
bool
,
device
=
self
.
device
,
)
self
.
next_token_logits_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
vocab_size
),
dtype
=
torch
.
float
,
device
=
self
.
device
,
)
# Capture
try
:
with
model_capture_mode
():
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture device graph failed:
{
e
}
\n
{
GRAPH_CAPTURE_FAILED_MSG
}
"
)
def
_cache_loc_dtype
(
self
):
return
torch
.
int64
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
require_mlp_tp_gather
:
cuda_graph_bs
=
(
max
(
forward_batch
.
global_num_tokens_cpu
)
//
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
max
(
forward_batch
.
global_num_tokens_cpu
)
)
else
:
cuda_graph_bs
=
forward_batch
.
batch_size
is_bs_supported
=
(
cuda_graph_bs
in
self
.
graphs
if
self
.
disable_padding
else
cuda_graph_bs
<=
self
.
max_bs
)
if
self
.
require_mlp_sync
:
is_bs_supported
=
is_bs_supported
and
forward_batch
.
can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported
=
(
torch
.
all
(
forward_batch
.
encoder_lens
>
0
)
if
self
.
is_encoder_decoder
else
True
)
requested_capture_hidden_mode
=
max
(
forward_batch
.
capture_hidden_mode
,
(
forward_batch
.
spec_info
.
capture_hidden_mode
if
getattr
(
forward_batch
.
spec_info
,
"capture_hidden_mode"
,
None
)
is
not
None
else
CaptureHiddenMode
.
NULL
),
)
capture_hidden_mode_matches
=
(
requested_capture_hidden_mode
==
CaptureHiddenMode
.
NULL
or
requested_capture_hidden_mode
==
self
.
capture_hidden_mode
)
is_tbo_supported
=
(
forward_batch
.
can_run_tbo
if
self
.
enable_two_batch_overlap
else
True
)
return
(
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
and
capture_hidden_mode_matches
)
def
capture
(
self
)
->
None
:
profile_context
=
empty_context
()
if
self
.
enable_profile_cuda_graph
:
profile_context
=
profile
(
activities
=
[
ProfilerActivity
.
CPU
,
ProfilerActivity
.
CUDA
],
record_shapes
=
True
,
)
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with
freeze_gc
(
self
.
model_runner
.
server_args
.
enable_cudagraph_gc
),
graph_capture
()
as
graph_capture_context
:
with
profile_context
as
prof
:
self
.
stream
=
graph_capture_context
.
stream
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range
=
(
tqdm
.
tqdm
(
list
(
reversed
(
self
.
capture_bs
)))
if
get_tensor_model_parallel_rank
()
==
0
else
reversed
(
self
.
capture_bs
)
)
for
i
,
bs
in
enumerate
(
capture_range
):
if
get_tensor_model_parallel_rank
()
==
0
:
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
capture_range
.
set_description
(
f
"Capturing batches (
{
bs
=
}
{
avail_mem
=
:.
2
f
}
GB)"
)
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
num_tokens
=
bs
*
self
.
num_tokens_per_bs
,
tp_group
=
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
graph
,
output_buffers
,
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
output_buffers
[
bs
]
=
output_buffers
# Save gemlite cache after each capture
save_gemlite_cache
()
if
self
.
enable_profile_cuda_graph
:
log_message
=
(
"Sorted by CUDA Time:
\n
"
+
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"cuda_time_total"
,
row_limit
=
10
)
+
"
\n\n
Sorted by CPU Time:
\n
"
+
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"cpu_time_total"
,
row_limit
=
10
)
)
logger
.
info
(
log_message
)
def
_capture_graph
(
self
,
graph
,
pool
,
stream
,
run_once_fn
):
with
self
.
device_module
.
graph
(
graph
,
pool
=
pool
,
stream
=
stream
):
out
=
run_once_fn
()
return
out
def
_create_device_graph
(
self
):
pass
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
self
.
_create_device_graph
()
stream
=
self
.
stream
num_tokens
=
bs
*
self
.
num_tokens_per_bs
# Graph inputs
input_ids
=
self
.
input_ids
[:
num_tokens
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
encoder_lens
=
None
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
next_token_logits_buffer
=
self
.
next_token_logits_buffer
[:
num_tokens
]
self
.
num_token_non_padded
[...]
=
num_tokens
# pipeline parallelism
if
self
.
pp_size
>
1
:
pp_proxy_tensors
=
PPProxyTensors
(
{
k
:
v
[:
num_tokens
]
for
k
,
v
in
self
.
pp_proxy_tensors
.
items
()}
)
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
]
*
self
.
dp_size
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
]
*
self
.
dp_size
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
global_dp_buffer_len
=
num_tokens
*
self
.
dp_size
elif
self
.
require_attn_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
],
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
self
.
global_num_tokens_for_logprob_gpu
.
copy_
(
torch
.
tensor
(
[
num_tokens
],
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
,
)
)
global_dp_buffer_len
=
num_tokens
else
:
global_dp_buffer_len
=
None
spec_info
=
self
.
get_spec_info
(
num_tokens
)
if
self
.
capture_hidden_mode
!=
CaptureHiddenMode
.
FULL
:
self
.
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
)
if
self
.
model_runner
.
server_args
.
enable_lora
:
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
lora_ids
=
[
None
]
*
bs
else
:
lora_ids
=
None
forward_batch
=
ForwardBatch
(
forward_mode
=
self
.
capture_forward_mode
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
orig_seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
().
item
(),
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
positions
=
positions
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
DpPaddingMode
.
get_default_mode_in_cuda_graph
(),
global_dp_buffer_len
=
global_dp_buffer_len
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
num_token_non_padded
=
self
.
num_token_non_padded
,
global_forward_mode
=
self
.
capture_forward_mode
,
lora_ids
=
lora_ids
,
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
if
lora_ids
is
not
None
:
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_batch
.
forward_mode
,
forward_batch
.
spec_info
,
)
# Run and capture
def
run_once
():
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
kwargs
=
{}
if
(
self
.
pp_size
>
1
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
):
kwargs
[
"pp_proxy_tensors"
]
=
PPProxyTensors
(
{
k
:
v
.
clone
()
for
k
,
v
in
pp_proxy_tensors
.
tensors
.
items
()}
)
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
return
logits_output_or_pp_proxy_tensors
for
_
in
range
(
2
):
self
.
device_module
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
run_once
()
if
get_global_graph_memory_pool
()
is
None
:
set_global_graph_memory_pool
(
self
.
device_module
.
graph_pool_handle
())
# Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id
(
get_global_graph_memory_pool
())
out
=
self
.
_capture_graph
(
graph
,
get_global_graph_memory_pool
(),
stream
,
run_once
)
return
graph
,
out
def
recapture_if_needed
(
self
,
forward_batch
:
ForwardBatch
):
# If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch
=
(
forward_batch
.
capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info
=
getattr
(
forward_batch
.
spec_info
,
"capture_hidden_mode"
,
CaptureHiddenMode
.
NULL
)
capture_hidden_mode_required_for_returning_hidden_states
=
(
CaptureHiddenMode
.
FULL
if
self
.
model_runner
.
server_args
.
enable_return_hidden_states
else
CaptureHiddenMode
.
NULL
)
# Determine the highest capture_hidden_mode required
# (If we have FULL, we can emulate LAST or NULL)
# (If we have LAST, we can emulate NULL)
required_capture_hidden_mode
=
max
(
capture_hidden_mode_required_by_forward_batch
,
capture_hidden_mode_required_by_spec_info
,
capture_hidden_mode_required_for_returning_hidden_states
,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if
self
.
capture_hidden_mode
!=
required_capture_hidden_mode
:
self
.
capture_hidden_mode
=
required_capture_hidden_mode
self
.
capture
()
def
replay_prepare
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
):
self
.
recapture_if_needed
(
forward_batch
)
raw_bs
=
forward_batch
.
batch_size
raw_num_token
=
raw_bs
*
self
.
num_tokens_per_bs
# Pad
if
self
.
require_mlp_tp_gather
:
max_num_tokens
=
max
(
forward_batch
.
global_num_tokens_cpu
)
max_batch_size
=
(
max_num_tokens
/
self
.
num_tokens_per_bs
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
else
max_num_tokens
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max_batch_size
)
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
out_cache_loc
.
zero_
()
# Common inputs
self
.
input_ids
[:
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
seq_lens_cpu
=
None
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
if
pp_proxy_tensors
:
for
key
in
self
.
pp_proxy_tensors
.
keys
():
dim
=
pp_proxy_tensors
[
key
].
shape
[
0
]
self
.
pp_proxy_tensors
[
key
][:
dim
].
copy_
(
pp_proxy_tensors
[
key
])
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
require_gathered_buffer
:
self
.
global_num_tokens_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
self
.
global_num_tokens_for_logprob_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
if
enable_num_token_non_padded
(
self
.
model_runner
.
server_args
):
num_token_non_padded
=
forward_batch
.
num_token_non_padded
if
self
.
require_gathered_buffer
:
tokens_per_rank
=
bs
//
self
.
attn_tp_size
*
self
.
num_tokens_per_bs
num_local_token_non_padded
=
torch
.
clamp
(
num_token_non_padded
-
tokens_per_rank
*
self
.
attn_tp_rank
,
min
=
0
,
max
=
tokens_per_rank
,
)
self
.
num_token_non_padded
.
copy_
(
num_local_token_non_padded
)
else
:
self
.
num_token_non_padded
.
copy_
(
num_token_non_padded
)
if
self
.
enable_two_batch_overlap
:
self
.
tbo_plugin
.
replay_prepare
(
forward_mode
=
self
.
capture_forward_mode
,
bs
=
bs
,
num_token_non_padded
=
len
(
forward_batch
.
input_ids
),
spec_info
=
forward_batch
.
spec_info
,
)
if
forward_batch
.
forward_mode
.
is_idle
()
and
forward_batch
.
spec_info
is
not
None
:
forward_batch
.
spec_info
.
custom_mask
=
self
.
custom_mask
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
self
.
req_pool_indices
[:
bs
],
self
.
seq_lens
[:
bs
],
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
)
*
self
.
seq_len_fill_value
,
self
.
encoder_lens
[:
bs
]
if
self
.
is_encoder_decoder
else
None
,
self
.
capture_forward_mode
,
forward_batch
.
spec_info
,
seq_lens_cpu
=
seq_lens_cpu
,
)
# Store fields
self
.
raw_bs
=
raw_bs
self
.
raw_num_token
=
raw_num_token
self
.
bs
=
bs
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
,
pp_proxy_tensors
)
else
:
# In speculative decoding, these two fields are still needed.
self
.
input_ids
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
positions
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
positions
)
# Replay
self
.
graphs
[
self
.
bs
].
replay
()
output
=
self
.
output_buffers
[
self
.
bs
]
if
isinstance
(
output
,
LogitsProcessorOutput
):
return
LogitsProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
self
.
raw_num_token
],
hidden_states
=
(
output
.
hidden_states
[:
self
.
raw_num_token
]
if
output
.
hidden_states
is
not
None
else
None
),
)
else
:
assert
isinstance
(
output
,
PPProxyTensors
)
return
PPProxyTensors
({
k
:
v
[:
self
.
bs
]
for
k
,
v
in
output
.
tensors
.
items
()})
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_utils
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
else
:
spec_info
=
EagleVerifyInput
(
draft_token
=
None
,
custom_mask
=
self
.
custom_mask
,
positions
=
None
,
retrive_index
=
None
,
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
seq_lens_sum
=
None
,
seq_lens_cpu
=
None
,
)
return
spec_info
GRAPH_CAPTURE_FAILED_MSG
=
(
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
python/sglang/srt/model_executor/model_runner.py
View file @
f4fafacc
...
@@ -91,7 +91,6 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -91,7 +91,6 @@ from sglang.srt.mem_cache.memory_pool import (
)
)
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
...
@@ -342,12 +341,9 @@ class ModelRunner:
...
@@ -342,12 +341,9 @@ class ModelRunner:
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
self
.
init_cublas
()
self
.
init_cublas
()
self
.
init_attention_backend
()
self
.
init_attention_backend
()
self
.
init_device_graphs
()
self
.
init_cuda_graphs
()
elif
self
.
device
==
"npu"
:
self
.
init_attention_backend
()
self
.
init_device_graphs
()
else
:
else
:
self
.
graph_runner
=
None
self
.
cuda_
graph_runner
=
None
self
.
cuda_graph_mem_usage
=
0
self
.
cuda_graph_mem_usage
=
0
self
.
init_attention_backend
()
self
.
init_attention_backend
()
...
@@ -921,8 +917,7 @@ class ModelRunner:
...
@@ -921,8 +917,7 @@ class ModelRunner:
)
)
# We need to get device after patch otherwise the device would be wrong
# We need to get device after patch otherwise the device would be wrong
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
infered_device
=
torch
.
cuda
.
current_device
()
infered_device
=
self
.
device_module
.
current_device
()
named_tensors
=
[
named_tensors
=
[
(
name
,
_unwrap_tensor
(
tensor
,
tp_rank
=
self
.
tp_rank
,
device
=
infered_device
))
(
name
,
_unwrap_tensor
(
tensor
,
tp_rank
=
self
.
tp_rank
,
device
=
infered_device
))
...
@@ -1590,9 +1585,9 @@ class ModelRunner:
...
@@ -1590,9 +1585,9 @@ class ModelRunner:
.
cuda
()
.
cuda
()
)
)
def
init_
device
_graphs
(
self
):
def
init_
cuda
_graphs
(
self
):
"""Capture cuda graphs."""
"""Capture cuda graphs."""
self
.
graph_runner
=
None
self
.
cuda_
graph_runner
=
None
self
.
cuda_graph_mem_usage
=
0
self
.
cuda_graph_mem_usage
=
0
if
not
self
.
is_generation
:
if
not
self
.
is_generation
:
...
@@ -1607,9 +1602,8 @@ class ModelRunner:
...
@@ -1607,9 +1602,8 @@ class ModelRunner:
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
)
self
.
graph_runner
=
(
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
CudaGraphRunner
(
self
)
if
not
_is_npu
else
NPUGraphRunner
(
self
)
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
cuda_graph_mem_usage
=
before_mem
-
after_mem
self
.
cuda_graph_mem_usage
=
before_mem
-
after_mem
logger
.
info
(
logger
.
info
(
...
@@ -1761,11 +1755,11 @@ class ModelRunner:
...
@@ -1761,11 +1755,11 @@ class ModelRunner:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
can_run_cuda_graph
=
bool
(
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
forward_batch
.
forward_mode
.
is_cuda_graph
()
and
self
.
graph_runner
and
self
.
cuda_
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
and
self
.
cuda_
graph_runner
.
can_run
(
forward_batch
)
)
)
if
can_run_cuda_graph
:
if
can_run_cuda_graph
:
ret
=
self
.
graph_runner
.
replay
(
ret
=
self
.
cuda_
graph_runner
.
replay
(
forward_batch
,
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
...
...
python/sglang/srt/model_executor/npu_graph_runner.py
deleted
100644 → 0
View file @
01d47a27
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with npu graph and torch.compile."""
from
__future__
import
annotations
import
logging
import
threading
from
typing
import
TYPE_CHECKING
import
torch
from
sglang.srt.model_executor.graph_runner
import
GraphRunner
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
class
NPUGraphRunner
(
GraphRunner
):
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
(
model_runner
)
def
_create_device_graph
(
self
):
return
torch
.
npu
.
NPUGraph
()
def
_capture_graph
(
self
,
graph
,
pool
,
stream
,
run_once_fn
):
with
torch
.
npu
.
graph
(
graph
,
pool
=
pool
,
stream
=
stream
,
auto_dispatch_capture
=
True
,
):
out
=
run_once_fn
()
return
out
def
_update_inputs
(
self
,
seq_lens
):
self
.
graphs
[
self
.
bs
].
update
(
cpu_update_input
=
[{
"actual_seq_lengths_kv"
:
seq_lens
}]
)
def
_cache_loc_dtype
(
self
):
return
torch
.
int32
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
,
pp_proxy_tensors
)
else
:
# In speculative decoding, these two fields are still needed.
self
.
input_ids
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
positions
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
positions
)
# Replay
seq_lens
=
forward_batch
.
seq_lens
.
cpu
().
tolist
()
+
[
0
]
*
(
self
.
bs
-
self
.
raw_bs
)
thread
=
threading
.
Thread
(
target
=
self
.
_update_inputs
,
args
=
(
seq_lens
,))
thread
.
start
()
self
.
graphs
[
self
.
bs
].
replay
()
thread
.
join
()
output
=
self
.
output_buffers
[
self
.
bs
]
if
isinstance
(
output
,
LogitsProcessorOutput
):
return
LogitsProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
self
.
raw_num_token
],
hidden_states
=
(
output
.
hidden_states
[:
self
.
raw_num_token
]
if
output
.
hidden_states
is
not
None
else
None
),
)
else
:
assert
isinstance
(
output
,
PPProxyTensors
)
return
PPProxyTensors
({
k
:
v
[:
self
.
bs
]
for
k
,
v
in
output
.
tensors
.
items
()})
python/sglang/srt/models/deepseek_v2.py
View file @
f4fafacc
...
@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
):
):
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.
cuda_
graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
if
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
:
if
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
:
...
...
python/sglang/srt/models/glm4_moe.py
View file @
f4fafacc
...
@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
(
from
sglang.srt.models.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
DeepseekV2DecoderLayer
,
...
...
python/sglang/srt/models/mllama.py
View file @
f4fafacc
...
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
...
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.
cuda_
graph_runner
import
get_is_capture_mode
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
=
(
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
=
(
self
.
_batch_image_inputs
(
forward_batch
)
self
.
_batch_image_inputs
(
forward_batch
)
...
...
python/sglang/srt/models/qwen3.py
View file @
f4fafacc
...
@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
f4fafacc
...
@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope
...
@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.graph_runner
import
get_is_capture_mode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
f4fafacc
...
@@ -6,20 +6,20 @@ from typing import TYPE_CHECKING, Callable
...
@@ -6,20 +6,20 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CaptureHiddenMode
,
CudaGraphRunner
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.model_executor.graph_runner
import
(
GRAPH_CAPTURE_FAILED_MSG
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
model_capture_mode
,
model_capture_mode
,
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_attn_tp_gather
,
...
@@ -121,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -121,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
{
GRAPH_CAPTURE_FAILED_MSG
}
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_
GRAPH_CAPTURE_FAILED_MSG
}
"
)
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
f4fafacc
...
@@ -6,14 +6,9 @@ from typing import TYPE_CHECKING, Callable
...
@@ -6,14 +6,9 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.layers.dp_attention
import
DpPaddingMode
,
set_dp_buffer_len
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CaptureHiddenMode
,
CudaGraphRunner
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.model_executor.graph_runner
import
(
GRAPH_CAPTURE_FAILED_MSG
,
LogitsProcessorOutput
,
LogitsProcessorOutput
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
...
@@ -21,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import (
...
@@ -21,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import (
set_global_graph_memory_pool
,
set_global_graph_memory_pool
,
set_torch_compile_config
,
set_torch_compile_config
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
fast_topk
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
fast_topk
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_attn_tp_gather
,
...
@@ -149,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -149,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
{
GRAPH_CAPTURE_FAILED_MSG
}
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_
GRAPH_CAPTURE_FAILED_MSG
}
"
)
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
test/srt/run_suite.py
View file @
f4fafacc
...
@@ -229,17 +229,6 @@ suite_amd = {
...
@@ -229,17 +229,6 @@ suite_amd = {
TestFile
(
"test_wave_attention_kernels.py"
,
2
),
TestFile
(
"test_wave_attention_kernels.py"
,
2
),
TestFile
(
"test_wave_attention_backend.py"
,
150
),
TestFile
(
"test_wave_attention_backend.py"
,
150
),
],
],
"per-commit-1-ascend-npu"
:
[
TestFile
(
"test_ascend_tp1_bf16.py"
,
400
),
TestFile
(
"test_ascend_graph_tp1_bf16.py"
,
400
),
],
"per-commit-2-ascend-npu"
:
[
TestFile
(
"test_ascend_tp2_bf16.py"
,
400
),
TestFile
(
"test_ascend_graph_tp2_bf16.py"
,
400
),
],
"per-commit-4-ascend-npu"
:
[
TestFile
(
"test_ascend_mla_w8a8int8.py"
,
400
),
],
"per-commit-2-gpu-amd"
:
[
"per-commit-2-gpu-amd"
:
[
TestFile
(
"lora/test_lora_tp.py"
,
116
),
TestFile
(
"lora/test_lora_tp.py"
,
116
),
TestFile
(
"rl/test_update_weights_from_distributed.py"
,
103
),
TestFile
(
"rl/test_update_weights_from_distributed.py"
,
103
),
...
...
test/srt/test_ascend_graph_tp1_bf16.py
deleted
100644 → 0
View file @
01d47a27
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"Qwen/Qwen2.5-7B-Instruct"
:
{
"accuracy"
:
0.85
,
"latency"
:
150
,
"output_throughput"
:
30
,
},
}
class
TestAscendGraphTp1Bf16
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--mem-fraction-static"
,
0.8
,
"--attention-backend"
,
"ascend"
,
]
def
test_a_gsm8k
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_ascend_graph_tp2_bf16.py
deleted
100644 → 0
View file @
01d47a27
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"Qwen/Qwen2.5-7B-Instruct"
:
{
"accuracy"
:
0.85
,
"latency"
:
180
,
"output_throughput"
:
20
,
},
}
class
TestAscendGraphTp2Bf16
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--mem-fraction-static"
,
0.8
,
"--attention-backend"
,
"ascend"
,
"--tp-size"
,
2
,
]
def
test_a_gsm8k
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment