Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4ac8e09d
Unverified
Commit
4ac8e09d
authored
Oct 11, 2025
by
Yuwei An
Committed by
GitHub
Oct 12, 2025
Browse files
Piecewise CUDA Graph Support & Torch Compile Backend (#10062)
Signed-off-by:
Oasis-Git
<
ayw.sirius19@gmail.com
>
parent
20a6c0a6
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2647 additions
and
19 deletions
+2647
-19
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+2
-2
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+14
-2
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+10
-5
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+61
-9
python/sglang/srt/model_executor/compilation/backend.py
python/sglang/srt/model_executor/compilation/backend.py
+435
-0
python/sglang/srt/model_executor/compilation/compilation_config.py
...lang/srt/model_executor/compilation/compilation_config.py
+19
-0
python/sglang/srt/model_executor/compilation/compilation_counter.py
...ang/srt/model_executor/compilation/compilation_counter.py
+47
-0
python/sglang/srt/model_executor/compilation/compile.py
python/sglang/srt/model_executor/compilation/compile.py
+210
-0
python/sglang/srt/model_executor/compilation/compiler_interface.py
...lang/srt/model_executor/compilation/compiler_interface.py
+479
-0
python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py
.../srt/model_executor/compilation/cuda_piecewise_backend.py
+230
-0
python/sglang/srt/model_executor/compilation/fix_functionalization.py
...g/srt/model_executor/compilation/fix_functionalization.py
+134
-0
python/sglang/srt/model_executor/compilation/fx_utils.py
python/sglang/srt/model_executor/compilation/fx_utils.py
+83
-0
python/sglang/srt/model_executor/compilation/inductor_pass.py
...on/sglang/srt/model_executor/compilation/inductor_pass.py
+140
-0
python/sglang/srt/model_executor/compilation/pass_manager.py
python/sglang/srt/model_executor/compilation/pass_manager.py
+68
-0
python/sglang/srt/model_executor/compilation/piecewise_context_manager.py
...t/model_executor/compilation/piecewise_context_manager.py
+40
-0
python/sglang/srt/model_executor/compilation/weak_ref_tensor.cpp
...sglang/srt/model_executor/compilation/weak_ref_tensor.cpp
+28
-0
python/sglang/srt/model_executor/compilation/weak_ref_tensor_jit.py
...ang/srt/model_executor/compilation/weak_ref_tensor_jit.py
+16
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+55
-1
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
.../sglang/srt/model_executor/piecewise_cuda_graph_runner.py
+532
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+44
-0
No files found.
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
4ac8e09d
...
...
@@ -185,7 +185,7 @@ class CustomAllreduce:
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
full_nvlink
...
...
@@ -202,7 +202,7 @@ class CustomAllreduce:
)
handles
,
offsets
=
self
.
_gather_ipc_meta
(
shard_data
)
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
4ac8e09d
...
...
@@ -239,6 +239,7 @@ class GroupCoordinator:
use_npu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
):
# Set group info
group_name
=
group_name
or
"anonymous"
...
...
@@ -326,10 +327,18 @@ class GroupCoordinator:
self
.
qr_comm
:
Optional
[
QuickAllReduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
if
torch_compile
is
not
None
and
torch_compile
:
# For piecewise CUDA graph, the requirement for custom allreduce is larger to
# avoid illegal cuda memory access.
ca_max_size
=
256
*
1024
*
1024
else
:
ca_max_size
=
8
*
1024
*
1024
try
:
# print(f"ca_max_size: {ca_max_size}")
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
max_size
=
ca_max_size
,
)
except
Exception
as
e
:
logger
.
warning
(
...
...
@@ -1260,6 +1269,7 @@ def init_model_parallel_group(
group_name
:
Optional
[
str
]
=
None
,
use_mscclpp_allreduce
:
Optional
[
bool
]
=
None
,
use_symm_mem_allreduce
:
Optional
[
bool
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
...
...
@@ -1280,6 +1290,7 @@ def init_model_parallel_group(
use_npu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
torch_compile
=
torch_compile
,
)
...
...
@@ -1439,6 +1450,7 @@ def initialize_model_parallel(
pipeline_model_parallel_size
:
int
=
1
,
backend
:
Optional
[
str
]
=
None
,
duplicate_tp_group
:
bool
=
False
,
torch_compile
:
Optional
[
bool
]
=
None
,
)
->
None
:
"""
Initialize model parallel groups.
...
...
@@ -1494,6 +1506,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER"
,
"true"
),
group_name
=
"tp"
,
torch_compile
=
torch_compile
,
)
if
duplicate_tp_group
:
...
...
@@ -1509,6 +1522,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER"
,
"true"
),
group_name
=
"pdmux_prefill_tp"
,
torch_compile
=
torch_compile
,
)
_TP
.
pynccl_comm
.
disabled
=
False
_PDMUX_PREFILL_TP_GROUP
.
pynccl_comm
.
disabled
=
False
...
...
@@ -1518,7 +1532,6 @@ def initialize_model_parallel(
global
_MOE_EP
assert
_MOE_EP
is
None
,
"expert model parallel group is already initialized"
if
moe_ep_size
==
tensor_model_parallel_size
:
_MOE_EP
=
_TP
else
:
...
...
@@ -1539,7 +1552,6 @@ def initialize_model_parallel(
global
_MOE_TP
assert
_MOE_TP
is
None
,
"expert model parallel group is already initialized"
if
moe_tp_size
==
tensor_model_parallel_size
:
_MOE_TP
=
_TP
else
:
...
...
python/sglang/srt/layers/layernorm.py
View file @
4ac8e09d
...
...
@@ -43,11 +43,16 @@ _is_cpu = is_cpu()
_is_xpu
=
is_xpu
()
if
_is_cuda
:
if
_is_flashinfer_available
:
from
flashinfer.norm
import
fused_add_rmsnorm
else
:
from
sgl_kernel
import
fused_add_rmsnorm
from
sgl_kernel
import
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
rmsnorm
# if _is_flashinfer_available:
# from flashinfer.norm import fused_add_rmsnorm
# else:
from
sgl_kernel
import
(
fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
rmsnorm
,
)
if
_use_aiter
:
from
aiter
import
rmsnorm2d_fwd
as
rms_norm
...
...
python/sglang/srt/layers/radix_attention.py
View file @
4ac8e09d
...
...
@@ -17,12 +17,18 @@ from __future__ import annotations
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
torch
import
nn
if
TYPE_CHECKING
:
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.compilation.piecewise_context_manager
import
(
get_forward_context
,
)
from
sglang.srt.utils
import
direct_register_custom_op
class
AttentionType
(
Enum
):
"""
...
...
@@ -105,12 +111,58 @@ class RadixAttention(nn.Module):
else
:
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
v_head_dim
)
return
forward_batch
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
forward_batch
,
save_kv_cache
,
**
kwargs
,
)
if
forward_batch
.
forward_mode
.
is_extend
()
and
get_forward_context
()
is
not
None
:
output
=
torch
.
zeros_like
(
q
)
torch
.
ops
.
sglang
.
unified_attention_with_output
(
q
,
k
,
v
,
output
,
save_kv_cache
,
self
.
layer_id
)
return
output
else
:
return
forward_batch
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
forward_batch
,
save_kv_cache
,
**
kwargs
,
)
def
unified_attention_with_output
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
save_kv_cache
:
bool
,
layer_id
:
int
,
)
->
None
:
context
=
get_forward_context
()
forward_batch
=
context
.
forward_batch
attention_layers
=
context
.
attention_layers
attention_layer
=
attention_layers
[
layer_id
]
ret
=
forward_batch
.
attn_backend
.
forward
(
query
,
key
,
value
,
attention_layer
,
forward_batch
,
save_kv_cache
)
assert
output
.
shape
==
ret
.
shape
output
.
copy_
(
ret
)
return
def
unified_attention_with_output_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
save_kv_cache
:
bool
,
layer_id
:
int
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"unified_attention_with_output"
,
op_func
=
unified_attention_with_output
,
mutates_args
=
[
"output"
],
fake_impl
=
unified_attention_with_output_fake
,
)
python/sglang/srt/model_executor/compilation/backend.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
import
ast
import
dataclasses
import
logging
import
os
import
pprint
import
time
from
collections.abc
import
Sequence
from
contextlib
import
contextmanager
from
typing
import
Any
,
Callable
,
Optional
import
torch
import
torch.fx
as
fx
from
torch._dispatch.python
import
enable_python_dispatcher
from
sglang.srt.model_executor.compilation.compilation_config
import
CompilationConfig
from
sglang.srt.model_executor.compilation.compilation_counter
import
(
compilation_counter
,
)
from
sglang.srt.model_executor.compilation.compiler_interface
import
InductorAdaptor
from
sglang.srt.model_executor.compilation.cuda_piecewise_backend
import
(
CUDAPiecewiseBackend
,
)
from
sglang.srt.model_executor.compilation.pass_manager
import
PostGradPassManager
logger
=
logging
.
getLogger
(
__name__
)
def
make_compiler
():
return
InductorAdaptor
()
class
CompilerManager
:
def
__init__
(
self
,
):
self
.
cache
=
dict
()
self
.
is_cache_updated
=
False
self
.
compiler
=
make_compiler
()
def
compute_hash
(
self
):
return
self
.
compiler
.
compute_hash
()
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
self
.
disable_cache
=
disable_cache
self
.
cache_dir
=
cache_dir
self
.
cache_file_path
=
os
.
path
.
join
(
cache_dir
,
"sglang_compile_cache.py"
)
if
not
disable_cache
and
os
.
path
.
exists
(
self
.
cache_file_path
):
with
open
(
self
.
cache_file_path
)
as
f
:
self
.
cache
=
ast
.
literal_eval
(
f
.
read
())
self
.
compiler
.
initialize_cache
(
cache_dir
=
cache_dir
,
disable_cache
=
disable_cache
,
prefix
=
prefix
)
def
save_to_file
(
self
):
if
self
.
disable_cache
or
not
self
.
is_cache_updated
:
return
printer
=
pprint
.
PrettyPrinter
(
indent
=
4
)
data
=
printer
.
pformat
(
self
.
cache
)
with
open
(
self
.
cache_file_path
,
"w"
)
as
f
:
f
.
write
(
data
)
def
load
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
graph_index
:
int
,
runtime_shape
:
Optional
[
int
]
=
None
,
)
->
Optional
[
Callable
]:
handle
=
self
.
cache
[(
runtime_shape
,
graph_index
,
self
.
compiler
.
name
)]
compiled_graph
=
self
.
compiler
.
load
(
handle
,
graph
,
example_inputs
,
graph_index
,
runtime_shape
)
if
runtime_shape
is
None
:
logger
.
debug
(
"Directly load the %s-th graph for dynamic shape from %s via "
"handle %s"
,
graph_index
,
self
.
compiler
.
name
,
handle
,
)
else
:
logger
.
debug
(
"Directly load the %s-th graph for shape %s from %s via "
"handle %s"
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
handle
,
)
return
compiled_graph
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
,
inductor_config
:
dict
[
str
,
Any
],
graph_index
:
int
=
0
,
num_graphs
:
int
=
1
,
runtime_shape
:
Optional
[
int
]
=
None
,
)
->
Any
:
if
graph_index
==
0
:
# before compiling the first graph, record the start time
global
compilation_start_time
compilation_start_time
=
time
.
time
()
compilation_counter
.
num_backend_compilations
+=
1
compiled_graph
=
None
# TODO(Yuwei): support cache loading
# no compiler cached the graph, or the cache is disabled,
# we need to compile it
if
isinstance
(
self
.
compiler
,
InductorAdaptor
):
maybe_key
=
None
else
:
maybe_key
=
f
"artifact_shape_
{
runtime_shape
}
_subgraph_
{
graph_index
}
"
compiled_graph
,
handle
=
self
.
compiler
.
compile
(
graph
,
example_inputs
,
inductor_config
,
runtime_shape
,
maybe_key
)
assert
compiled_graph
is
not
None
,
"Failed to compile the graph"
# store the artifact in the cache
if
handle
is
not
None
:
self
.
cache
[(
runtime_shape
,
graph_index
,
self
.
compiler
.
name
)]
=
handle
compilation_counter
.
num_cache_entries_updated
+=
1
self
.
is_cache_updated
=
True
if
graph_index
==
0
:
# adds some info logging for the first graph
if
runtime_shape
is
None
:
logger
.
info
(
"Cache the graph for dynamic shape for later use"
)
else
:
logger
.
info
(
"Cache the graph of shape %s for later use"
,
str
(
runtime_shape
)
)
if
runtime_shape
is
None
:
logger
.
debug
(
"Store the %s-th graph for dynamic shape from %s via "
"handle %s"
,
graph_index
,
self
.
compiler
.
name
,
handle
,
)
else
:
logger
.
debug
(
"Store the %s-th graph for shape %s from %s via handle %s"
,
graph_index
,
str
(
runtime_shape
),
self
.
compiler
.
name
,
handle
,
)
# after compiling the last graph, record the end time
if
graph_index
==
num_graphs
-
1
:
now
=
time
.
time
()
elapsed
=
now
-
compilation_start_time
if
runtime_shape
is
None
:
logger
.
info
(
"Compiling a graph for dynamic shape takes %.2f s"
,
elapsed
)
else
:
logger
.
info
(
"Compiling a graph for shape %s takes %.2f s"
,
runtime_shape
,
elapsed
,
)
return
compiled_graph
@
dataclasses
.
dataclass
class
SplitItem
:
submod_name
:
str
graph_id
:
int
is_splitting_graph
:
bool
graph
:
fx
.
GraphModule
def
split_graph
(
graph
:
fx
.
GraphModule
,
ops
:
list
[
str
]
)
->
tuple
[
fx
.
GraphModule
,
list
[
SplitItem
]]:
# split graph by ops
subgraph_id
=
0
node_to_subgraph_id
=
{}
split_op_graphs
=
[]
for
node
in
graph
.
graph
.
nodes
:
if
node
.
op
in
(
"output"
,
"placeholder"
):
continue
if
node
.
op
==
"call_function"
and
str
(
node
.
target
)
in
ops
:
subgraph_id
+=
1
node_to_subgraph_id
[
node
]
=
subgraph_id
split_op_graphs
.
append
(
subgraph_id
)
subgraph_id
+=
1
else
:
node_to_subgraph_id
[
node
]
=
subgraph_id
# `keep_original_order` is important!
# otherwise pytorch might reorder the nodes and
# the semantics of the graph will change when we
# have mutations in the graph
split_gm
=
torch
.
fx
.
passes
.
split_module
.
split_module
(
graph
,
None
,
lambda
node
:
node_to_subgraph_id
[
node
],
keep_original_order
=
True
)
outputs
=
[]
names
=
[
name
for
(
name
,
module
)
in
split_gm
.
named_modules
()]
for
name
in
names
:
if
"."
in
name
or
name
==
""
:
# recursive child module or the root module
continue
module
=
getattr
(
split_gm
,
name
)
graph_id
=
int
(
name
.
replace
(
"submod_"
,
""
))
outputs
.
append
(
SplitItem
(
name
,
graph_id
,
(
graph_id
in
split_op_graphs
),
module
))
# sort by intetger graph_id, rather than string name
outputs
.
sort
(
key
=
lambda
x
:
x
.
graph_id
)
return
split_gm
,
outputs
# we share the global graph pool among all the backends
global_graph_pool
=
None
compilation_start_time
=
0.0
class
PiecewiseCompileInterpreter
(
torch
.
fx
.
Interpreter
):
def
__init__
(
self
,
module
:
torch
.
fx
.
GraphModule
,
compile_submod_names
:
list
[
str
],
inductor_config
:
dict
[
str
,
Any
],
graph_pool
,
compile_config
:
CompilationConfig
,
sglang_backend
:
"SGLangBackend"
,
):
super
().
__init__
(
module
)
from
torch._guards
import
detect_fake_mode
self
.
fake_mode
=
detect_fake_mode
()
self
.
compile_submod_names
=
compile_submod_names
self
.
graph_pool
=
graph_pool
self
.
sglang_backend
=
sglang_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self
.
extra_traceback
=
False
self
.
inductor_config
=
inductor_config
self
.
compile_config
=
compile_config
def
run
(
self
,
*
args
):
fake_args
=
[
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
for
t
in
args
]
with
self
.
fake_mode
,
enable_python_dispatcher
():
return
super
().
run
(
*
fake_args
)
def
call_module
(
self
,
target
:
torch
.
fx
.
node
.
Target
,
args
:
tuple
[
torch
.
fx
.
node
.
Argument
,
...],
kwargs
:
dict
[
str
,
Any
],
)
->
Any
:
assert
isinstance
(
target
,
str
)
output
=
super
().
call_module
(
target
,
args
,
kwargs
)
if
target
in
self
.
compile_submod_names
:
index
=
self
.
compile_submod_names
.
index
(
target
)
submod
=
self
.
fetch_attr
(
target
)
sym_shape_indices
=
[
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
]
global
compilation_start_time
compiled_graph_for_dynamic_shape
=
(
self
.
sglang_backend
.
compiler_manager
.
compile
(
submod
,
args
,
self
.
inductor_config
,
graph_index
=
index
,
num_graphs
=
len
(
self
.
compile_submod_names
),
runtime_shape
=
None
,
)
)
self
.
module
.
__dict__
[
target
]
=
CUDAPiecewiseBackend
(
submod
,
self
.
compile_config
,
self
.
inductor_config
,
self
.
graph_pool
,
index
,
len
(
self
.
compile_submod_names
),
sym_shape_indices
,
compiled_graph_for_dynamic_shape
,
self
.
sglang_backend
,
)
compilation_counter
.
num_piecewise_capturable_graphs_seen
+=
1
return
output
model_tag
:
str
=
"backbone"
@
contextmanager
def
set_model_tag
(
tag
:
str
):
"""Context manager to set the model tag."""
global
model_tag
assert
(
tag
!=
model_tag
),
f
"Model tag
{
tag
}
is the same as the current tag
{
model_tag
}
."
old_tag
=
model_tag
model_tag
=
tag
try
:
yield
finally
:
model_tag
=
old_tag
class
SGLangBackend
:
graph_pool
:
Any
_called
:
bool
=
False
# the graph we compiled
graph
:
fx
.
GraphModule
# the stiching graph module for all the piecewise graphs
split_gm
:
fx
.
GraphModule
piecewise_graphs
:
list
[
SplitItem
]
returned_callable
:
Callable
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes
:
Sequence
[
Callable
]
sym_tensor_indices
:
list
[
int
]
input_buffers
:
list
[
torch
.
Tensor
]
compiler_manager
:
CompilerManager
def
__init__
(
self
,
config
:
CompilationConfig
,
graph_pool
:
Any
,
):
assert
graph_pool
is
not
None
self
.
graph_pool
=
graph_pool
self
.
post_grad_pass_manager
=
PostGradPassManager
()
self
.
sym_tensor_indices
=
[]
self
.
input_buffers
=
[]
self
.
compiler_manager
=
CompilerManager
()
self
.
inductor_config
=
{
"enable_auto_functionalized_v2"
:
False
,
}
self
.
compile_config
=
config
def
configure_post_pass
(
self
):
self
.
post_grad_pass_manager
.
configure
()
self
.
inductor_config
[
"post_grad_custom_post_pass"
]
=
self
.
post_grad_pass_manager
def
__call__
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
)
->
Callable
:
base_cache_dir
=
os
.
path
.
expanduser
(
os
.
getenv
(
"SGLANG_CACHE_DIR"
,
"~/.cache/sglang/"
)
)
cache_hash
=
self
.
compiler_manager
.
compute_hash
()
cache_dir
=
os
.
path
.
join
(
base_cache_dir
,
"torch_compile_cache"
,
cache_hash
,
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
rank
=
0
dp_rank
=
0
local_cache_dir
=
os
.
path
.
join
(
cache_dir
,
f
"rank_
{
rank
}
_
{
dp_rank
}
"
,
model_tag
)
os
.
makedirs
(
local_cache_dir
,
exist_ok
=
True
)
self
.
compiler_manager
.
initialize_cache
(
local_cache_dir
,
disable_cache
=
False
,
prefix
=
""
)
compilation_counter
.
num_graphs_seen
+=
1
assert
not
self
.
_called
,
"SGLangBackend can only be called once"
self
.
graph
=
graph
self
.
configure_post_pass
()
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
graph
,
[
"sglang.unified_attention_with_output"
]
)
from
torch._dynamo.utils
import
lazy_format_graph_code
# depyf will hook lazy_format_graph_code and dump the graph
# for debugging, no need to print the graph here
lazy_format_graph_code
(
"before split"
,
self
.
graph
)
lazy_format_graph_code
(
"after split"
,
self
.
split_gm
)
compilation_counter
.
num_piecewise_graphs_seen
+=
len
(
self
.
piecewise_graphs
)
submod_names_to_compile
=
[
item
.
submod_name
for
item
in
self
.
piecewise_graphs
if
not
item
.
is_splitting_graph
]
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
inductor_config
,
self
.
graph_pool
,
self
.
compile_config
,
self
,
).
run
(
*
example_inputs
)
graph_path
=
os
.
path
.
join
(
local_cache_dir
,
"computation_graph.py"
)
if
not
os
.
path
.
exists
(
graph_path
):
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
# use `print_readable` because it can include submodules
src
=
(
"from __future__ import annotations
\n
import torch
\n
"
+
self
.
split_gm
.
print_readable
(
print_output
=
False
)
)
src
=
src
.
replace
(
"<lambda>"
,
"GraphModule"
)
with
open
(
graph_path
,
"w"
)
as
f
:
f
.
write
(
src
)
logger
.
debug
(
"Computation graph saved to %s"
,
graph_path
)
self
.
_called
=
True
return
self
.
split_gm
python/sglang/srt/model_executor/compilation/compilation_config.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
from
typing
import
List
# TODO(Yuwei): support better compile config support
class
CompilationConfig
:
def
__init__
(
self
,
capture_sizes
:
List
[
int
]):
self
.
traced_files
=
set
()
self
.
capture_sizes
=
capture_sizes
def
add_traced_file
(
self
,
file_path
:
str
):
self
.
traced_files
.
add
(
file_path
)
def
get_traced_files
(
self
):
return
self
.
traced_files
def
get_capture_sizes
(
self
):
return
self
.
capture_sizes
python/sglang/srt/model_executor/compilation/compilation_counter.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
import
copy
import
dataclasses
from
contextlib
import
contextmanager
@
dataclasses
.
dataclass
class
CompilationCounter
:
num_models_seen
:
int
=
0
num_graphs_seen
:
int
=
0
# including the splitting ops
num_piecewise_graphs_seen
:
int
=
0
# not including the splitting ops
num_piecewise_capturable_graphs_seen
:
int
=
0
num_backend_compilations
:
int
=
0
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
num_gpu_runner_capture_triggers
:
int
=
0
# Number of CUDAGraphs captured
num_cudagraph_captured
:
int
=
0
# InductorAdapter.compile calls
num_inductor_compiles
:
int
=
0
# EagerAdapter.compile calls
num_eager_compiles
:
int
=
0
# The number of time vLLM's compiler cache entry was updated
num_cache_entries_updated
:
int
=
0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved
:
int
=
0
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
dynamo_as_is_count
:
int
=
0
def
clone
(
self
)
->
"CompilationCounter"
:
return
copy
.
deepcopy
(
self
)
@
contextmanager
def
expect
(
self
,
**
kwargs
):
old
=
self
.
clone
()
yield
for
k
,
v
in
kwargs
.
items
():
assert
getattr
(
self
,
k
)
-
getattr
(
old
,
k
)
==
v
,
(
f
"
{
k
}
not as expected, before it is
{
getattr
(
old
,
k
)
}
"
f
", after it is
{
getattr
(
self
,
k
)
}
, "
f
"expected diff is
{
v
}
"
)
compilation_counter
=
CompilationCounter
()
python/sglang/srt/model_executor/compilation/compile.py
0 → 100644
View file @
4ac8e09d
import
contextvars
import
inspect
import
logging
import
os
import
sys
import
types
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
from
sglang.srt.model_executor.compilation.compilation_config
import
CompilationConfig
logger
=
logging
.
getLogger
(
__name__
)
_COMPILE_ENABLED
=
contextvars
.
ContextVar
(
"_COMPILE_ENABLED"
,
default
=
False
)
@
contextmanager
def
set_compiled
(
enabled
:
bool
=
True
):
token
=
_COMPILE_ENABLED
.
set
(
enabled
)
try
:
yield
finally
:
_COMPILE_ENABLED
.
reset
(
token
)
@
dataclass
class
IntermediateTensors
:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
Each stage also needs to handle its own finished_sending and
finished_recving in case of kv transfer.
"""
tensors
:
dict
[
str
,
torch
.
Tensor
]
# [req_ids]
finished_sending
:
Optional
[
set
[
str
]]
=
None
finished_recving
:
Optional
[
set
[
str
]]
=
None
def
__init__
(
self
,
tensors
):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self
.
tensors
=
tensors
def
__getitem__
(
self
,
key
:
Union
[
str
,
slice
]):
if
isinstance
(
key
,
str
):
return
self
.
tensors
[
key
]
elif
isinstance
(
key
,
slice
):
return
self
.
__class__
({
k
:
v
[
key
]
for
k
,
v
in
self
.
tensors
.
items
()})
def
__setitem__
(
self
,
key
:
str
,
value
:
torch
.
Tensor
):
self
.
tensors
[
key
]
=
value
def
items
(
self
):
return
self
.
tensors
.
items
()
def
__len__
(
self
):
return
len
(
self
.
tensors
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
def
__repr__
(
self
)
->
str
:
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
def
_normalize_dims
(
dims
,
ndim
:
int
):
dims
=
[
dims
]
if
isinstance
(
dims
,
int
)
else
list
(
dims
)
return
[
d
if
d
>=
0
else
ndim
+
d
for
d
in
dims
]
class
_MaybeIntermediateTensors
:
"""Duck-typed check to support your IntermediateTensors without importing."""
def
__init__
(
self
,
obj
):
self
.
is_intermediate
=
hasattr
(
obj
,
"tensors"
)
and
isinstance
(
getattr
(
obj
,
"tensors"
),
dict
)
self
.
obj
=
obj
def
_mark_dynamic_on_value
(
val
,
dims
):
if
isinstance
(
val
,
torch
.
Tensor
):
torch
.
_dynamo
.
mark_dynamic
(
val
,
_normalize_dims
(
dims
,
val
.
ndim
))
else
:
mit
=
_MaybeIntermediateTensors
(
val
)
if
mit
.
is_intermediate
:
for
t
in
mit
.
obj
.
tensors
.
values
():
torch
.
_dynamo
.
mark_dynamic
(
t
,
_normalize_dims
(
dims
,
t
.
ndim
))
# else: ignore (None or non-tensor)
def
_infer_dynamic_arg_dims_from_annotations
(
forward_fn
):
sig
=
inspect
.
signature
(
forward_fn
)
dyn
=
{}
for
name
,
p
in
sig
.
parameters
.
items
():
ann
=
p
.
annotation
# Accept torch.Tensor / Optional[torch.Tensor] / your IntermediateTensors types by name
if
(
ann
is
torch
.
Tensor
or
getattr
(
getattr
(
ann
,
"__args__"
,
[
None
])[
0
],
"__name__"
,
""
)
==
"Tensor"
):
dyn
[
name
]
=
0
elif
getattr
(
ann
,
"__name__"
,
""
)
in
(
"IntermediateTensors"
,)
or
any
(
getattr
(
a
,
"__name__"
,
""
)
==
"IntermediateTensors"
for
a
in
getattr
(
ann
,
"__args__"
,
[])
):
dyn
[
name
]
=
0
if
not
dyn
:
raise
ValueError
(
"No dynamic dims inferred; pass dynamic_arg_dims explicitly."
)
return
dyn
def
install_torch_compiled
(
module
:
torch
.
nn
.
Module
,
*
,
dynamic_arg_dims
:
dict
[
str
,
Union
[
int
,
list
[
int
]]]
|
None
=
None
,
backend_factory
:
Optional
[
Callable
[[
torch
.
fx
.
GraphModule
,
list
],
Callable
]]
=
None
,
compile_config
:
CompilationConfig
=
None
,
fullgraph
:
bool
=
True
,
graph_pool
:
Any
=
None
,
):
unbound_fwd
=
module
.
__class__
.
forward
if
not
callable
(
unbound_fwd
):
raise
TypeError
(
"module.__class__.forward must be callable"
)
original_code
=
unbound_fwd
.
__code__
dyn_map
=
dynamic_arg_dims
or
_infer_dynamic_arg_dims_from_annotations
(
unbound_fwd
)
if
backend_factory
is
None
:
from
sglang.srt.model_executor.compilation.backend
import
SGLangBackend
backend_factory
=
lambda
gm
,
ex
:
SGLangBackend
(
compile_config
,
graph_pool
)(
gm
,
ex
)
compiled_codes
:
list
[
type
(
original_code
)]
=
[]
state
=
{
"compiled"
:
False
,
"compiled_callable"
:
None
}
def
bytecode_hook
(
old_code
,
new_code
):
if
old_code
is
not
original_code
:
return
frame
=
sys
.
_getframe
()
while
frame
and
frame
.
f_back
:
frame
=
frame
.
f_back
if
(
frame
.
f_code
.
co_name
==
"_compile"
and
os
.
path
.
basename
(
frame
.
f_code
.
co_filename
)
==
"convert_frame.py"
):
break
try
:
dynamo_frame
=
frame
.
f_locals
[
"frame"
]
except
Exception
:
return
if
dynamo_frame
.
f_code
is
not
old_code
:
return
if
dynamo_frame
.
f_locals
.
get
(
"self"
)
is
not
module
:
return
compiled_codes
.
append
(
new_code
)
torch
.
_dynamo
.
convert_frame
.
register_bytecode_hook
(
bytecode_hook
)
def
_ensure_compiled
(
self
,
*
args
,
**
kwargs
):
"""Compile on first use (with flag ON)."""
if
state
[
"compiled"
]:
return
# Mark dynamic dims only when we are about to compile
sig
=
inspect
.
signature
(
unbound_fwd
)
ba
=
sig
.
bind
(
self
,
*
args
,
**
kwargs
)
ba
.
apply_defaults
()
for
name
,
dims
in
(
dyn_map
or
{}).
items
():
if
name
in
ba
.
arguments
:
val
=
ba
.
arguments
[
name
]
if
val
is
not
None
:
_mark_dynamic_on_value
(
val
,
dims
)
# Avoid cross-instance cache reuse
torch
.
_dynamo
.
eval_frame
.
remove_from_cache
(
unbound_fwd
.
__code__
)
bound
=
types
.
MethodType
(
unbound_fwd
,
self
)
compiled_callable
=
torch
.
compile
(
bound
,
fullgraph
=
fullgraph
,
backend
=
backend_factory
)
# Trigger Dynamo so bytecode hook can capture
compiled_callable
(
*
args
,
**
kwargs
)
state
[
"compiled"
]
=
True
state
[
"compiled_callable"
]
=
compiled_callable
def
trampoline
(
self
,
*
args
,
**
kwargs
):
use_compiled
=
_COMPILE_ENABLED
.
get
()
if
use_compiled
:
if
not
state
[
"compiled"
]:
_ensure_compiled
(
self
,
*
args
,
**
kwargs
)
compiled_callable
=
state
[
"compiled_callable"
]
return
compiled_callable
(
*
args
,
**
kwargs
)
else
:
# Explicitly run the original uncompiled forward
return
unbound_fwd
(
self
,
*
args
,
**
kwargs
)
module
.
forward
=
types
.
MethodType
(
trampoline
,
module
)
return
module
python/sglang/srt/model_executor/compilation/compiler_interface.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compiler_interface.py
import
contextlib
import
copy
import
hashlib
import
os
from
contextlib
import
ExitStack
from
typing
import
Any
,
Callable
,
Optional
from
unittest.mock
import
patch
import
torch
import
torch._inductor.compile_fx
import
torch.fx
as
fx
from
sglang.srt.model_executor.compilation.compilation_counter
import
(
compilation_counter
,
)
from
sglang.srt.model_executor.compilation.inductor_pass
import
pass_context
class
CompilerInterface
:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name
:
str
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.
e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
"""
pass
def
compute_hash
(
self
)
->
str
:
"""
Gather all the relevant information from the vLLM config,
to compute a hash so that we can cache the compiled model.
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""
return
""
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
compiler_config
:
dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
the `example_inputs` have a dynamic shape. Otherwise, the
`runtime_shape` specifies the shape of the inputs. Right now we only
support one variable shape for all inputs, which is the batchsize
(number of tokens) during inference.
Dynamo will make sure `graph(*example_inputs)` is valid.
The function should return a compiled callable function, as well as
a handle that can be used to directly load the compiled function.
The handle should be a plain Python object, preferably a string or a
file path for readability.
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
"""
return
None
,
None
def
load
(
self
,
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
graph_index
:
int
,
runtime_shape
:
Optional
[
int
]
=
None
,
)
->
Callable
:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
The handle is the second return value of the `compile` function.
"""
raise
NotImplementedError
(
"caching is not supported"
)
def
get_inductor_factors
()
->
list
[
Any
]:
factors
:
list
[
Any
]
=
[]
# summarize system state
from
torch._inductor.codecache
import
CacheBase
system_factors
=
CacheBase
.
get_system
()
factors
.
append
(
system_factors
)
# summarize pytorch state
from
torch._inductor.codecache
import
torch_key
torch_factors
=
torch_key
()
factors
.
append
(
torch_factors
)
return
factors
class
AlwaysHitShapeEnv
:
"""
Why do we need this class:
For normal `torch.compile` usage, every compilation will have
one Dynamo bytecode compilation and one Inductor compilation.
The Inductor compilation happens under the context of the
Dynamo bytecode compilation, and that context is used to
determine the dynamic shape information, etc.
For our use case, we only run Dynamo bytecode compilation once,
and run Inductor compilation multiple times with different shapes
plus a general shape. The compilation for specific shapes happens
outside of the context of the Dynamo bytecode compilation. At that
time, we don't have shape environment to provide to Inductor, and
it will fail the Inductor code cache lookup.
By providing a dummy shape environment that always hits, we can
make the Inductor code cache lookup always hit, and we can
compile the graph for different shapes as needed.
The following dummy methods are obtained by trial-and-error
until it works.
"""
def
__init__
(
self
)
->
None
:
self
.
guards
:
list
[
Any
]
=
[]
def
evaluate_guards_expression
(
self
,
*
args
,
**
kwargs
):
return
True
def
get_pruned_guards
(
self
,
*
args
,
**
kwargs
):
return
[]
def
produce_guards_expression
(
self
,
*
args
,
**
kwargs
):
return
""
class
InductorAdaptor
(
CompilerInterface
):
"""
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name
=
"inductor"
def
compute_hash
(
self
)
->
str
:
factors
=
get_inductor_factors
()
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()[:
10
]
return
hash_str
def
initialize_cache
(
self
,
cache_dir
:
str
,
disable_cache
:
bool
=
False
,
prefix
:
str
=
""
):
self
.
cache_dir
=
cache_dir
self
.
prefix
=
prefix
self
.
base_cache_dir
=
cache_dir
[:
-
len
(
prefix
)]
if
prefix
else
cache_dir
if
disable_cache
:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache
=
os
.
path
.
join
(
self
.
base_cache_dir
,
"inductor_cache"
)
os
.
makedirs
(
inductor_cache
,
exist_ok
=
True
)
os
.
environ
[
"TORCHINDUCTOR_CACHE_DIR"
]
=
inductor_cache
triton_cache
=
os
.
path
.
join
(
self
.
base_cache_dir
,
"triton_cache"
)
os
.
makedirs
(
triton_cache
,
exist_ok
=
True
)
os
.
environ
[
"TRITON_CACHE_DIR"
]
=
triton_cache
def
compile
(
self
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
compiler_config
:
dict
[
str
,
Any
],
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
compilation_counter
.
num_inductor_compiles
+=
1
from
torch._inductor.compile_fx
import
compile_fx
current_config
=
{}
if
compiler_config
is
not
None
:
current_config
.
update
(
compiler_config
)
# disable remote cache
current_config
[
"fx_graph_cache"
]
=
True
current_config
[
"fx_graph_remote_cache"
]
=
False
set_inductor_config
(
current_config
,
runtime_shape
)
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
graph
=
copy
.
deepcopy
(
graph
)
# it's the first time we compile this graph
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
hash_str
,
file_path
=
None
,
None
from
torch._inductor.codecache
import
FxGraphCache
,
compiled_fx_graph_hash
if
torch
.
__version__
.
startswith
(
"2.5"
):
original_load
=
FxGraphCache
.
load
original_load_name
=
"torch._inductor.codecache.FxGraphCache.load"
def
hijack_load
(
*
args
,
**
kwargs
):
inductor_compiled_graph
=
original_load
(
*
args
,
**
kwargs
)
nonlocal
file_path
compiled_fn
=
inductor_compiled_graph
.
current_callable
file_path
=
compiled_fn
.
__code__
.
co_filename
# noqa
if
not
file_path
.
startswith
(
self
.
base_cache_dir
):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for
cell
in
compiled_fn
.
__closure__
:
if
not
callable
(
cell
.
cell_contents
):
continue
if
cell
.
cell_contents
.
__code__
.
co_filename
.
startswith
(
self
.
base_cache_dir
):
# this is the real file path compiled from Inductor
file_path
=
cell
.
cell_contents
.
__code__
.
co_filename
break
return
inductor_compiled_graph
hijacked_compile_fx_inner
=
(
torch
.
_inductor
.
compile_fx
.
compile_fx_inner
)
# noqa
elif
torch
.
__version__
>=
"2.6"
:
# function renamed in 2.6
original_load_name
=
None
def
hijacked_compile_fx_inner
(
*
args
,
**
kwargs
):
output
=
torch
.
_inductor
.
compile_fx
.
compile_fx_inner
(
*
args
,
**
kwargs
)
nonlocal
hash_str
inductor_compiled_graph
=
output
if
inductor_compiled_graph
is
not
None
:
nonlocal
file_path
compiled_fn
=
inductor_compiled_graph
.
current_callable
file_path
=
compiled_fn
.
__code__
.
co_filename
# noqa
if
not
file_path
.
startswith
(
self
.
base_cache_dir
):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for
cell
in
compiled_fn
.
__closure__
:
if
not
callable
(
cell
.
cell_contents
):
continue
code
=
cell
.
cell_contents
.
__code__
if
code
.
co_filename
.
startswith
(
self
.
base_cache_dir
):
# this is the real file path
# compiled from Inductor
file_path
=
code
.
co_filename
break
hash_str
=
inductor_compiled_graph
.
_fx_graph_cache_key
return
output
def
hijack_compiled_fx_graph_hash
(
*
args
,
**
kwargs
):
out
=
compiled_fx_graph_hash
(
*
args
,
**
kwargs
)
nonlocal
hash_str
hash_str
=
out
[
0
]
return
out
def
_check_can_cache
(
*
args
,
**
kwargs
):
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
# with high-order ops.
# For vLLM, in either case, we want to cache the graph.
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return
def
_get_shape_env
()
->
AlwaysHitShapeEnv
:
return
AlwaysHitShapeEnv
()
with
ExitStack
()
as
stack
:
# hijack to get the compiled graph itself
if
original_load_name
is
not
None
:
stack
.
enter_context
(
patch
(
original_load_name
,
hijack_load
))
# for hijacking the hash of the compiled graph
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.compiled_fx_graph_hash"
,
hijack_compiled_fx_graph_hash
,
)
)
# for providing a dummy shape environment
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
_get_shape_env
,
)
)
from
torch._functorch._aot_autograd.autograd_cache
import
AOTAutogradCache
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if
hasattr
(
AOTAutogradCache
,
"_get_shape_env"
):
stack
.
enter_context
(
patch
(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env"
,
_get_shape_env
,
)
)
# for forcing the graph to be cached
stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache._check_can_cache"
,
_check_can_cache
,
)
)
# Dynamo metrics context, see method for more details.
stack
.
enter_context
(
self
.
metrics_context
())
# Disable remote caching. When these are on, on remote cache-hit,
# the monkey-patched functions never actually get called.
# vLLM today assumes and requires the monkey-patched functions to
# get hit.
# TODO(zou3519): we're going to replace this all with
# standalone_compile sometime.
stack
.
enter_context
(
torch
.
_inductor
.
config
.
patch
(
fx_graph_remote_cache
=
False
)
)
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_autograd_cache
=
False
)
)
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_remote_autograd_cache
=
False
)
)
with
pass_context
(
runtime_shape
):
compiled_graph
=
compile_fx
(
graph
,
example_inputs
,
inner_compile
=
hijacked_compile_fx_inner
,
config_patches
=
current_config
,
)
return
compiled_graph
,
(
hash_str
,
file_path
)
def
load
(
self
,
handle
:
Any
,
graph
:
fx
.
GraphModule
,
example_inputs
:
list
[
Any
],
graph_index
:
int
,
runtime_shape
:
Optional
[
int
]
=
None
,
)
->
Callable
:
assert
isinstance
(
handle
,
tuple
)
assert
isinstance
(
handle
[
0
],
str
)
assert
isinstance
(
handle
[
1
],
str
)
hash_str
=
handle
[
0
]
from
torch._functorch._aot_autograd.autograd_cache
import
AOTAutogradCache
from
torch._inductor.codecache
import
FxGraphCache
with
ExitStack
()
as
exit_stack
:
exit_stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
lambda
*
args
,
**
kwargs
:
AlwaysHitShapeEnv
(),
)
)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if
hasattr
(
AOTAutogradCache
,
"_get_shape_env"
):
exit_stack
.
enter_context
(
patch
(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env"
,
lambda
*
args
,
**
kwargs
:
AlwaysHitShapeEnv
(),
)
)
# Dynamo metrics context, see method for more details.
exit_stack
.
enter_context
(
self
.
metrics_context
())
if
torch
.
__version__
.
startswith
(
"2.5"
):
inductor_compiled_graph
=
FxGraphCache
.
_lookup_graph
(
hash_str
,
example_inputs
,
True
,
False
)
assert
inductor_compiled_graph
is
not
None
,
(
"Inductor cache lookup failed. Please remove"
f
"the cache directory and try again."
# noqa
)
elif
torch
.
__version__
>=
"2.6"
:
from
torch._inductor.output_code
import
CompiledFxGraphConstantsWithGm
constants
=
CompiledFxGraphConstantsWithGm
(
graph
)
inductor_compiled_graph
,
_
=
FxGraphCache
.
_lookup_graph
(
hash_str
,
example_inputs
,
True
,
None
,
constants
)
assert
inductor_compiled_graph
is
not
None
,
(
"Inductor cache lookup failed. Please remove"
f
"the cache directory and try again."
# noqa
)
# Inductor calling convention (function signature):
# f(list) -> tuple
# Dynamo calling convention (function signature):
# f(*args) -> Any
# need to know if the graph returns a tuple
from
torch._inductor.compile_fx
import
graph_returns_tuple
returns_tuple
=
graph_returns_tuple
(
graph
)
# this is the callable we return to Dynamo to run
def
compiled_graph
(
*
args
):
# convert args to list
list_args
=
list
(
args
)
graph_output
=
inductor_compiled_graph
(
list_args
)
# unpack the tuple if needed
if
returns_tuple
:
return
graph_output
else
:
return
graph_output
[
0
]
return
compiled_graph
def
metrics_context
(
self
)
->
contextlib
.
AbstractContextManager
:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different level of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
import
torch._dynamo.utils
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
def
set_inductor_config
(
config
,
runtime_shape
):
if
isinstance
(
runtime_shape
,
int
):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
config
[
"max_autotune"
]
=
True
config
[
"coordinate_descent_tuning"
]
=
True
python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/cuda_piecewise_backend.py
import
dataclasses
import
logging
from
contextlib
import
ExitStack
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
unittest.mock
import
patch
import
torch
import
torch.fx
as
fx
import
sglang.srt.model_executor.compilation.weak_ref_tensor_jit
from
sglang.srt.model_executor.compilation.compilation_config
import
CompilationConfig
from
sglang.srt.model_executor.compilation.compilation_counter
import
(
compilation_counter
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
weak_ref_tensor
(
tensor
:
Any
)
->
Any
:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
if
isinstance
(
tensor
,
torch
.
Tensor
):
# TODO(yuwei): introduce weak_ref_tensor from sgl_kernel
return
torch
.
ops
.
jit_weak_ref_tensor
.
weak_ref_tensor
(
tensor
)
return
tensor
def
weak_ref_tensors
(
tensors
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
list
[
Any
],
tuple
[
Any
],
Any
]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if
isinstance
(
tensors
,
torch
.
Tensor
):
return
weak_ref_tensor
(
tensors
)
if
isinstance
(
tensors
,
list
):
return
[
weak_ref_tensor
(
t
)
for
t
in
tensors
]
if
isinstance
(
tensors
,
tuple
):
return
tuple
(
weak_ref_tensor
(
t
)
for
t
in
tensors
)
raise
ValueError
(
"Invalid type for tensors"
)
@
dataclasses
.
dataclass
class
ConcreteSizeEntry
:
runtime_shape
:
int
need_to_compile
:
bool
# the size is in compile_sizes
use_cudagraph
:
bool
# the size is in cudagraph_capture_sizes
compiled
:
bool
=
False
runnable
:
Callable
=
None
# type: ignore
num_finished_warmup
:
int
=
0
cudagraph
:
Optional
[
torch
.
cuda
.
CUDAGraph
]
=
None
output
:
Optional
[
Any
]
=
None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses
:
Optional
[
list
[
int
]]
=
None
class
CUDAPiecewiseBackend
:
def
__init__
(
self
,
graph
:
fx
.
GraphModule
,
compile_config
:
CompilationConfig
,
inductor_config
:
dict
[
str
,
Any
],
graph_pool
:
Any
,
piecewise_compile_index
:
int
,
total_piecewise_compiles
:
int
,
sym_shape_indices
:
list
[
int
],
compiled_graph_for_general_shape
:
Callable
,
sglang_backend
,
):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self
.
graph
=
graph
self
.
inductor_config
=
inductor_config
self
.
graph_pool
=
graph_pool
self
.
piecewise_compile_index
=
piecewise_compile_index
self
.
total_piecewise_compiles
=
total_piecewise_compiles
self
.
sglang_backend
=
sglang_backend
self
.
is_first_graph
=
piecewise_compile_index
==
0
self
.
is_last_graph
=
piecewise_compile_index
==
total_piecewise_compiles
-
1
self
.
compile_sizes
:
set
[
int
]
=
set
([])
self
.
compile_config
=
compile_config
self
.
cudagraph_capture_sizes
:
set
[
int
]
=
set
(
compile_config
.
get_capture_sizes
())
self
.
first_run_finished
=
False
self
.
compiled_graph_for_general_shape
=
compiled_graph_for_general_shape
# noqa
self
.
sym_shape_indices
=
sym_shape_indices
self
.
is_debugging_mode
=
True
# the entries for different shapes that we need to either
# compile or capture cudagraph
self
.
concrete_size_entries
:
dict
[
int
,
ConcreteSizeEntry
]
=
{}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self
.
to_be_compiled_sizes
:
set
[
int
]
=
self
.
compile_sizes
.
copy
()
for
shape
in
self
.
compile_sizes
.
union
(
self
.
cudagraph_capture_sizes
):
self
.
concrete_size_entries
[
shape
]
=
ConcreteSizeEntry
(
runtime_shape
=
shape
,
need_to_compile
=
shape
in
self
.
compile_sizes
,
use_cudagraph
=
shape
in
self
.
cudagraph_capture_sizes
,
)
def
check_for_ending_compilation
(
self
):
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self
.
sglang_backend
.
compiler_manager
.
save_to_file
()
def
__call__
(
self
,
*
args
)
->
Any
:
if
not
self
.
first_run_finished
:
self
.
first_run_finished
=
True
self
.
check_for_ending_compilation
()
return
self
.
compiled_graph_for_general_shape
(
*
args
)
runtime_shape
=
args
[
self
.
sym_shape_indices
[
0
]]
if
runtime_shape
not
in
self
.
concrete_size_entries
:
# we don't need to do anything for this shape
return
self
.
compiled_graph_for_general_shape
(
*
args
)
entry
=
self
.
concrete_size_entries
[
runtime_shape
]
if
entry
.
runnable
is
None
:
entry
.
runnable
=
self
.
compiled_graph_for_general_shape
if
entry
.
need_to_compile
and
not
entry
.
compiled
:
entry
.
compiled
=
True
self
.
to_be_compiled_sizes
.
remove
(
runtime_shape
)
# args are real arguments
entry
.
runnable
=
self
.
sglang_backend
.
compiler_manager
.
compile
(
self
.
graph
,
args
,
self
.
inductor_config
,
graph_index
=
self
.
piecewise_compile_index
,
num_graphs
=
self
.
total_piecewise_compiles
,
runtime_shape
=
runtime_shape
,
)
# finished compilations for all required shapes
if
self
.
is_last_graph
and
not
self
.
to_be_compiled_sizes
:
self
.
check_for_ending_compilation
()
# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
# skip_cuda_graphs = get_forward_context().skip_cuda_graphs
# if not entry.use_cudagraph or skip_cuda_graphs:
# return entry.runnable(*args)
if
entry
.
cudagraph
is
None
:
if
entry
.
num_finished_warmup
<
1
:
# noqa
entry
.
num_finished_warmup
+=
1
return
entry
.
runnable
(
*
args
)
input_addresses
=
[
x
.
data_ptr
()
for
x
in
args
if
isinstance
(
x
,
torch
.
Tensor
)
]
entry
.
input_addresses
=
input_addresses
cudagraph
=
torch
.
cuda
.
CUDAGraph
()
with
ExitStack
()
as
stack
:
if
not
self
.
is_first_graph
:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack
.
enter_context
(
patch
(
"gc.collect"
,
lambda
:
None
))
stack
.
enter_context
(
patch
(
"torch.cuda.empty_cache"
,
lambda
:
None
))
# mind-exploding: carefully manage the reference and memory.
with
torch
.
cuda
.
graph
(
cudagraph
,
pool
=
self
.
graph_pool
):
# `output` is managed by pytorch's cudagraph pool
output
=
entry
.
runnable
(
*
args
)
if
self
.
is_last_graph
:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output
=
weak_ref_tensors
(
output
)
# here we always use weak ref for the output
# to save memory
entry
.
output
=
weak_ref_tensors
(
output
)
entry
.
cudagraph
=
cudagraph
compilation_counter
.
num_cudagraph_captured
+=
1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return
output
if
self
.
is_debugging_mode
:
# check if the input addresses are the same
new_input_addresses
=
[
x
.
data_ptr
()
for
x
in
args
if
isinstance
(
x
,
torch
.
Tensor
)
]
assert
new_input_addresses
==
entry
.
input_addresses
,
(
"Input addresses for cudagraphs are different during replay."
f
" Expected
{
entry
.
input_addresses
}
, got
{
new_input_addresses
}
"
)
entry
.
cudagraph
.
replay
()
return
entry
.
output
python/sglang/srt/model_executor/compilation/fix_functionalization.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fix_functionalization.py
import
logging
import
operator
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
import
torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
sglang.srt.model_executor.compilation.fx_utils
import
is_func
from
sglang.srt.model_executor.compilation.inductor_pass
import
SGLangInductorPass
logger
=
logging
.
getLogger
(
__name__
)
class
FixFunctionalizationPass
(
SGLangInductorPass
):
"""
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
After this pass, DCE (dead-code elimination) should never be run,
as de-functionalized nodes may appear as dead code.
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
begin
()
self
.
dump_graph
(
graph
,
"before_fix_functionalization"
)
self
.
nodes_to_remove
:
list
[
torch
.
fx
.
Node
]
=
[]
count
=
0
for
node
in
graph
.
nodes
:
if
not
is_func
(
node
,
auto_functionalized
):
continue
# Avoid deep if-elif nesting
count
+=
1
self
.
dump_graph
(
graph
,
"before_fix_functionalization_cleanup"
)
# Remove the nodes all at once
count_removed
=
len
(
self
.
nodes_to_remove
)
for
node
in
self
.
nodes_to_remove
:
graph
.
erase_node
(
node
)
logger
.
debug
(
"De-functionalized %s nodes, removed %s nodes"
,
count
,
count_removed
)
self
.
dump_graph
(
graph
,
"after_fix_functionalization"
)
self
.
end_and_log
()
def
_remove
(
self
,
node_or_nodes
:
Union
[
torch
.
fx
.
Node
,
Iterable
[
torch
.
fx
.
Node
]]):
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
if
isinstance
(
node_or_nodes
,
torch
.
fx
.
Node
):
self
.
nodes_to_remove
.
append
(
node_or_nodes
)
else
:
self
.
nodes_to_remove
.
extend
(
node_or_nodes
)
def
defunctionalize
(
self
,
graph
:
torch
.
fx
.
Graph
,
node
:
torch
.
fx
.
Node
,
mutated_args
:
dict
[
int
,
Union
[
torch
.
fx
.
Node
,
str
]],
args
:
Optional
[
tuple
[
Union
[
torch
.
fx
.
Node
,
str
],
...]]
=
None
,
):
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
See replace_users_with_mutated_args and insert_defunctionalized.
"""
self
.
replace_users_with_mutated_args
(
node
,
mutated_args
)
self
.
insert_defunctionalized
(
graph
,
node
,
args
=
args
)
self
.
_remove
(
node
)
def
replace_users_with_mutated_args
(
self
,
node
:
torch
.
fx
.
Node
,
mutated_args
:
dict
[
int
,
Union
[
torch
.
fx
.
Node
,
str
]]
):
"""
Replace all getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for
idx
,
user
in
self
.
getitem_users
(
node
).
items
():
arg
=
mutated_args
[
idx
]
arg
=
node
.
kwargs
[
arg
]
if
isinstance
(
arg
,
str
)
else
arg
user
.
replace_all_uses_with
(
arg
)
self
.
_remove
(
user
)
def
getitem_users
(
self
,
node
:
torch
.
fx
.
Node
)
->
dict
[
int
,
torch
.
fx
.
Node
]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
"""
users
=
{}
for
user
in
node
.
users
:
if
is_func
(
user
,
operator
.
getitem
):
idx
=
user
.
args
[
1
]
users
[
idx
]
=
user
return
users
def
insert_defunctionalized
(
self
,
graph
:
torch
.
fx
.
Graph
,
node
:
torch
.
fx
.
Node
,
args
:
Optional
[
tuple
[
Union
[
torch
.
fx
.
Node
,
str
],
...]]
=
None
,
):
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
"""
# noqa: E501
assert
is_func
(
node
,
auto_functionalized
),
f
"node must be auto-functionalized, is
{
node
}
instead"
# Create a new call to the original function
with
graph
.
inserting_before
(
node
):
function
=
node
.
args
[
0
]
if
args
is
None
:
graph
.
call_function
(
function
,
kwargs
=
node
.
kwargs
)
else
:
# Args passed as strings refer to items in node.kwargs
args
=
tuple
(
node
.
kwargs
[
arg
]
if
isinstance
(
arg
,
str
)
else
arg
for
arg
in
args
)
graph
.
call_function
(
function
,
args
=
args
)
python/sglang/srt/model_executor/compilation/fx_utils.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/fx_utils.py
import
operator
from
collections.abc
import
Iterable
,
Iterator
from
typing
import
Optional
from
torch
import
fx
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._ops
import
OpOverload
def
is_func
(
node
:
fx
.
Node
,
target
)
->
bool
:
return
node
.
op
==
"call_function"
and
node
.
target
==
target
def
is_auto_func
(
node
:
fx
.
Node
,
op
:
OpOverload
)
->
bool
:
return
is_func
(
node
,
auto_functionalized
)
and
node
.
args
[
0
]
==
op
# Returns the first specified node with the given op (if it exists)
def
find_specified_fn_maybe
(
nodes
:
Iterable
[
fx
.
Node
],
op
:
OpOverload
)
->
Optional
[
fx
.
Node
]:
for
node
in
nodes
:
if
node
.
target
==
op
:
return
node
return
None
# Returns the first specified node with the given op
def
find_specified_fn
(
nodes
:
Iterable
[
fx
.
Node
],
op
:
OpOverload
)
->
fx
.
Node
:
node
=
find_specified_fn_maybe
(
nodes
,
op
)
assert
node
is
not
None
,
f
"Could not find
{
op
}
in nodes
{
nodes
}
"
return
node
# Returns the first auto_functionalized node with the given op (if it exists)
def
find_auto_fn_maybe
(
nodes
:
Iterable
[
fx
.
Node
],
op
:
OpOverload
)
->
Optional
[
fx
.
Node
]:
for
node
in
nodes
:
if
is_func
(
node
,
auto_functionalized
)
and
node
.
args
[
0
]
==
op
:
# noqa
return
node
return
None
# Returns the first auto_functionalized node with the given op
def
find_auto_fn
(
nodes
:
Iterable
[
fx
.
Node
],
op
:
OpOverload
)
->
fx
.
Node
:
node
=
find_auto_fn_maybe
(
nodes
,
op
)
assert
node
is
not
None
,
f
"Could not find
{
op
}
in nodes
{
nodes
}
"
return
node
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def
find_getitem_maybe
(
node
:
fx
.
Node
,
idx
:
int
)
->
Optional
[
fx
.
Node
]:
for
user
in
node
.
users
:
if
is_func
(
user
,
operator
.
getitem
)
and
user
.
args
[
1
]
==
idx
:
return
user
return
None
# Returns the getitem node that extracts the idx-th element from node
def
find_getitem
(
node
:
fx
.
Node
,
idx
:
int
)
->
fx
.
Node
:
ret
=
find_getitem_maybe
(
node
,
idx
)
assert
ret
is
not
None
,
f
"Could not find getitem
{
idx
}
in node
{
node
}
"
return
ret
# An auto-functionalization-aware utility for finding nodes with a specific op
def
find_op_nodes
(
op
:
OpOverload
,
graph
:
fx
.
Graph
)
->
Iterator
[
fx
.
Node
]:
if
not
op
.
_schema
.
is_mutable
:
yield
from
graph
.
find_nodes
(
op
=
"call_function"
,
target
=
op
)
for
n
in
graph
.
find_nodes
(
op
=
"call_function"
,
target
=
auto_functionalized
):
if
n
.
args
[
0
]
==
op
:
yield
n
# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def
get_only_user
(
node
:
fx
.
Node
)
->
fx
.
Node
:
assert
len
(
node
.
users
)
==
1
return
next
(
iter
(
node
.
users
))
python/sglang/srt/model_executor/compilation/inductor_pass.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py
import
hashlib
import
inspect
import
json
import
logging
import
time
import
types
from
contextlib
import
contextmanager
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
from
torch
import
fx
from
torch._dynamo.utils
import
lazy_format_graph_code
from
torch._inductor.custom_graph_pass
import
CustomGraphPass
logger
=
logging
.
getLogger
(
__name__
)
_pass_context
=
None
class
PassContext
:
def
__init__
(
self
,
runtime_shape
:
Optional
[
int
]):
self
.
runtime_shape
=
runtime_shape
def
get_pass_context
()
->
PassContext
:
"""Get the current pass context."""
assert
_pass_context
is
not
None
return
_pass_context
@
contextmanager
def
pass_context
(
runtime_shape
:
Optional
[
int
]):
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global
_pass_context
prev_context
=
_pass_context
_pass_context
=
PassContext
(
runtime_shape
)
try
:
yield
finally
:
_pass_context
=
prev_context
class
InductorPass
(
CustomGraphPass
):
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
def
uuid
(
self
)
->
Any
:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
return
InductorPass
.
hash_source
(
self
)
@
staticmethod
def
hash_source
(
*
srcs
:
Union
[
str
,
Any
]):
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher
=
hashlib
.
sha256
()
for
src
in
srcs
:
if
isinstance
(
src
,
str
):
src_str
=
src
elif
isinstance
(
src
,
types
.
FunctionType
):
src_str
=
inspect
.
getsource
(
src
)
else
:
src_str
=
inspect
.
getsource
(
src
.
__class__
)
hasher
.
update
(
src_str
.
encode
(
"utf-8"
))
return
hasher
.
hexdigest
()
@
staticmethod
def
hash_dict
(
dict_
:
dict
[
Any
,
Any
]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded
=
json
.
dumps
(
dict_
,
sort_keys
=
True
).
encode
(
"utf-8"
)
return
hashlib
.
sha256
(
encoded
).
hexdigest
()
def
is_applicable_for_shape
(
self
,
shape
:
Optional
[
int
]):
return
True
class
CallableInductorPass
(
InductorPass
):
"""
This class is a wrapper for a callable that automatically provides an
implementation of the UUID.
"""
def
__init__
(
self
,
callable
:
Callable
[[
fx
.
Graph
],
None
],
uuid
:
Optional
[
Any
]
=
None
):
self
.
callable
=
callable
self
.
_uuid
=
self
.
hash_source
(
callable
)
if
uuid
is
None
else
uuid
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
callable
(
graph
)
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
class
SGLangInductorPass
(
InductorPass
):
def
__init__
(
self
,
):
self
.
pass_name
=
self
.
__class__
.
__name__
def
dump_graph
(
self
,
graph
:
torch
.
fx
.
Graph
,
stage
:
str
):
lazy_format_graph_code
(
stage
,
graph
.
owning_module
)
def
begin
(
self
):
self
.
_start_time
=
time
.
perf_counter_ns
()
def
end_and_log
(
self
):
self
.
_end_time
=
time
.
perf_counter_ns
()
duration_ms
=
float
(
self
.
_end_time
-
self
.
_start_time
)
/
1.0e6
logger
.
debug
(
"%s completed in %.1f ms"
,
self
.
pass_name
,
duration_ms
)
class
PrinterInductorPass
(
SGLangInductorPass
):
def
__init__
(
self
,
name
:
str
):
super
().
__init__
()
self
.
name
=
name
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
self
.
dump_graph
(
graph
,
self
.
name
)
python/sglang/srt/model_executor/compilation/pass_manager.py
0 → 100644
View file @
4ac8e09d
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/pass_manager.py
import
logging
from
torch
import
fx
as
fx
from
sglang.srt.model_executor.compilation.fix_functionalization
import
(
FixFunctionalizationPass
,
)
from
sglang.srt.model_executor.compilation.inductor_pass
import
(
CustomGraphPass
,
InductorPass
,
SGLangInductorPass
,
get_pass_context
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
PostGradPassManager
(
CustomGraphPass
):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
"""
def
__init__
(
self
):
self
.
passes
:
list
[
SGLangInductorPass
]
=
[]
def
__call__
(
self
,
graph
:
fx
.
Graph
):
shape
=
get_pass_context
().
runtime_shape
for
pass_
in
self
.
passes
:
if
pass_
.
is_applicable_for_shape
(
shape
):
pass_
(
graph
)
# always run fix_functionalization last
self
.
fix_functionalization
(
graph
)
def
configure
(
self
,
):
self
.
pass_config
=
dict
()
self
.
fix_functionalization
=
FixFunctionalizationPass
()
def
add
(
self
,
pass_
:
InductorPass
):
assert
isinstance
(
pass_
,
InductorPass
)
self
.
passes
.
append
(
pass_
)
def
uuid
(
self
):
"""
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
pass_manager_uuid
=
"fshdakhsa"
state
=
{
"pass_config"
:
pass_manager_uuid
,
"passes"
:
[]}
for
pass_
in
self
.
passes
:
state
[
"passes"
].
append
(
pass_
.
uuid
())
state
[
"passes"
].
append
(
self
.
fix_functionalization
.
uuid
())
return
InductorPass
.
hash_dict
(
state
)
python/sglang/srt/model_executor/compilation/piecewise_context_manager.py
0 → 100644
View file @
4ac8e09d
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
Any
,
List
,
Optional
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
@
dataclass
class
ForwardContext
:
def
__init__
(
self
):
self
.
forward_batch
=
None
self
.
attention_layer
=
None
def
set_forward_batch
(
self
,
forward_batch
:
ForwardBatch
):
self
.
forward_batch
=
forward_batch
def
set_attention_layers
(
self
,
layers
:
List
[
Any
]):
self
.
attention_layers
=
layers
_forward_context
:
Optional
[
ForwardContext
]
=
None
def
get_forward_context
()
->
Optional
[
ForwardContext
]:
if
_forward_context
is
None
:
return
None
return
_forward_context
@
contextmanager
def
set_forward_context
(
forward_batch
:
ForwardBatch
,
attention_layers
:
List
[
Any
]):
global
_forward_context
prev_forward_context
=
_forward_context
_forward_context
=
ForwardContext
()
_forward_context
.
set_forward_batch
(
forward_batch
)
_forward_context
.
set_attention_layers
(
attention_layers
)
try
:
yield
finally
:
_forward_context
=
prev_forward_context
python/sglang/srt/model_executor/compilation/weak_ref_tensor.cpp
0 → 100644
View file @
4ac8e09d
// Adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/ops.h
#include <torch/extension.h>
#include <vector>
static
at
::
Tensor
weak_ref_tensor
(
at
::
Tensor
&
tensor
)
{
TORCH_CHECK
(
tensor
.
is_cuda
(),
"weak_ref_tensor expects a CUDA tensor"
);
void
*
data_ptr
=
tensor
.
data_ptr
();
std
::
vector
<
int64_t
>
sizes
=
tensor
.
sizes
().
vec
();
std
::
vector
<
int64_t
>
strides
=
tensor
.
strides
().
vec
();
auto
options
=
tensor
.
options
();
auto
new_tensor
=
torch
::
from_blob
(
data_ptr
,
sizes
,
strides
,
options
);
return
new_tensor
;
}
TORCH_LIBRARY
(
jit_weak_ref_tensor
,
ops
)
{
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
}
TORCH_LIBRARY_IMPL
(
jit_weak_ref_tensor
,
CUDA
,
ops
)
{
ops
.
impl
(
"weak_ref_tensor"
,
weak_ref_tensor
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{}
python/sglang/srt/model_executor/compilation/weak_ref_tensor_jit.py
0 → 100644
View file @
4ac8e09d
import
os
import
torch
from
torch.utils.cpp_extension
import
load
_abs_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
load
(
name
=
"weak_ref_tensor_ext"
,
sources
=
[
f
"
{
_abs_path
}
/weak_ref_tensor.cpp"
],
extra_cflags
=
[
"-O3"
],
)
x
=
torch
.
arange
(
12
,
device
=
"cuda"
).
reshape
(
3
,
4
)
y
=
torch
.
ops
.
jit_weak_ref_tensor
.
weak_ref_tensor
(
x
)
print
(
"alias:"
,
x
.
data_ptr
()
==
y
.
data_ptr
())
python/sglang/srt/model_executor/model_runner.py
View file @
4ac8e09d
...
...
@@ -108,8 +108,15 @@ from sglang.srt.mem_cache.memory_pool import (
)
from
sglang.srt.model_executor.cpu_graph_runner
import
CPUGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
from
sglang.srt.model_executor.piecewise_cuda_graph_runner
import
(
PiecewiseCudaGraphRunner
,
)
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.remote_instance_weight_loader_utils
import
(
...
...
@@ -307,6 +314,26 @@ class ModelRunner:
self
.
_model_update_group
=
{}
self
.
_weights_send_group
=
{}
if
(
self
.
server_args
.
enable_piecewise_cuda_graph
and
self
.
can_run_piecewise_cuda_graph
()
):
self
.
attention_layers
=
[]
for
layer
in
self
.
model
.
model
.
layers
:
if
hasattr
(
layer
,
"self_attn"
)
and
hasattr
(
layer
.
self_attn
,
"attn"
):
self
.
attention_layers
.
append
(
layer
.
self_attn
.
attn
)
if
len
(
self
.
attention_layers
)
<
self
.
model_config
.
num_hidden_layers
:
# TODO(yuwei): support Non-Standard GQA
log_info_on_rank0
(
logger
,
"Disable piecewise CUDA graph because some layers do not apply Standard GQA"
,
)
self
.
piecewise_cuda_graph_runner
=
None
else
:
self
.
piecewise_cuda_graph_runner
=
PiecewiseCudaGraphRunner
(
self
)
else
:
self
.
piecewise_cuda_graph_runner
=
None
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
server_args
=
self
.
server_args
...
...
@@ -692,6 +719,7 @@ class ModelRunner:
pipeline_model_parallel_size
=
self
.
pp_size
,
expert_model_parallel_size
=
self
.
moe_ep_size
,
duplicate_tp_group
=
self
.
server_args
.
enable_pdmux
,
torch_compile
=
self
.
server_args
.
enable_piecewise_cuda_graph
,
)
initialize_dp_attention
(
server_args
=
self
.
server_args
,
...
...
@@ -1411,6 +1439,27 @@ class ModelRunner:
f
"Use Sliding window memory pool. full_layer_tokens=
{
self
.
full_max_total_num_tokens
}
, swa_layer_tokens=
{
self
.
swa_max_total_num_tokens
}
"
)
def
can_run_piecewise_cuda_graph
(
self
):
if
self
.
server_args
.
disable_cuda_graph
:
log_info_on_rank0
(
logger
,
"Disable piecewise CUDA graph because disable_cuda_graph is set"
)
return
False
if
self
.
server_args
.
enable_torch_compile
:
log_info_on_rank0
(
logger
,
"Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile"
,
)
return
False
if
self
.
pp_size
>
1
:
# TODO(yuwei): support PP
log_info_on_rank0
(
logger
,
"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP"
,
)
return
False
return
True
def
init_memory_pool
(
self
,
total_gpu_memory
:
int
,
...
...
@@ -1932,6 +1981,11 @@ class ModelRunner:
kwargs
[
"input_embeds"
]
=
forward_batch
.
input_embeds
.
bfloat16
()
if
not
self
.
is_generation
:
kwargs
[
"get_embedding"
]
=
True
if
self
.
piecewise_cuda_graph_runner
is
not
None
:
if
self
.
piecewise_cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
piecewise_cuda_graph_runner
.
replay
(
forward_batch
,
**
kwargs
)
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
...
...
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
0 → 100644
View file @
4ac8e09d
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cuda graph and torch.compile."""
from
__future__
import
annotations
import
bisect
import
gc
import
logging
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Union
import
torch
import
tqdm
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
set_graph_pool_id
,
)
from
sglang.srt.distributed.parallel_state
import
graph_capture
from
sglang.srt.layers.dp_attention
import
(
DpPaddingMode
,
get_attention_tp_rank
,
get_attention_tp_size
,
set_dp_buffer_len
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.model_executor.compilation.compilation_config
import
CompilationConfig
from
sglang.srt.model_executor.compilation.compile
import
(
install_torch_compiled
,
set_compiled
,
)
from
sglang.srt.model_executor.compilation.piecewise_context_manager
import
(
set_forward_context
,
)
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
from
sglang.srt.utils
import
get_available_gpu_memory
,
log_info_on_rank0
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode
=
False
def
get_is_capture_mode
():
return
is_capture_mode
@
contextmanager
def
model_capture_mode
():
global
is_capture_mode
is_capture_mode
=
True
yield
is_capture_mode
=
False
@
contextmanager
def
freeze_gc
(
enable_cudagraph_gc
:
bool
):
"""
Optimize garbage collection during CUDA graph capture.
Clean up, then freeze all remaining objects from being included
in future collections if GC is disabled during capture.
"""
gc
.
collect
()
should_freeze
=
not
enable_cudagraph_gc
if
should_freeze
:
gc
.
freeze
()
try
:
yield
finally
:
if
should_freeze
:
gc
.
unfreeze
()
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
sub
.
leave_torch_compile
()
else
:
sub
.
enter_torch_compile
(
num_tokens
=
num_tokens
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
,
num_tokens
)
@
contextmanager
def
patch_model
(
model
:
torch
.
nn
.
Module
):
try
:
_to_torch
(
model
,
reverse
=
False
,
num_tokens
=
16
)
yield
model
finally
:
_to_torch
(
model
,
reverse
=
True
,
num_tokens
=
16
)
# Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool
=
None
def
get_global_graph_memory_pool
():
return
global_graph_memory_pool
def
set_global_graph_memory_pool
(
val
):
global
global_graph_memory_pool
global_graph_memory_pool
=
val
class
PiecewiseCudaGraphRunner
:
"""A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Parse args
self
.
model_runner
=
model_runner
self
.
device
=
model_runner
.
device
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
assert
(
self
.
model_runner
.
server_args
.
piecewise_cuda_graph_tokens
is
not
None
),
"piecewise_cuda_graph_tokens is not set"
self
.
compile_config
=
CompilationConfig
(
self
.
model_runner
.
server_args
.
piecewise_cuda_graph_tokens
)
# Batch sizes to capture
self
.
capture_num_tokens
=
self
.
compile_config
.
get_capture_sizes
()
log_info_on_rank0
(
logger
,
f
"Capture cuda graph num tokens
{
self
.
capture_num_tokens
}
"
)
self
.
capture_forward_mode
=
ForwardMode
.
EXTEND
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if
model_runner
.
server_args
.
enable_return_hidden_states
:
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
# Attention backend
self
.
max_num_tokens
=
max
(
self
.
capture_num_tokens
)
# Graph inputs
with
torch
.
device
(
self
.
device
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_tokens
,),
dtype
=
torch
.
int64
)
self
.
out_cache_loc
=
torch
.
zeros
(
(
self
.
max_num_tokens
,),
dtype
=
self
.
_cache_loc_dtype
()
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_tokens
,),
dtype
=
torch
.
int64
)
self
.
tbo_plugin
=
TboCudaGraphRunnerPlugin
()
self
.
attention_layers
=
self
.
model_runner
.
attention_layers
if
get_global_graph_memory_pool
()
is
None
:
set_global_graph_memory_pool
(
self
.
device_module
.
graph_pool_handle
())
# Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id
(
get_global_graph_memory_pool
())
with
patch_model
(
self
.
model_runner
.
model
.
model
)
as
patched_model
:
install_torch_compiled
(
patched_model
,
fullgraph
=
True
,
dynamic_arg_dims
=
None
,
compile_config
=
self
.
compile_config
,
graph_pool
=
get_global_graph_memory_pool
(),
)
with
set_compiled
(
True
):
self
.
warmup_and_capture
()
# Capture
try
:
with
model_capture_mode
():
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
{
PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
)
self
.
raw_num_tokens
=
0
def
warmup_and_capture
(
self
):
num_tokens
=
2
with
torch
.
device
(
self
.
device
):
forward_batch
=
ForwardBatch
(
forward_mode
=
ForwardMode
.
EXTEND
,
batch_size
=
1
,
input_ids
=
torch
.
randint
(
0
,
100
,
(
num_tokens
,),
device
=
self
.
device
),
req_pool_indices
=
torch
.
arange
(
1
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
next_token_logits_buffer
=
None
,
orig_seq_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
seq_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
torch
.
randint
(
0
,
100
,
(
num_tokens
,),
device
=
self
.
device
),
seq_lens_sum
=
num_tokens
,
encoder_lens
=
None
,
return_logprob
=
False
,
extend_seq_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
extend_prefix_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
extend_start_loc
=
torch
.
tensor
([
0
],
device
=
self
.
device
),
extend_prefix_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
extend_seq_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
extend_logprob_start_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
positions
=
torch
.
arange
(
num_tokens
,
device
=
self
.
device
),
global_num_tokens_gpu
=
None
,
global_num_tokens_for_logprob_gpu
=
None
,
dp_padding_mode
=
DpPaddingMode
.
get_default_mode_in_cuda_graph
(),
global_dp_buffer_len
=
None
,
mrope_positions
=
None
,
spec_algorithm
=
None
,
spec_info
=
None
,
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
,
num_token_non_padded
=
None
,
global_forward_mode
=
ForwardMode
.
EXTEND
,
lora_ids
=
None
,
)
with
set_forward_context
(
forward_batch
,
self
.
attention_layers
):
_
=
self
.
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
)
def
_cache_loc_dtype
(
self
):
return
torch
.
int64
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
num_tokens
=
len
(
forward_batch
.
input_ids
)
# TODO(yuwei): support return logprob
if
forward_batch
.
return_logprob
:
return
False
if
num_tokens
<=
self
.
max_num_tokens
:
return
True
return
False
def
capture
(
self
)
->
None
:
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with
freeze_gc
(
self
.
model_runner
.
server_args
.
enable_cudagraph_gc
),
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range
=
(
tqdm
.
tqdm
(
list
(
reversed
(
self
.
capture_num_tokens
)))
if
get_tensor_model_parallel_rank
()
==
0
else
reversed
(
self
.
capture_num_tokens
)
)
for
i
,
num_tokens
in
enumerate
(
capture_range
):
if
get_tensor_model_parallel_rank
()
==
0
:
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
capture_range
.
set_description
(
f
"Capturing num tokens (
{
num_tokens
=
}
{
avail_mem
=
:.
2
f
}
GB)"
)
with
set_compiled
(
True
):
self
.
capture_one_batch_size
(
num_tokens
)
# Save gemlite cache after each capture
save_gemlite_cache
()
def
capture_one_batch_size
(
self
,
num_tokens
:
int
):
stream
=
self
.
stream
bs
=
1
# Graph inputs
input_ids
=
self
.
input_ids
[:
num_tokens
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
# pipeline parallelism
if
self
.
pp_size
>
1
:
pp_proxy_tensors
=
PPProxyTensors
(
{
k
:
v
[:
num_tokens
]
for
k
,
v
in
self
.
pp_proxy_tensors
.
items
()}
)
global_dp_buffer_len
=
None
if
self
.
model_runner
.
server_args
.
enable_lora
:
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
lora_ids
=
[
None
]
*
bs
else
:
lora_ids
=
None
with
torch
.
device
(
self
.
device
):
forward_batch
=
ForwardBatch
(
forward_mode
=
ForwardMode
.
EXTEND
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
torch
.
arange
(
bs
,
device
=
self
.
device
),
seq_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
next_token_logits_buffer
=
None
,
orig_seq_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
seq_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
num_tokens
,
encoder_lens
=
None
,
return_logprob
=
False
,
extend_seq_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
extend_prefix_lens
=
torch
.
tensor
([
num_tokens
],
device
=
self
.
device
),
extend_start_loc
=
torch
.
tensor
([
0
],
device
=
self
.
device
),
extend_prefix_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
extend_seq_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
extend_logprob_start_lens_cpu
=
torch
.
tensor
([
num_tokens
]),
positions
=
positions
,
global_num_tokens_gpu
=
None
,
global_num_tokens_for_logprob_gpu
=
None
,
dp_padding_mode
=
DpPaddingMode
.
get_default_mode_in_cuda_graph
(),
global_dp_buffer_len
=
None
,
mrope_positions
=
None
,
spec_algorithm
=
None
,
spec_info
=
None
,
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
,
num_token_non_padded
=
None
,
global_forward_mode
=
ForwardMode
.
EXTEND
,
lora_ids
=
None
,
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
if
lora_ids
is
not
None
:
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
# # Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# Run and capture
def
run_once
():
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
kwargs
=
{}
with
set_forward_context
(
forward_batch
,
self
.
attention_layers
):
self
.
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
return
for
_
in
range
(
2
):
self
.
device_module
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
run_once
()
return
def
replay_prepare
(
self
,
forward_batch
:
ForwardBatch
,
**
kwargs
,
):
num_tokens
=
len
(
forward_batch
.
input_ids
)
index
=
bisect
.
bisect_left
(
self
.
capture_num_tokens
,
num_tokens
)
static_num_tokens
=
self
.
capture_num_tokens
[
index
]
self
.
raw_num_tokens
=
num_tokens
if
static_num_tokens
!=
num_tokens
:
self
.
out_cache_loc
.
zero_
()
bs
=
forward_batch
.
batch_size
self
.
input_ids
[:
num_tokens
].
copy_
(
forward_batch
.
input_ids
)
self
.
positions
[:
num_tokens
].
copy_
(
forward_batch
.
positions
)
self
.
out_cache_loc
[:
num_tokens
].
copy_
(
forward_batch
.
out_cache_loc
)
input_ids
=
self
.
input_ids
[:
static_num_tokens
]
positions
=
self
.
positions
[:
static_num_tokens
]
out_cache_loc
=
self
.
out_cache_loc
[:
static_num_tokens
]
next_token_logits_buffer
=
None
mrope_positions
=
None
static_forward_batch
=
ForwardBatch
(
forward_mode
=
forward_batch
.
forward_mode
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
forward_batch
.
req_pool_indices
,
seq_lens
=
forward_batch
.
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
orig_seq_lens
=
forward_batch
.
orig_seq_lens
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
forward_batch
.
seq_lens_sum
,
encoder_lens
=
forward_batch
.
encoder_lens
,
return_logprob
=
forward_batch
.
return_logprob
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
extend_prefix_lens
=
forward_batch
.
extend_prefix_lens
,
extend_start_loc
=
forward_batch
.
extend_start_loc
,
extend_prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
,
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
extend_logprob_start_lens_cpu
=
forward_batch
.
extend_logprob_start_lens_cpu
,
extend_num_tokens
=
forward_batch
.
extend_num_tokens
,
extend_input_logprob_token_ids_gpu
=
forward_batch
.
extend_input_logprob_token_ids_gpu
,
positions
=
positions
,
global_num_tokens_gpu
=
forward_batch
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
forward_batch
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
forward_batch
.
dp_padding_mode
,
global_dp_buffer_len
=
forward_batch
.
global_dp_buffer_len
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
forward_batch
.
spec_algorithm
,
spec_info
=
forward_batch
.
spec_info
,
capture_hidden_mode
=
forward_batch
.
capture_hidden_mode
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
global_forward_mode
=
forward_batch
.
global_forward_mode
,
lora_ids
=
forward_batch
.
lora_ids
,
sampling_info
=
forward_batch
.
sampling_info
,
mm_inputs
=
forward_batch
.
mm_inputs
,
temp_scaled_logprobs
=
forward_batch
.
temp_scaled_logprobs
,
temperature
=
forward_batch
.
temperature
,
top_p_normalized_logprobs
=
forward_batch
.
top_p_normalized_logprobs
,
top_p
=
forward_batch
.
top_p
,
)
return
static_forward_batch
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
**
kwargs
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
static_forward_batch
=
self
.
replay_prepare
(
forward_batch
,
**
kwargs
)
# Replay
with
set_forward_context
(
static_forward_batch
,
self
.
attention_layers
):
with
set_compiled
(
True
):
output
=
self
.
model_runner
.
model
.
forward
(
static_forward_batch
.
input_ids
,
static_forward_batch
.
positions
,
static_forward_batch
,
**
kwargs
,
)
if
isinstance
(
output
,
LogitsProcessorOutput
):
return
LogitsProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
self
.
raw_num_tokens
],
hidden_states
=
(
output
.
hidden_states
[:
self
.
raw_num_tokens
]
if
output
.
hidden_states
is
not
None
else
None
),
)
else
:
assert
isinstance
(
output
,
PPProxyTensors
)
# TODO(Yuwei): support PP Support
raise
NotImplementedError
(
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
)
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
if
(
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
or
self
.
model_runner
.
spec_algorithm
.
is_standalone
()
):
from
sglang.srt.speculative.eagle_utils
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
else
:
spec_info
=
EagleVerifyInput
(
draft_token
=
None
,
custom_mask
=
self
.
custom_mask
,
positions
=
None
,
retrive_index
=
None
,
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
seq_lens_sum
=
None
,
seq_lens_cpu
=
None
,
)
return
spec_info
PIECEWISE_CUDA_GRAPH_CAPTURE_FAILED_MSG
=
(
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --piecewise-cuda-graph-max-tokens to a smaller value (e.g., 512)
\n
"
"3. disable Piecewise CUDA graph by unset --enable-piecewise-cuda-graph
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
python/sglang/srt/server_args.py
View file @
4ac8e09d
...
...
@@ -417,7 +417,10 @@ class ServerArgs:
enable_single_batch_overlap
:
bool
=
False
tbo_token_distribution_threshold
:
float
=
0.48
enable_torch_compile
:
bool
=
False
enable_piecewise_cuda_graph
:
bool
=
False
torch_compile_max_bs
:
int
=
32
piecewise_cuda_graph_max_tokens
:
int
=
4096
piecewise_cuda_graph_tokens
:
Optional
[
List
[
int
]]
=
None
torchao_config
:
str
=
""
enable_nan_detection
:
bool
=
False
enable_p2p_check
:
bool
=
False
...
...
@@ -675,6 +678,11 @@ class ServerArgs:
else
:
self
.
cuda_graph_max_bs
=
max
(
self
.
cuda_graph_bs
)
if
self
.
piecewise_cuda_graph_tokens
is
None
:
self
.
piecewise_cuda_graph_tokens
=
(
self
.
_generate_piecewise_cuda_graph_tokens
()
)
if
self
.
mem_fraction_static
is
None
:
# Constant meta data (e.g., from attention backend)
reserved_mem
=
512
...
...
@@ -753,6 +761,25 @@ class ServerArgs:
return
capture_bs
def
_generate_piecewise_cuda_graph_tokens
(
self
):
"""
Generate the list of batch sizes for piecewise CUDA graph capture
based on piecewise_cuda_graph_max_tokens.
"""
capture_sizes
=
(
list
(
range
(
4
,
33
,
4
))
+
list
(
range
(
48
,
257
,
16
))
+
list
(
range
(
288
,
513
,
32
))
+
list
(
range
(
640
,
4096
+
1
,
128
))
+
list
(
range
(
4352
,
self
.
piecewise_cuda_graph_max_tokens
+
1
,
256
))
)
capture_sizes
=
[
s
for
s
in
capture_sizes
if
s
<=
self
.
piecewise_cuda_graph_max_tokens
]
return
capture_sizes
def
_handle_hpu_backends
(
self
):
if
self
.
device
==
"hpu"
:
self
.
attention_backend
=
"torch_native"
...
...
@@ -2649,12 +2676,29 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Optimize the model with torch.compile. Experimental feature."
,
)
parser
.
add_argument
(
"--enable-piecewise-cuda-graph"
,
action
=
"store_true"
,
help
=
"Optimize the model with piecewise cuda graph for extend/prefill only. Experimental feature."
,
)
parser
.
add_argument
(
"--piecewise-cuda-graph-tokens"
,
type
=
json_list_type
,
default
=
ServerArgs
.
piecewise_cuda_graph_tokens
,
help
=
"Set the list of tokens when using piecewise cuda graph."
,
)
parser
.
add_argument
(
"--torch-compile-max-bs"
,
type
=
int
,
default
=
ServerArgs
.
torch_compile_max_bs
,
help
=
"Set the maximum batch size when using torch compile."
,
)
parser
.
add_argument
(
"--piecewise-cuda-graph-max-tokens"
,
type
=
int
,
default
=
ServerArgs
.
piecewise_cuda_graph_max_tokens
,
help
=
"Set the maximum tokens when using piecewise cuda graph."
,
)
parser
.
add_argument
(
"--torchao-config"
,
type
=
str
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment