Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7577f0e4
Unverified
Commit
7577f0e4
authored
Sep 08, 2025
by
Cao E
Committed by
GitHub
Sep 07, 2025
Browse files
Add graph runner support with torch compile on CPU (#7843)
parent
8cda5a62
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
820 additions
and
48 deletions
+820
-48
.github/workflows/pr-test-xeon.yml
.github/workflows/pr-test-xeon.yml
+1
-1
docs/platforms/cpu_server.md
docs/platforms/cpu_server.md
+6
-1
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+4
-3
python/sglang/srt/layers/attention/intel_amx_backend.py
python/sglang/srt/layers/attention/intel_amx_backend.py
+3
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+3
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+5
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-5
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+1
-1
python/sglang/srt/model_executor/cpu_graph_runner.py
python/sglang/srt/model_executor/cpu_graph_runner.py
+640
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+38
-16
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-1
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
+5
-5
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_cpu_graph.py
test/srt/test_cpu_graph.py
+87
-0
test/srt/test_intel_amx_attention_backend.py
test/srt/test_intel_amx_attention_backend.py
+10
-8
No files found.
.github/workflows/pr-test-xeon.yml
View file @
7577f0e4
...
@@ -70,7 +70,7 @@ jobs:
...
@@ -70,7 +70,7 @@ jobs:
-
name
:
Run unit tests
-
name
:
Run unit tests
if
:
steps.check_amx.outcome == 'success'
if
:
steps.check_amx.outcome == 'success'
timeout-minutes
:
3
0
timeout-minutes
:
3
6
run
:
|
run
:
|
docker exec -w /sglang-checkout/ ci_sglang_xeon \
docker exec -w /sglang-checkout/ ci_sglang_xeon \
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
...
...
docs/platforms/cpu_server.md
View file @
7577f0e4
...
@@ -134,7 +134,12 @@ Notes:
...
@@ -134,7 +134,12 @@ Notes:
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
```
```
3.
A warmup step is automatically triggered when the service is started.
3.
For optimizing decoding with torch.compile, please add the flag
`--enable-torch-compile`
.
To specify the maximum batch size when using torch compile, set the flag
`--torch-compile-max-bs`
.
For example,
`--enable-torch-compile --torch-compile-max-bs 4`
means using torch compile and setting the
maximum batch size to 4.
4.
A warmup step is automatically triggered when the service is started.
The server is ready when you see the log
`The server is fired up and ready to roll!`
.
The server is ready when you see the log
`The server is fired up and ready to roll!`
.
## Benchmarking with Requests
## Benchmarking with Requests
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
7577f0e4
...
@@ -64,6 +64,9 @@ class GraphCaptureContext:
...
@@ -64,6 +64,9 @@ class GraphCaptureContext:
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM
=
int
(
torch
.
distributed
.
ReduceOp
.
SUM
)
def
_split_tensor_dict
(
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
...
@@ -489,9 +492,7 @@ class GroupCoordinator:
...
@@ -489,9 +492,7 @@ class GroupCoordinator:
if
input_
.
is_cpu
:
if
input_
.
is_cpu
:
if
is_shm_available
(
input_
.
dtype
,
self
.
world_size
,
self
.
local_size
):
if
is_shm_available
(
input_
.
dtype
,
self
.
world_size
,
self
.
local_size
):
torch
.
ops
.
sgl_kernel
.
shm_allreduce
(
torch
.
ops
.
sgl_kernel
.
shm_allreduce
(
input_
,
REDUCE_OP_SUM
)
input_
,
torch
.
distributed
.
ReduceOp
.
SUM
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
return
input_
...
...
python/sglang/srt/layers/attention/intel_amx_backend.py
View file @
7577f0e4
...
@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
...
@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
)
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
)
def
get_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
,
q
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
7577f0e4
...
@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
_is_cpu_amx_available
),
"Fp8LinearMethod on CPU requires that CPU has AMX support"
),
"Fp8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
)
return
return
else
:
else
:
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
7577f0e4
...
@@ -343,8 +343,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -343,8 +343,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
_is_cpu_amx_available
),
"W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
),
"W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
return
else
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
...
@@ -486,8 +485,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -486,8 +485,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
_is_cpu_amx_available
_is_cpu_amx_available
),
"W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
),
"W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
_amx_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
return
else
:
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
=
Parameter
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
7577f0e4
...
@@ -414,7 +414,7 @@ class Scheduler(
...
@@ -414,7 +414,7 @@ class Scheduler(
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
f
"available_gpu_mem=
{
avail_mem
:.
2
f
}
GB"
f
"
{
'available_cpu_mem'
if
self
.
device
==
'cpu'
else
'
available_gpu_mem
'
}
=
{
avail_mem
:.
2
f
}
GB"
)
)
# Init memory pool and cache
# Init memory pool and cache
...
@@ -2252,9 +2252,8 @@ class Scheduler(
...
@@ -2252,9 +2252,8 @@ class Scheduler(
"token_capacity"
:
int
(
self
.
max_total_num_tokens
),
"token_capacity"
:
int
(
self
.
max_total_num_tokens
),
}
}
if
not
_is_cpu
:
ret
[
"memory_usage"
][
"graph"
]
=
round
(
ret
[
"memory_usage"
][
"cuda_graph"
]
=
round
(
self
.
tp_worker
.
worker
.
model_runner
.
graph_mem_usage
,
2
self
.
tp_worker
.
worker
.
model_runner
.
cuda_graph_mem_usage
,
2
)
)
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_spec_accept_count
>
0
:
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_spec_accept_count
>
0
:
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
7577f0e4
...
@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
...
@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
(
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
f
"
{
'cpu graph'
if
self
.
device
==
'cpu'
else
'
cuda graph
'
}
:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
)
...
...
python/sglang/srt/model_executor/cpu_graph_runner.py
0 → 100644
View file @
7577f0e4
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cpu torch compile."""
# The implementation of CPUGraphRunner follows the CudaGraphRunner
from
__future__
import
annotations
import
logging
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
psutil
import
torch
import
tqdm
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
log_info_on_rank0
,
require_attn_tp_gather
,
require_gathered_buffer
,
require_mlp_sync
,
require_mlp_tp_gather
,
)
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
@
contextmanager
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
num_tokens
:
int
,
tp_group
:
GroupCoordinator
,
):
"""Patch the model to make it compatible with torch.compile"""
backup_ca_comm
=
None
try
:
if
enable_compile
:
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
dynamic
=
False
,
)
else
:
yield
model
.
forward
finally
:
if
enable_compile
:
tp_group
.
ca_comm
=
backup_ca_comm
def
set_torch_compile_config
():
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
torch
.
_inductor
.
config
.
freezing
=
True
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
1024
if
hasattr
(
torch
.
_dynamo
.
config
,
"cache_size_limit"
):
torch
.
_dynamo
.
config
.
cache_size_limit
=
1024
monkey_patch_torch_compile
()
def
get_batch_sizes_to_capture
(
model_runner
:
ModelRunner
):
server_args
=
model_runner
.
server_args
# cpu torch compile only speeds up decoding by
# reducing python overhead when bs is small
capture_bs
=
list
(
range
(
1
,
17
))
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
,
f
"
{
capture_bs
=
}
"
return
capture_bs
def
register_fake_ops
():
"""
Registers fake/meta implementations for all custom sgl_kernel CPU operators
using torch.library.register_fake to support torch.compile
"""
none_return_ops
=
[
"shm_allreduce"
,
"bmm_cpu"
,
"fused_add_rmsnorm_cpu"
,
"decode_attention_cpu"
,
"extend_attention_cpu"
,
]
for
op
in
none_return_ops
:
@
torch
.
library
.
register_fake
(
f
"sgl_kernel::
{
op
}
"
)
def
_
(
*
args
,
**
kwargs
):
return
for
op
in
[
"rmsnorm_cpu"
,
"l2norm_cpu"
,
"fused_experts_cpu"
,
"shared_expert_cpu"
,
]:
@
torch
.
library
.
register_fake
(
f
"sgl_kernel::
{
op
}
"
)
def
_
(
input
,
*
args
,
**
kwargs
):
return
torch
.
empty_like
(
input
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::qkv_proj_with_rope"
)
def
_
(
hidden_states
,
q_a_proj_weight
,
q_b_proj_weight
,
kv_a_proj_weight
,
w_kc
,
q_a_layernorm_weight
,
kv_a_layernorm_weight
,
positions
,
cos_sin_cache
,
eps
,
use_int8_w8a8
,
use_fp8_w8a16
,
q_a_proj_scale
,
q_b_proj_scale
,
kv_a_proj_scale
,
is_vnni
,
block_size
,
):
num_seqs
=
hidden_states
.
shape
[
0
]
num_heads
=
w_kc
.
shape
[
0
]
kv_lora_rank
=
w_kc
.
shape
[
1
]
qk_rope_head_dim
=
kv_a_proj_weight
.
shape
[
0
]
-
kv_lora_rank
q_input
=
torch
.
empty
(
num_seqs
,
num_heads
,
kv_lora_rank
+
qk_rope_head_dim
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
k_input
=
torch
.
empty
(
num_seqs
,
1
,
kv_lora_rank
+
qk_rope_head_dim
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
v_input
=
k_input
.
narrow
(
-
1
,
0
,
kv_lora_rank
)
return
q_input
,
k_input
,
v_input
@
torch
.
library
.
register_fake
(
"sgl_kernel::rotary_embedding_cpu"
)
def
_
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
):
if
query
.
ndim
==
2
:
return
query
,
key
else
:
return
torch
.
empty_like
(
query
),
torch
.
empty_like
(
key
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::qkv_proj_with_rope_fused_weight"
)
def
_
(
hidden_states
,
q_a_proj_weight
,
q_b_proj_weight
,
w_kc
,
q_a_layernorm_weight
,
kv_a_layernorm_weight
,
positions
,
cos_sin_cache
,
eps
,
use_int8_w8a8
,
use_fp8_w8a16
,
qkv_a_proj_scale
,
q_b_proj_scale
,
is_vnni
,
block_size
,
q_lora_rank
,
kv_lora_rank
,
qk_rope_head_dim
,
):
num_seqs
=
hidden_states
.
shape
[
0
]
num_heads
=
w_kc
.
shape
[
0
]
kv_lora_rank
=
w_kc
.
shape
[
1
]
weight_chunks
=
torch
.
split
(
q_a_proj_weight
,
[
q_lora_rank
,
kv_lora_rank
+
qk_rope_head_dim
],
dim
=
0
)
qk_rope_head_dim
=
weight_chunks
[
1
].
shape
[
0
]
-
kv_lora_rank
q_input
=
torch
.
empty
(
num_seqs
,
num_heads
,
kv_lora_rank
+
qk_rope_head_dim
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
k_input
=
torch
.
empty
(
num_seqs
,
1
,
kv_lora_rank
+
qk_rope_head_dim
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
v_input
=
k_input
.
narrow
(
-
1
,
0
,
kv_lora_rank
)
return
q_input
,
k_input
,
v_input
@
torch
.
library
.
register_fake
(
"sgl_kernel::weight_packed_linear"
)
def
_
(
x
,
weight
,
bias
,
is_vnni
):
return
x
.
new_empty
(
x
.
shape
[
0
],
weight
.
shape
[
0
])
@
torch
.
library
.
register_fake
(
"sgl_kernel::per_token_quant_int8_cpu"
)
def
_
(
input
):
M
=
input
.
shape
[
0
]
K
=
input
.
shape
[
1
]
Aq
=
input
.
new_empty
(
M
,
K
,
dtype
=
torch
.
int8
)
As
=
input
.
new_empty
(
M
,
dtype
=
torch
.
float32
)
return
Aq
,
As
@
torch
.
library
.
register_fake
(
"sgl_kernel::int8_scaled_mm_cpu"
)
def
_
(
mat1
,
mat2
,
scales1
,
scales2
,
bias
,
out_dtype
,
is_vnni
):
M
=
mat1
.
shape
[
0
]
N
=
mat2
.
shape
[
0
]
out
=
mat1
.
new_empty
(
M
,
N
,
dtype
=
out_dtype
)
return
out
@
torch
.
library
.
register_fake
(
"sgl_kernel::grouped_topk_cpu"
)
def
_
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
num_fused_shared_experts
,
routed_scaling_factor
,
num_token_non_padded
,
):
num_tokens
=
hidden_states
.
shape
[
0
]
shape
=
(
num_tokens
,
topk
)
device
=
hidden_states
.
device
topk_weights
=
torch
.
empty
(
shape
,
device
=
device
,
dtype
=
torch
.
float32
)
topk_ids
=
torch
.
empty
(
shape
,
device
=
device
,
dtype
=
torch
.
int
)
return
topk_weights
,
topk_ids
@
torch
.
library
.
register_fake
(
"sgl_kernel::biased_grouped_topk_cpu"
)
def
_
(
hidden_states
,
gating_output
,
correction_bias
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
num_fused_shared_experts
,
routed_scaling_factor
,
num_token_non_padded
,
):
num_tokens
=
hidden_states
.
shape
[
0
]
shape
=
(
num_tokens
,
topk
)
device
=
hidden_states
.
device
topk_weights
=
torch
.
empty
(
shape
,
device
=
device
,
dtype
=
torch
.
float32
)
topk_ids
=
torch
.
empty
(
shape
,
device
=
device
,
dtype
=
torch
.
int
)
return
topk_weights
,
topk_ids
@
torch
.
library
.
register_fake
(
"sgl_kernel::topk_sigmoid_cpu"
)
def
_
(
hidden_states
,
gating_output
,
topk
,
renormalize
):
num_tokens
=
hidden_states
.
shape
[
0
]
shape
=
(
num_tokens
,
topk
)
return
(
torch
.
empty
(
shape
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float
),
torch
.
empty
(
shape
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int
),
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::topk_softmax_cpu"
)
def
_
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
):
num_tokens
=
hidden_states
.
shape
[
0
]
shape
=
(
num_tokens
,
topk
)
return
(
torch
.
empty
(
shape
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float
),
torch
.
empty
(
shape
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int
),
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::silu_and_mul_cpu"
)
def
_
(
input
):
return
input
.
new_empty
(
input
.
shape
[
0
],
input
.
shape
[
1
]
//
2
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::int8_scaled_mm_with_quant"
)
def
_
(
mat1
,
mat2
,
scales2
,
bias
,
out_dtype
,
is_vnni
,
):
M
=
mat1
.
shape
[
0
]
N
=
mat2
.
shape
[
0
]
return
mat1
.
new_empty
(
M
,
N
,
dtype
=
out_dtype
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::fp8_scaled_mm_cpu"
)
def
_
(
mat1
,
mat2
,
scales2
,
block_size
,
bias
,
out_dtype
,
is_vnni
,
):
M
=
mat1
.
shape
[
0
]
N
=
mat2
.
shape
[
0
]
return
mat1
.
new_empty
(
M
,
N
,
dtype
=
out_dtype
)
# TODO Remove unnecessary settings for CPUGraphRunner.
# Re-abstract the graph runner and restructure CPUGraphRunner to reuse the same logic.
class
CPUGraphRunner
:
"""A CPUGraphRunner runs the forward pass of a model with cpu torch.compile."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Parse args
self
.
model_runner
=
model_runner
self
.
device
=
model_runner
.
device
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
self
.
is_encoder_decoder
=
model_runner
.
model_config
.
is_encoder_decoder
self
.
require_gathered_buffer
=
require_gathered_buffer
(
model_runner
.
server_args
)
self
.
require_mlp_tp_gather
=
require_mlp_tp_gather
(
model_runner
.
server_args
)
self
.
require_mlp_sync
=
require_mlp_sync
(
model_runner
.
server_args
)
self
.
require_attn_tp_gather
=
require_attn_tp_gather
(
model_runner
.
server_args
)
self
.
enable_two_batch_overlap
=
(
model_runner
.
server_args
.
enable_two_batch_overlap
)
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
enable_profile_cuda_graph
=
(
model_runner
.
server_args
.
enable_profile_cuda_graph
)
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if
model_runner
.
server_args
.
enable_return_hidden_states
:
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
assert
(
not
self
.
model_runner
.
server_args
.
enable_lora
),
"CPUGraphRunner does not support LoRA yet."
assert
(
not
self
.
enable_two_batch_overlap
),
"CPUGraphRunner does not support two batch overlap yet."
assert
(
not
self
.
require_mlp_tp_gather
),
"CPUGraphRunner does not support MLP TP gather yet."
assert
(
not
self
.
require_mlp_sync
),
"CPUGraphRunner does not support MLP sync yet."
assert
(
not
self
.
require_gathered_buffer
),
"CPUGraphRunner does not support gathered buffer yet."
assert
(
model_runner
.
spec_algorithm
==
SpeculativeAlgorithm
.
NONE
),
"CPUGraphRunner does not support speculative inference yet."
# TODO add compile support for encoder-decoder models
assert
(
not
self
.
is_encoder_decoder
),
"CPUGraphRunner does not support encoder-decoder models yet."
assert
self
.
dp_size
==
1
,
"CPUGraphRunner does not support DP yet."
assert
self
.
pp_size
==
1
,
"CPUGraphRunner does not support PP yet."
# Batch sizes to capture
self
.
capture_bs
=
get_batch_sizes_to_capture
(
model_runner
)
log_info_on_rank0
(
logger
,
f
"Capture cpu graph bs
{
self
.
capture_bs
}
"
)
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_graph_seq_len_fill_value
()
)
if
self
.
enable_torch_compile
:
register_fake_ops
()
set_torch_compile_config
()
# Graph inputs
with
torch
.
device
(
self
.
device
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int64
)
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int64
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
num_token_non_padded
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int64
)
self
.
custom_mask
=
torch
.
ones
(
(
(
self
.
seq_lens
.
sum
().
item
()
+
self
.
max_num_token
)
*
self
.
num_tokens_per_bs
),
dtype
=
torch
.
bool
,
device
=
self
.
device
,
)
# Capture
try
:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture CPU graph failed:
{
e
}
\n
{
CPU_GRAPH_CAPTURE_FAILED_MSG
}
"
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
is_bs_supported
=
forward_batch
.
batch_size
in
self
.
graphs
requested_capture_hidden_mode
=
max
(
forward_batch
.
capture_hidden_mode
,
(
forward_batch
.
spec_info
.
capture_hidden_mode
if
getattr
(
forward_batch
.
spec_info
,
"capture_hidden_mode"
,
None
)
is
not
None
else
CaptureHiddenMode
.
NULL
),
)
capture_hidden_mode_matches
=
(
requested_capture_hidden_mode
==
CaptureHiddenMode
.
NULL
or
requested_capture_hidden_mode
==
self
.
capture_hidden_mode
)
return
is_bs_supported
and
capture_hidden_mode_matches
def
capture
(
self
)
->
None
:
capture_range
=
(
tqdm
.
tqdm
(
list
(
reversed
(
self
.
capture_bs
)))
if
get_tensor_model_parallel_rank
()
==
0
else
reversed
(
self
.
capture_bs
)
)
for
bs
in
capture_range
:
if
get_tensor_model_parallel_rank
()
==
0
:
avail_mem
=
psutil
.
virtual_memory
().
available
/
(
1
<<
30
)
capture_range
.
set_description
(
f
"Capturing batches (
{
bs
=
}
{
avail_mem
=
:.
2
f
}
GB)"
)
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
capture_bs
,
num_tokens
=
bs
*
self
.
num_tokens_per_bs
,
tp_group
=
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
graph
,
output_buffers
,
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
output_buffers
[
bs
]
=
output_buffers
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
num_tokens
=
bs
*
self
.
num_tokens_per_bs
# Graph inputs
input_ids
=
self
.
input_ids
[:
num_tokens
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
self
.
num_token_non_padded
[...]
=
num_tokens
spec_info
=
self
.
get_spec_info
(
num_tokens
)
if
self
.
capture_hidden_mode
!=
CaptureHiddenMode
.
FULL
:
self
.
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
)
forward_batch
=
ForwardBatch
(
forward_mode
=
self
.
capture_forward_mode
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
().
item
(),
return_logprob
=
False
,
positions
=
positions
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
num_token_non_padded
=
self
.
num_token_non_padded
,
global_forward_mode
=
self
.
capture_forward_mode
,
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# Do infernence to avoid setting attr at runtime, e.g.,
# self.attn_mha.kv_b_proj = self.kv_b_proj for full graph compile on CPU
self
.
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
)
# Run and capture
def
run_once
():
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
,
)
return
logits_output_or_pp_proxy_tensors
with
torch
.
no_grad
():
for
_
in
range
(
2
):
self
.
model_runner
.
tp_group
.
barrier
()
out
=
run_once
()
return
forward
,
out
def
recapture_if_needed
(
self
,
forward_batch
:
ForwardBatch
):
# If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch
=
(
forward_batch
.
capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info
=
getattr
(
forward_batch
.
spec_info
,
"capture_hidden_mode"
,
CaptureHiddenMode
.
NULL
)
capture_hidden_mode_required_for_returning_hidden_states
=
(
CaptureHiddenMode
.
FULL
if
self
.
model_runner
.
server_args
.
enable_return_hidden_states
else
CaptureHiddenMode
.
NULL
)
# Determine the highest capture_hidden_mode required
# (If we have FULL, we can emulate LAST or NULL)
# (If we have LAST, we can emulate NULL)
required_capture_hidden_mode
=
max
(
capture_hidden_mode_required_by_forward_batch
,
capture_hidden_mode_required_by_spec_info
,
capture_hidden_mode_required_for_returning_hidden_states
,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if
self
.
capture_hidden_mode
!=
required_capture_hidden_mode
:
self
.
capture_hidden_mode
=
required_capture_hidden_mode
self
.
capture
()
# TODO add padding support for CPUGraphRunner
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
assert
(
pp_proxy_tensors
is
None
),
"PPProxyTensors is not supported in CPUGraphRunner yet."
self
.
recapture_if_needed
(
forward_batch
)
self
.
model_runner
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
output
=
self
.
graphs
[
forward_batch
.
batch_size
](
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
)
return
output
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_utils
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
else
:
spec_info
=
EagleVerifyInput
(
draft_token
=
None
,
custom_mask
=
self
.
custom_mask
,
positions
=
None
,
retrive_index
=
None
,
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
seq_lens_sum
=
None
,
seq_lens_cpu
=
None
,
)
return
spec_info
CPU_GRAPH_CAPTURE_FAILED_MSG
=
(
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --torch-compile-max-bs to a smaller value (e.g., 8)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
python/sglang/srt/model_executor/forward_batch_info.py
View file @
7577f0e4
...
@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
...
@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
or
self
==
ForwardMode
.
IDLE
or
self
==
ForwardMode
.
IDLE
)
)
def
is_cpu_graph
(
self
):
return
self
==
ForwardMode
.
DECODE
def
is_dummy_first
(
self
):
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
return
self
==
ForwardMode
.
DUMMY_FIRST
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7577f0e4
...
@@ -20,6 +20,7 @@ import json
...
@@ -20,6 +20,7 @@ import json
import
logging
import
logging
import
os
import
os
import
time
import
time
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool
,
ReqToTokenPool
,
SWAKVPool
,
SWAKVPool
,
)
)
from
sglang.srt.model_executor.cpu_graph_runner
import
CPUGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
...
@@ -360,12 +362,12 @@ class ModelRunner:
...
@@ -360,12 +362,12 @@ class ModelRunner:
self
.
init_cublas
()
self
.
init_cublas
()
self
.
init_attention_backend
()
self
.
init_attention_backend
()
self
.
init_device_graphs
()
self
.
init_device_graphs
()
elif
self
.
device
==
"npu"
:
elif
self
.
device
in
[
"npu"
,
"cpu"
]
:
self
.
init_attention_backend
()
self
.
init_attention_backend
()
self
.
init_device_graphs
()
self
.
init_device_graphs
()
else
:
else
:
self
.
graph_runner
=
None
self
.
graph_runner
=
None
self
.
cuda_
graph_mem_usage
=
0
self
.
graph_mem_usage
=
0
self
.
init_attention_backend
()
self
.
init_attention_backend
()
# auxiliary hidden capture mode. TODO: expose this to server args?
# auxiliary hidden capture mode. TODO: expose this to server args?
...
@@ -608,6 +610,11 @@ class ModelRunner:
...
@@ -608,6 +610,11 @@ class ModelRunner:
# Set local size to hint SGLang to use shared memory based AllReduce
# Set local size to hint SGLang to use shared memory based AllReduce
os
.
environ
[
"LOCAL_SIZE"
]
=
str
(
self
.
tp_size
)
os
.
environ
[
"LOCAL_SIZE"
]
=
str
(
self
.
tp_size
)
torch
.
ops
.
sgl_kernel
.
initialize
(
self
.
tp_size
,
self
.
tp_rank
)
torch
.
ops
.
sgl_kernel
.
initialize
(
self
.
tp_size
,
self
.
tp_rank
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::shm_allgather"
)
def
_
(
data
,
dim
):
return
torch
.
cat
([
data
]
*
self
.
tp_size
,
dim
=
dim
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
...
@@ -1619,30 +1626,39 @@ class ModelRunner:
...
@@ -1619,30 +1626,39 @@ class ModelRunner:
)
)
def
init_device_graphs
(
self
):
def
init_device_graphs
(
self
):
"""Capture
cuda
graphs."""
"""Capture
device
graphs."""
self
.
graph_runner
=
None
self
.
graph_runner
=
None
self
.
cuda_
graph_mem_usage
=
0
self
.
graph_mem_usage
=
0
if
not
self
.
is_generation
:
if
not
self
.
is_generation
:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
return
if
self
.
server_args
.
disable_cuda_graph
:
if
self
.
device
!=
"cpu"
and
self
.
server_args
.
disable_cuda_graph
:
return
if
self
.
device
==
"cpu"
and
not
self
.
server_args
.
enable_torch_compile
:
return
return
tic
=
time
.
perf_counter
()
tic
=
time
.
perf_counter
()
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
f
"Capture
{
'cpu graph'
if
self
.
device
==
'cpu'
else
'
cuda graph
'
}
begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
)
self
.
graph_runner
=
(
graph_runners
=
defaultdict
(
CudaGraphRunner
(
self
)
if
not
_is_npu
else
NPUGraphRunner
(
self
)
lambda
:
CudaGraphRunner
,
{
"cpu"
:
CPUGraphRunner
,
"npu"
:
NPUGraphRunner
,
},
)
)
self
.
graph_runner
=
graph_runners
[
self
.
device
](
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
cuda_
graph_mem_usage
=
before_mem
-
after_mem
self
.
graph_mem_usage
=
before_mem
-
after_mem
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
perf_counter
()
-
tic
:.
2
f
}
s. "
f
"Capture
{
'cpu graph'
if
self
.
device
==
'cpu'
else
'
cuda graph
'
}
end. Time elapsed:
{
time
.
perf_counter
()
-
tic
:.
2
f
}
s. "
f
"mem usage=
{
self
.
cuda_
graph_mem_usage
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
f
"mem usage=
{
self
.
graph_mem_usage
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
)
)
def
init_threads_binding
(
self
):
def
init_threads_binding
(
self
):
...
@@ -1787,18 +1803,24 @@ class ModelRunner:
...
@@ -1787,18 +1803,24 @@ class ModelRunner:
reinit_attn_backend
:
bool
=
False
,
reinit_attn_backend
:
bool
=
False
,
split_forward_count
:
int
=
1
,
split_forward_count
:
int
=
1
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
can_run_cuda_graph
=
bool
(
mode_check
=
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
forward_batch
.
forward_mode
.
is_cpu_graph
if
self
.
device
==
"cpu"
else
forward_batch
.
forward_mode
.
is_cuda_graph
)
can_run_graph
=
bool
(
mode_check
()
and
self
.
graph_runner
and
self
.
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
and
self
.
graph_runner
.
can_run
(
forward_batch
)
)
)
if
can_run_cuda_graph
:
if
can_run_graph
:
ret
=
self
.
graph_runner
.
replay
(
ret
=
self
.
graph_runner
.
replay
(
forward_batch
,
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
return
ret
,
can_run_
cuda_
graph
return
ret
,
can_run_graph
# For MLP sync
# For MLP sync
if
forward_batch
.
global_num_tokens_cpu
is
not
None
:
if
forward_batch
.
global_num_tokens_cpu
is
not
None
:
...
@@ -1833,7 +1855,7 @@ class ModelRunner:
...
@@ -1833,7 +1855,7 @@ class ModelRunner:
):
):
forward_batch
.
post_forward_mlp_sync_batch
(
ret
)
forward_batch
.
post_forward_mlp_sync_batch
(
ret
)
return
ret
,
can_run_
cuda_
graph
return
ret
,
can_run_graph
def
_preprocess_logits
(
def
_preprocess_logits
(
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
...
...
python/sglang/srt/utils.py
View file @
7577f0e4
...
@@ -230,8 +230,16 @@ except:
...
@@ -230,8 +230,16 @@ except:
is_intel_amx_backend_available
=
False
is_intel_amx_backend_available
=
False
try
:
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
# to support torch compile
is_amx_tile_supported
=
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
except
:
is_amx_tile_supported
=
False
def
cpu_has_amx_support
():
def
cpu_has_amx_support
():
return
torch
.
_C
.
_cpu
.
_
is_amx_tile_supported
()
and
is_intel_amx_backend_available
return
is_amx_tile_supported
and
is_intel_amx_backend_available
def
use_intel_amx_backend
(
layer
):
def
use_intel_amx_backend
(
layer
):
...
...
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
View file @
7577f0e4
...
@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"rmsnorm_cpu"
,
torch
::
kCPU
,
&
rmsnorm_cpu
);
m
.
impl
(
"rmsnorm_cpu"
,
torch
::
kCPU
,
&
rmsnorm_cpu
);
m
.
def
(
"l2norm_cpu(Tensor input, float eps) -> Tensor"
);
m
.
def
(
"l2norm_cpu(Tensor input, float eps) -> Tensor"
);
m
.
impl
(
"l2norm_cpu"
,
torch
::
kCPU
,
&
l2norm_cpu
);
m
.
impl
(
"l2norm_cpu"
,
torch
::
kCPU
,
&
l2norm_cpu
);
m
.
def
(
"fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"
);
m
.
def
(
"fused_add_rmsnorm_cpu(Tensor
(a!)
input, Tensor residual, Tensor weight, float eps) -> ()"
);
m
.
impl
(
"fused_add_rmsnorm_cpu"
,
torch
::
kCPU
,
&
fused_add_rmsnorm_cpu
);
m
.
impl
(
"fused_add_rmsnorm_cpu"
,
torch
::
kCPU
,
&
fused_add_rmsnorm_cpu
);
// topk
// topk
...
@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// decode
// decode
m
.
def
(
m
.
def
(
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor
(a!)
output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()"
);
"float logit_cap) -> ()"
);
m
.
impl
(
"decode_attention_cpu"
,
torch
::
kCPU
,
&
decode_attention_cpu
);
m
.
impl
(
"decode_attention_cpu"
,
torch
::
kCPU
,
&
decode_attention_cpu
);
// extend
// extend
m
.
def
(
m
.
def
(
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor
(a!)
o_extend, Tensor k_buffer, "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"
);
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"
);
m
.
impl
(
"extend_attention_cpu"
,
torch
::
kCPU
,
&
extend_attention_cpu
);
m
.
impl
(
"extend_attention_cpu"
,
torch
::
kCPU
,
&
extend_attention_cpu
);
...
@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
&
int8_scaled_mm_with_quant
);
m
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
&
int8_scaled_mm_with_quant
);
// bmm
// bmm
m
.
def
(
"bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"
);
m
.
def
(
"bmm_cpu(Tensor
(a!)
out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"
);
m
.
impl
(
"bmm_cpu"
,
torch
::
kCPU
,
&
bmm_cpu
);
m
.
impl
(
"bmm_cpu"
,
torch
::
kCPU
,
&
bmm_cpu
);
// moe
// moe
...
@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce
// all reduce
m
.
def
(
"initialize(int size, int rank) -> ()"
);
m
.
def
(
"initialize(int size, int rank) -> ()"
);
m
.
def
(
"shm_allreduce(Tensor data, int reduce_op) -> ()"
);
m
.
def
(
"shm_allreduce(Tensor
(a!)
data, int reduce_op) -> ()"
);
m
.
impl
(
"shm_allreduce"
,
torch
::
kCPU
,
&
shm_allreduce
);
m
.
impl
(
"shm_allreduce"
,
torch
::
kCPU
,
&
shm_allreduce
);
m
.
def
(
"shm_allgather(Tensor data, int dim) -> Tensor"
);
m
.
def
(
"shm_allgather(Tensor data, int dim) -> Tensor"
);
m
.
impl
(
"shm_allgather"
,
torch
::
kCPU
,
&
shm_allgather
);
m
.
impl
(
"shm_allgather"
,
torch
::
kCPU
,
&
shm_allgather
);
...
...
test/srt/run_suite.py
View file @
7577f0e4
...
@@ -276,6 +276,7 @@ suite_xeon = {
...
@@ -276,6 +276,7 @@ suite_xeon = {
TestFile
(
"cpu/test_shared_expert.py"
),
TestFile
(
"cpu/test_shared_expert.py"
),
TestFile
(
"cpu/test_topk.py"
),
TestFile
(
"cpu/test_topk.py"
),
TestFile
(
"test_intel_amx_attention_backend.py"
),
TestFile
(
"test_intel_amx_attention_backend.py"
),
TestFile
(
"test_cpu_graph.py"
),
],
],
}
}
...
...
test/srt/test_cpu_graph.py
0 → 100644
View file @
7577f0e4
"""
Usage:
python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu
"""
import
copy
import
os
import
unittest
from
types
import
SimpleNamespace
from
test_intel_amx_attention_backend
import
intel_amx_benchmark
from
sglang.srt.utils
import
get_cpu_ids_by_node
,
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
)
class
TestCPUGraph
(
CustomTestCase
):
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"1"
,
"--mem-fraction-static"
,
"0.05"
,
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"1"
,
],
min_throughput
=
10
,
)
def
test_latency_torch_compile_cpu
(
self
):
return
DEFAULT_MLA_MODEL_NAME_FOR_TEST
def
test_mmlu_torch_compile_cpu
(
self
):
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
cpu_ids_by_node
=
get_cpu_ids_by_node
()
n_numa_node
=
len
(
cpu_ids_by_node
)
env
=
copy
.
deepcopy
(
os
.
environ
)
env
[
"SGLANG_CPU_OMP_THREADS_BIND"
]
=
"all"
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--attention-backend"
,
"intel_amx"
,
"--mem-fraction-static"
,
"0.05"
,
"--disable-radix"
,
"--trust-remote-code"
,
"--disable-overlap-schedule"
,
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"1"
,
"--tp"
,
f
"
{
n_numa_node
}
"
,
],
env
=
env
,
)
try
:
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
if
is_in_ci
():
self
.
assertGreater
(
metrics
[
"score"
],
0.45
)
finally
:
kill_process_tree
(
process
.
pid
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_intel_amx_attention_backend.py
View file @
7577f0e4
...
@@ -3,7 +3,6 @@ Usage:
...
@@ -3,7 +3,6 @@ Usage:
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
"""
"""
import
os
import
unittest
import
unittest
from
functools
import
wraps
from
functools
import
wraps
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
...
@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
"intel_amx"
,
"intel_amx"
,
"--disable-radix"
,
"--disable-radix"
,
"--trust-remote-code"
,
"--trust-remote-code"
,
"--batch-size"
,
"4"
,
]
]
full_args
=
common_args
+
(
extra_args
or
[])
full_args
=
common_args
+
(
extra_args
or
[])
...
@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
...
@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
class
TestIntelAMXAttnBackend
(
CustomTestCase
):
class
TestIntelAMXAttnBackend
(
CustomTestCase
):
@
intel_amx_benchmark
(
min_throughput
=
10
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
10
)
def
test_latency_mla_model
(
self
):
def
test_latency_mla_model
(
self
):
return
DEFAULT_MLA_MODEL_NAME_FOR_TEST
return
DEFAULT_MLA_MODEL_NAME_FOR_TEST
@
intel_amx_benchmark
(
min_throughput
=
40
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
40
)
def
test_latency_default_model
(
self
):
def
test_latency_default_model
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST
return
DEFAULT_MODEL_NAME_FOR_TEST
@
intel_amx_benchmark
(
min_throughput
=
150
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
150
)
def
test_latency_fp8_qwen
(
self
):
def
test_latency_fp8_qwen
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8
return
DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8
@
intel_amx_benchmark
(
min_throughput
=
50
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
50
)
def
test_latency_fp8_moe_model
(
self
):
def
test_latency_fp8_moe_model
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE
return
DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE
@
intel_amx_benchmark
(
extra_args
=
[
"--quantization"
,
"w8a8_int8"
],
min_throughput
=
100
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
,
"--quantization"
,
"w8a8_int8"
],
min_throughput
=
100
,
)
def
test_latency_w8a8_default_model
(
self
):
def
test_latency_w8a8_default_model
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST_W8A8
return
DEFAULT_MODEL_NAME_FOR_TEST_W8A8
@
intel_amx_benchmark
(
@
intel_amx_benchmark
(
extra_args
=
[
extra_args
=
[
"--batch-size"
,
"4"
,
"--quantization"
,
"--quantization"
,
"w8a8_int8"
,
"w8a8_int8"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment