Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7577f0e4
Unverified
Commit
7577f0e4
authored
Sep 08, 2025
by
Cao E
Committed by
GitHub
Sep 07, 2025
Browse files
Add graph runner support with torch compile on CPU (#7843)
parent
8cda5a62
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
820 additions
and
48 deletions
+820
-48
.github/workflows/pr-test-xeon.yml
.github/workflows/pr-test-xeon.yml
+1
-1
docs/platforms/cpu_server.md
docs/platforms/cpu_server.md
+6
-1
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+4
-3
python/sglang/srt/layers/attention/intel_amx_backend.py
python/sglang/srt/layers/attention/intel_amx_backend.py
+3
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+3
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+5
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-5
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+1
-1
python/sglang/srt/model_executor/cpu_graph_runner.py
python/sglang/srt/model_executor/cpu_graph_runner.py
+640
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+38
-16
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-1
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
+5
-5
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_cpu_graph.py
test/srt/test_cpu_graph.py
+87
-0
test/srt/test_intel_amx_attention_backend.py
test/srt/test_intel_amx_attention_backend.py
+10
-8
No files found.
.github/workflows/pr-test-xeon.yml
View file @
7577f0e4
...
@@ -70,7 +70,7 @@ jobs:
...
@@ -70,7 +70,7 @@ jobs:
-
name
:
Run unit tests
-
name
:
Run unit tests
if
:
steps.check_amx.outcome == 'success'
if
:
steps.check_amx.outcome == 'success'
timeout-minutes
:
3
0
timeout-minutes
:
3
6
run
:
|
run
:
|
docker exec -w /sglang-checkout/ ci_sglang_xeon \
docker exec -w /sglang-checkout/ ci_sglang_xeon \
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
...
...
docs/platforms/cpu_server.md
View file @
7577f0e4
...
@@ -134,7 +134,12 @@ Notes:
...
@@ -134,7 +134,12 @@ Notes:
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
```
```
3.
A warmup step is automatically triggered when the service is started.
3.
For optimizing decoding with torch.compile, please add the flag
`--enable-torch-compile`
.
To specify the maximum batch size when using torch compile, set the flag
`--torch-compile-max-bs`
.
For example,
`--enable-torch-compile --torch-compile-max-bs 4`
means using torch compile and setting the
maximum batch size to 4.
4.
A warmup step is automatically triggered when the service is started.
The server is ready when you see the log
`The server is fired up and ready to roll!`
.
The server is ready when you see the log
`The server is fired up and ready to roll!`
.
## Benchmarking with Requests
## Benchmarking with Requests
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
7577f0e4
...
@@ -64,6 +64,9 @@ class GraphCaptureContext:
...
@@ -64,6 +64,9 @@ class GraphCaptureContext:
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM
=
int
(
torch
.
distributed
.
ReduceOp
.
SUM
)
def
_split_tensor_dict
(
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
...
@@ -489,9 +492,7 @@ class GroupCoordinator:
...
@@ -489,9 +492,7 @@ class GroupCoordinator:
if
input_
.
is_cpu
:
if
input_
.
is_cpu
:
if
is_shm_available
(
input_
.
dtype
,
self
.
world_size
,
self
.
local_size
):
if
is_shm_available
(
input_
.
dtype
,
self
.
world_size
,
self
.
local_size
):
torch
.
ops
.
sgl_kernel
.
shm_allreduce
(
torch
.
ops
.
sgl_kernel
.
shm_allreduce
(
input_
,
REDUCE_OP_SUM
)
input_
,
torch
.
distributed
.
ReduceOp
.
SUM
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
return
input_
...
...
python/sglang/srt/layers/attention/intel_amx_backend.py
View file @
7577f0e4
...
@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
...
@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
)
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
)
def
get_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
,
q
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
7577f0e4
...
@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
_is_cpu_amx_available
),
"Fp8LinearMethod on CPU requires that CPU has AMX support"
),
"Fp8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
)
return
return
else
:
else
:
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
7577f0e4
...
@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
_is_cpu_amx_available
_is_cpu_amx_available
),
"W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
),
"W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
_amx_process_weight_after_loading
(
layer
,
[
"weight"
])
return
else
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
def
create_weights
(
def
create_weights
(
...
@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
...
@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
_is_cpu_amx_available
_is_cpu_amx_available
),
"W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
),
"W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
_amx_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
return
else
:
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
7577f0e4
...
@@ -414,7 +414,7 @@ class Scheduler(
...
@@ -414,7 +414,7 @@ class Scheduler(
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
f
"available_gpu_mem=
{
avail_mem
:.
2
f
}
GB"
f
"
{
'available_cpu_mem'
if
self
.
device
==
'cpu'
else
'
available_gpu_mem
'
}
=
{
avail_mem
:.
2
f
}
GB"
)
)
# Init memory pool and cache
# Init memory pool and cache
...
@@ -2252,10 +2252,9 @@ class Scheduler(
...
@@ -2252,10 +2252,9 @@ class Scheduler(
"token_capacity"
:
int
(
self
.
max_total_num_tokens
),
"token_capacity"
:
int
(
self
.
max_total_num_tokens
),
}
}
if
not
_is_cpu
:
ret
[
"memory_usage"
][
"graph"
]
=
round
(
ret
[
"memory_usage"
][
"cuda_graph"
]
=
round
(
self
.
tp_worker
.
worker
.
model_runner
.
graph_mem_usage
,
2
self
.
tp_worker
.
worker
.
model_runner
.
cuda_graph_mem_usage
,
2
)
)
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_spec_accept_count
>
0
:
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_spec_accept_count
>
0
:
ret
[
"avg_spec_accept_length"
]
=
(
ret
[
"avg_spec_accept_length"
]
=
(
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
7577f0e4
...
@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
...
@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
(
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
f
"
{
'cpu graph'
if
self
.
device
==
'cpu'
else
'
cuda graph
'
}
:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
)
...
...
python/sglang/srt/model_executor/cpu_graph_runner.py
0 → 100644
View file @
7577f0e4
This diff is collapsed.
Click to expand it.
python/sglang/srt/model_executor/forward_batch_info.py
View file @
7577f0e4
...
@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
...
@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
or
self
==
ForwardMode
.
IDLE
or
self
==
ForwardMode
.
IDLE
)
)
def
is_cpu_graph
(
self
):
return
self
==
ForwardMode
.
DECODE
def
is_dummy_first
(
self
):
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
return
self
==
ForwardMode
.
DUMMY_FIRST
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7577f0e4
...
@@ -20,6 +20,7 @@ import json
...
@@ -20,6 +20,7 @@ import json
import
logging
import
logging
import
os
import
os
import
time
import
time
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool
,
ReqToTokenPool
,
SWAKVPool
,
SWAKVPool
,
)
)
from
sglang.srt.model_executor.cpu_graph_runner
import
CPUGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
from
sglang.srt.model_executor.npu_graph_runner
import
NPUGraphRunner
...
@@ -360,12 +362,12 @@ class ModelRunner:
...
@@ -360,12 +362,12 @@ class ModelRunner:
self
.
init_cublas
()
self
.
init_cublas
()
self
.
init_attention_backend
()
self
.
init_attention_backend
()
self
.
init_device_graphs
()
self
.
init_device_graphs
()
elif
self
.
device
==
"npu"
:
elif
self
.
device
in
[
"npu"
,
"cpu"
]
:
self
.
init_attention_backend
()
self
.
init_attention_backend
()
self
.
init_device_graphs
()
self
.
init_device_graphs
()
else
:
else
:
self
.
graph_runner
=
None
self
.
graph_runner
=
None
self
.
cuda_
graph_mem_usage
=
0
self
.
graph_mem_usage
=
0
self
.
init_attention_backend
()
self
.
init_attention_backend
()
# auxiliary hidden capture mode. TODO: expose this to server args?
# auxiliary hidden capture mode. TODO: expose this to server args?
...
@@ -608,6 +610,11 @@ class ModelRunner:
...
@@ -608,6 +610,11 @@ class ModelRunner:
# Set local size to hint SGLang to use shared memory based AllReduce
# Set local size to hint SGLang to use shared memory based AllReduce
os
.
environ
[
"LOCAL_SIZE"
]
=
str
(
self
.
tp_size
)
os
.
environ
[
"LOCAL_SIZE"
]
=
str
(
self
.
tp_size
)
torch
.
ops
.
sgl_kernel
.
initialize
(
self
.
tp_size
,
self
.
tp_rank
)
torch
.
ops
.
sgl_kernel
.
initialize
(
self
.
tp_size
,
self
.
tp_rank
)
@
torch
.
library
.
register_fake
(
"sgl_kernel::shm_allgather"
)
def
_
(
data
,
dim
):
return
torch
.
cat
([
data
]
*
self
.
tp_size
,
dim
=
dim
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
...
@@ -1619,30 +1626,39 @@ class ModelRunner:
...
@@ -1619,30 +1626,39 @@ class ModelRunner:
)
)
def
init_device_graphs
(
self
):
def
init_device_graphs
(
self
):
"""Capture
cuda
graphs."""
"""Capture
device
graphs."""
self
.
graph_runner
=
None
self
.
graph_runner
=
None
self
.
cuda_
graph_mem_usage
=
0
self
.
graph_mem_usage
=
0
if
not
self
.
is_generation
:
if
not
self
.
is_generation
:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
return
if
self
.
server_args
.
disable_cuda_graph
:
if
self
.
device
!=
"cpu"
and
self
.
server_args
.
disable_cuda_graph
:
return
if
self
.
device
==
"cpu"
and
not
self
.
server_args
.
enable_torch_compile
:
return
return
tic
=
time
.
perf_counter
()
tic
=
time
.
perf_counter
()
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
f
"Capture
{
'cpu graph'
if
self
.
device
==
'cpu'
else
'
cuda graph
'
}
begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
)
self
.
graph_runner
=
(
graph_runners
=
defaultdict
(
CudaGraphRunner
(
self
)
if
not
_is_npu
else
NPUGraphRunner
(
self
)
lambda
:
CudaGraphRunner
,
{
"cpu"
:
CPUGraphRunner
,
"npu"
:
NPUGraphRunner
,
},
)
)
self
.
graph_runner
=
graph_runners
[
self
.
device
](
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
cuda_
graph_mem_usage
=
before_mem
-
after_mem
self
.
graph_mem_usage
=
before_mem
-
after_mem
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
perf_counter
()
-
tic
:.
2
f
}
s. "
f
"Capture
{
'cpu graph'
if
self
.
device
==
'cpu'
else
'
cuda graph
'
}
end. Time elapsed:
{
time
.
perf_counter
()
-
tic
:.
2
f
}
s. "
f
"mem usage=
{
self
.
cuda_
graph_mem_usage
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
f
"mem usage=
{
self
.
graph_mem_usage
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
)
)
def
init_threads_binding
(
self
):
def
init_threads_binding
(
self
):
...
@@ -1787,18 +1803,24 @@ class ModelRunner:
...
@@ -1787,18 +1803,24 @@ class ModelRunner:
reinit_attn_backend
:
bool
=
False
,
reinit_attn_backend
:
bool
=
False
,
split_forward_count
:
int
=
1
,
split_forward_count
:
int
=
1
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
can_run_cuda_graph
=
bool
(
mode_check
=
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
forward_batch
.
forward_mode
.
is_cpu_graph
if
self
.
device
==
"cpu"
else
forward_batch
.
forward_mode
.
is_cuda_graph
)
can_run_graph
=
bool
(
mode_check
()
and
self
.
graph_runner
and
self
.
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
and
self
.
graph_runner
.
can_run
(
forward_batch
)
)
)
if
can_run_cuda_graph
:
if
can_run_graph
:
ret
=
self
.
graph_runner
.
replay
(
ret
=
self
.
graph_runner
.
replay
(
forward_batch
,
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
return
ret
,
can_run_
cuda_
graph
return
ret
,
can_run_graph
# For MLP sync
# For MLP sync
if
forward_batch
.
global_num_tokens_cpu
is
not
None
:
if
forward_batch
.
global_num_tokens_cpu
is
not
None
:
...
@@ -1833,7 +1855,7 @@ class ModelRunner:
...
@@ -1833,7 +1855,7 @@ class ModelRunner:
):
):
forward_batch
.
post_forward_mlp_sync_batch
(
ret
)
forward_batch
.
post_forward_mlp_sync_batch
(
ret
)
return
ret
,
can_run_
cuda_
graph
return
ret
,
can_run_graph
def
_preprocess_logits
(
def
_preprocess_logits
(
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
...
...
python/sglang/srt/utils.py
View file @
7577f0e4
...
@@ -230,8 +230,16 @@ except:
...
@@ -230,8 +230,16 @@ except:
is_intel_amx_backend_available
=
False
is_intel_amx_backend_available
=
False
try
:
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
# to support torch compile
is_amx_tile_supported
=
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
except
:
is_amx_tile_supported
=
False
def
cpu_has_amx_support
():
def
cpu_has_amx_support
():
return
torch
.
_C
.
_cpu
.
_
is_amx_tile_supported
()
and
is_intel_amx_backend_available
return
is_amx_tile_supported
and
is_intel_amx_backend_available
def
use_intel_amx_backend
(
layer
):
def
use_intel_amx_backend
(
layer
):
...
...
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
View file @
7577f0e4
...
@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"rmsnorm_cpu"
,
torch
::
kCPU
,
&
rmsnorm_cpu
);
m
.
impl
(
"rmsnorm_cpu"
,
torch
::
kCPU
,
&
rmsnorm_cpu
);
m
.
def
(
"l2norm_cpu(Tensor input, float eps) -> Tensor"
);
m
.
def
(
"l2norm_cpu(Tensor input, float eps) -> Tensor"
);
m
.
impl
(
"l2norm_cpu"
,
torch
::
kCPU
,
&
l2norm_cpu
);
m
.
impl
(
"l2norm_cpu"
,
torch
::
kCPU
,
&
l2norm_cpu
);
m
.
def
(
"fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"
);
m
.
def
(
"fused_add_rmsnorm_cpu(Tensor
(a!)
input, Tensor residual, Tensor weight, float eps) -> ()"
);
m
.
impl
(
"fused_add_rmsnorm_cpu"
,
torch
::
kCPU
,
&
fused_add_rmsnorm_cpu
);
m
.
impl
(
"fused_add_rmsnorm_cpu"
,
torch
::
kCPU
,
&
fused_add_rmsnorm_cpu
);
// topk
// topk
...
@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// decode
// decode
m
.
def
(
m
.
def
(
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor
(a!)
output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()"
);
"float logit_cap) -> ()"
);
m
.
impl
(
"decode_attention_cpu"
,
torch
::
kCPU
,
&
decode_attention_cpu
);
m
.
impl
(
"decode_attention_cpu"
,
torch
::
kCPU
,
&
decode_attention_cpu
);
// extend
// extend
m
.
def
(
m
.
def
(
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor
(a!)
o_extend, Tensor k_buffer, "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"
);
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"
);
m
.
impl
(
"extend_attention_cpu"
,
torch
::
kCPU
,
&
extend_attention_cpu
);
m
.
impl
(
"extend_attention_cpu"
,
torch
::
kCPU
,
&
extend_attention_cpu
);
...
@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
&
int8_scaled_mm_with_quant
);
m
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
&
int8_scaled_mm_with_quant
);
// bmm
// bmm
m
.
def
(
"bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"
);
m
.
def
(
"bmm_cpu(Tensor
(a!)
out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"
);
m
.
impl
(
"bmm_cpu"
,
torch
::
kCPU
,
&
bmm_cpu
);
m
.
impl
(
"bmm_cpu"
,
torch
::
kCPU
,
&
bmm_cpu
);
// moe
// moe
...
@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce
// all reduce
m
.
def
(
"initialize(int size, int rank) -> ()"
);
m
.
def
(
"initialize(int size, int rank) -> ()"
);
m
.
def
(
"shm_allreduce(Tensor data, int reduce_op) -> ()"
);
m
.
def
(
"shm_allreduce(Tensor
(a!)
data, int reduce_op) -> ()"
);
m
.
impl
(
"shm_allreduce"
,
torch
::
kCPU
,
&
shm_allreduce
);
m
.
impl
(
"shm_allreduce"
,
torch
::
kCPU
,
&
shm_allreduce
);
m
.
def
(
"shm_allgather(Tensor data, int dim) -> Tensor"
);
m
.
def
(
"shm_allgather(Tensor data, int dim) -> Tensor"
);
m
.
impl
(
"shm_allgather"
,
torch
::
kCPU
,
&
shm_allgather
);
m
.
impl
(
"shm_allgather"
,
torch
::
kCPU
,
&
shm_allgather
);
...
...
test/srt/run_suite.py
View file @
7577f0e4
...
@@ -276,6 +276,7 @@ suite_xeon = {
...
@@ -276,6 +276,7 @@ suite_xeon = {
TestFile
(
"cpu/test_shared_expert.py"
),
TestFile
(
"cpu/test_shared_expert.py"
),
TestFile
(
"cpu/test_topk.py"
),
TestFile
(
"cpu/test_topk.py"
),
TestFile
(
"test_intel_amx_attention_backend.py"
),
TestFile
(
"test_intel_amx_attention_backend.py"
),
TestFile
(
"test_cpu_graph.py"
),
],
],
}
}
...
...
test/srt/test_cpu_graph.py
0 → 100644
View file @
7577f0e4
"""
Usage:
python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu
"""
import
copy
import
os
import
unittest
from
types
import
SimpleNamespace
from
test_intel_amx_attention_backend
import
intel_amx_benchmark
from
sglang.srt.utils
import
get_cpu_ids_by_node
,
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
)
class
TestCPUGraph
(
CustomTestCase
):
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"1"
,
"--mem-fraction-static"
,
"0.05"
,
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"1"
,
],
min_throughput
=
10
,
)
def
test_latency_torch_compile_cpu
(
self
):
return
DEFAULT_MLA_MODEL_NAME_FOR_TEST
def
test_mmlu_torch_compile_cpu
(
self
):
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
cpu_ids_by_node
=
get_cpu_ids_by_node
()
n_numa_node
=
len
(
cpu_ids_by_node
)
env
=
copy
.
deepcopy
(
os
.
environ
)
env
[
"SGLANG_CPU_OMP_THREADS_BIND"
]
=
"all"
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--attention-backend"
,
"intel_amx"
,
"--mem-fraction-static"
,
"0.05"
,
"--disable-radix"
,
"--trust-remote-code"
,
"--disable-overlap-schedule"
,
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"1"
,
"--tp"
,
f
"
{
n_numa_node
}
"
,
],
env
=
env
,
)
try
:
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
if
is_in_ci
():
self
.
assertGreater
(
metrics
[
"score"
],
0.45
)
finally
:
kill_process_tree
(
process
.
pid
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_intel_amx_attention_backend.py
View file @
7577f0e4
...
@@ -3,7 +3,6 @@ Usage:
...
@@ -3,7 +3,6 @@ Usage:
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
"""
"""
import
os
import
unittest
import
unittest
from
functools
import
wraps
from
functools
import
wraps
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
...
@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
"intel_amx"
,
"intel_amx"
,
"--disable-radix"
,
"--disable-radix"
,
"--trust-remote-code"
,
"--trust-remote-code"
,
"--batch-size"
,
"4"
,
]
]
full_args
=
common_args
+
(
extra_args
or
[])
full_args
=
common_args
+
(
extra_args
or
[])
...
@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
...
@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
class
TestIntelAMXAttnBackend
(
CustomTestCase
):
class
TestIntelAMXAttnBackend
(
CustomTestCase
):
@
intel_amx_benchmark
(
min_throughput
=
10
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
10
)
def
test_latency_mla_model
(
self
):
def
test_latency_mla_model
(
self
):
return
DEFAULT_MLA_MODEL_NAME_FOR_TEST
return
DEFAULT_MLA_MODEL_NAME_FOR_TEST
@
intel_amx_benchmark
(
min_throughput
=
40
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
40
)
def
test_latency_default_model
(
self
):
def
test_latency_default_model
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST
return
DEFAULT_MODEL_NAME_FOR_TEST
@
intel_amx_benchmark
(
min_throughput
=
150
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
150
)
def
test_latency_fp8_qwen
(
self
):
def
test_latency_fp8_qwen
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8
return
DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8
@
intel_amx_benchmark
(
min_throughput
=
50
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
],
min_throughput
=
50
)
def
test_latency_fp8_moe_model
(
self
):
def
test_latency_fp8_moe_model
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE
return
DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE
@
intel_amx_benchmark
(
extra_args
=
[
"--quantization"
,
"w8a8_int8"
],
min_throughput
=
100
)
@
intel_amx_benchmark
(
extra_args
=
[
"--batch-size"
,
"4"
,
"--quantization"
,
"w8a8_int8"
],
min_throughput
=
100
,
)
def
test_latency_w8a8_default_model
(
self
):
def
test_latency_w8a8_default_model
(
self
):
return
DEFAULT_MODEL_NAME_FOR_TEST_W8A8
return
DEFAULT_MODEL_NAME_FOR_TEST_W8A8
@
intel_amx_benchmark
(
@
intel_amx_benchmark
(
extra_args
=
[
extra_args
=
[
"--batch-size"
,
"4"
,
"--quantization"
,
"--quantization"
,
"w8a8_int8"
,
"w8a8_int8"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment