Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64587211
Unverified
Commit
64587211
authored
Jun 13, 2025
by
Li, Jiang
Committed by
GitHub
Jun 13, 2025
Browse files
[CPU] Refine default config for the CPU backend (#19539)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
bb4a0dec
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
25 deletions
+60
-25
.buildkite/scripts/hardware_ci/run-cpu-test.sh
.buildkite/scripts/hardware_ci/run-cpu-test.sh
+12
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+22
-4
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+14
-6
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+12
-12
No files found.
.buildkite/scripts/hardware_ci/run-cpu-test.sh
View file @
64587211
...
@@ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
...
@@ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl
-C
"
$CORE_RANGE
"
-N
"
$NUMA_NODE
"
docker build
--build-arg
VLLM_CPU_DISABLE_AVX512
=
"true"
--tag
cpu-test-
"
$NUMA_NODE
"
-avx2
--target
vllm-test
-f
docker/Dockerfile.cpu
.
numactl
-C
"
$CORE_RANGE
"
-N
"
$NUMA_NODE
"
docker build
--build-arg
VLLM_CPU_DISABLE_AVX512
=
"true"
--tag
cpu-test-
"
$NUMA_NODE
"
-avx2
--target
vllm-test
-f
docker/Dockerfile.cpu
.
# Run the image, setting --shm-size=4g for tensor parallel.
# Run the image, setting --shm-size=4g for tensor parallel.
docker run
-itd
--cpuset-cpus
=
"
$CORE_RANGE
"
--cpuset-mems
=
"
$NUMA_NODE
"
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--env
VLLM_CPU_OMP_THREADS_BIND
=
"
$OMP_CORE_RANGE
"
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
cpu-test-
"
$NUMA_NODE
"
docker run
-itd
--cpuset-cpus
=
"
$CORE_RANGE
"
--cpuset-mems
=
"
$NUMA_NODE
"
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--env
VLLM_CPU_OMP_THREADS_BIND
=
"
$OMP_CORE_RANGE
"
--env
VLLM_CPU_CI_ENV
=
1
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
cpu-test-
"
$NUMA_NODE
"
docker run
-itd
--cpuset-cpus
=
"
$CORE_RANGE
"
--cpuset-mems
=
"
$NUMA_NODE
"
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--env
VLLM_CPU_OMP_THREADS_BIND
=
"
$OMP_CORE_RANGE
"
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
-avx2
cpu-test-
"
$NUMA_NODE
"
-avx2
docker run
-itd
--cpuset-cpus
=
"
$CORE_RANGE
"
--cpuset-mems
=
"
$NUMA_NODE
"
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--env
VLLM_CPU_OMP_THREADS_BIND
=
"
$OMP_CORE_RANGE
"
--env
VLLM_CPU_CI_ENV
=
1
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
-avx2
cpu-test-
"
$NUMA_NODE
"
-avx2
function
cpu_tests
()
{
function
cpu_tests
()
{
set
-e
set
-e
export
NUMA_NODE
=
$2
export
NUMA_NODE
=
$2
# list packages
docker
exec
cpu-test-
"
$NUMA_NODE
"
-avx2
bash
-c
"
set -e
pip list"
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
set -e
pip list"
# offline inference
# offline inference
docker
exec
cpu-test-
"
$NUMA_NODE
"
-avx2
bash
-c
"
docker
exec
cpu-test-
"
$NUMA_NODE
"
-avx2
bash
-c
"
set -e
set -e
...
@@ -72,7 +81,7 @@ function cpu_tests() {
...
@@ -72,7 +81,7 @@ function cpu_tests() {
set -e
set -e
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py
\
VLLM_CPU_CI_ENV=0
python3 benchmarks/benchmark_serving.py
\
--backend vllm
\
--backend vllm
\
--dataset-name random
\
--dataset-name random
\
--model facebook/opt-125m
\
--model facebook/opt-125m
\
...
...
vllm/engine/arg_utils.py
View file @
64587211
...
@@ -1562,14 +1562,20 @@ class EngineArgs:
...
@@ -1562,14 +1562,20 @@ class EngineArgs:
UsageContext
.
LLM_CLASS
:
16384
,
UsageContext
.
LLM_CLASS
:
16384
,
UsageContext
.
OPENAI_API_SERVER
:
8192
,
UsageContext
.
OPENAI_API_SERVER
:
8192
,
}
}
default_max_num_seqs
=
1024
default_max_num_seqs
=
{
UsageContext
.
LLM_CLASS
:
1024
,
UsageContext
.
OPENAI_API_SERVER
:
1024
,
}
else
:
else
:
# TODO(woosuk): Tune the default values for other hardware.
# TODO(woosuk): Tune the default values for other hardware.
default_max_num_batched_tokens
=
{
default_max_num_batched_tokens
=
{
UsageContext
.
LLM_CLASS
:
8192
,
UsageContext
.
LLM_CLASS
:
8192
,
UsageContext
.
OPENAI_API_SERVER
:
2048
,
UsageContext
.
OPENAI_API_SERVER
:
2048
,
}
}
default_max_num_seqs
=
256
default_max_num_seqs
=
{
UsageContext
.
LLM_CLASS
:
256
,
UsageContext
.
OPENAI_API_SERVER
:
256
,
}
# tpu specific default values.
# tpu specific default values.
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
...
@@ -1586,6 +1592,17 @@ class EngineArgs:
...
@@ -1586,6 +1592,17 @@ class EngineArgs:
}
}
}
}
# cpu specific default values.
if
current_platform
.
is_cpu
():
default_max_num_batched_tokens
=
{
UsageContext
.
LLM_CLASS
:
4096
,
UsageContext
.
OPENAI_API_SERVER
:
2048
,
}
default_max_num_seqs
=
{
UsageContext
.
LLM_CLASS
:
128
,
UsageContext
.
OPENAI_API_SERVER
:
32
,
}
use_context_value
=
usage_context
.
value
if
usage_context
else
None
use_context_value
=
usage_context
.
value
if
usage_context
else
None
if
(
self
.
max_num_batched_tokens
is
None
if
(
self
.
max_num_batched_tokens
is
None
and
usage_context
in
default_max_num_batched_tokens
):
and
usage_context
in
default_max_num_batched_tokens
):
...
@@ -1606,8 +1623,9 @@ class EngineArgs:
...
@@ -1606,8 +1623,9 @@ class EngineArgs:
"Setting max_num_batched_tokens to %d for %s usage context."
,
"Setting max_num_batched_tokens to %d for %s usage context."
,
self
.
max_num_batched_tokens
,
use_context_value
)
self
.
max_num_batched_tokens
,
use_context_value
)
if
self
.
max_num_seqs
is
None
:
if
(
self
.
max_num_seqs
is
None
self
.
max_num_seqs
=
default_max_num_seqs
and
usage_context
in
default_max_num_seqs
):
self
.
max_num_seqs
=
default_max_num_seqs
[
usage_context
]
logger
.
debug
(
"Setting max_num_seqs to %d for %s usage context."
,
logger
.
debug
(
"Setting max_num_seqs to %d for %s usage context."
,
self
.
max_num_seqs
,
use_context_value
)
self
.
max_num_seqs
,
use_context_value
)
...
...
vllm/platforms/cpu.py
View file @
64587211
...
@@ -89,10 +89,6 @@ class CpuPlatform(Platform):
...
@@ -89,10 +89,6 @@ class CpuPlatform(Platform):
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.utils
import
GiB_bytes
from
vllm.utils
import
GiB_bytes
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if
not
model_config
.
enforce_eager
:
model_config
.
enforce_eager
=
True
model_config
.
disable_cascade_attn
=
True
model_config
.
disable_cascade_attn
=
True
...
@@ -171,9 +167,21 @@ class CpuPlatform(Platform):
...
@@ -171,9 +167,21 @@ class CpuPlatform(Platform):
compilation_config
=
vllm_config
.
compilation_config
compilation_config
=
vllm_config
.
compilation_config
if
(
envs
.
VLLM_USE_V1
and
vllm_config
.
compilation_config
.
level
if
(
envs
.
VLLM_USE_V1
and
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
):
==
CompilationLevel
.
PIECEWISE
):
# Note: vLLM V1 is using PIECEWISE level compilation, which will
# take time to compile kernels just-in-time with the inductor
# backend. For CPU CI tests, most of them are executed fast and
# compilations consume too much time, even with torch compile
# cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
# and just execute model with dynamo + eager mode to save time.
# VLLM_CPU_CI_ENV is only used as an internal variable.
if
os
.
environ
.
get
(
"VLLM_CPU_CI_ENV"
,
"0"
)
!=
"0"
:
backend
=
"eager"
else
:
backend
=
"inductor"
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
compilation_config
.
backend
=
"eager"
compilation_config
.
backend
=
backend
compilation_config
.
custom_ops
+=
[
"none"
]
compilation_config
.
inductor_compile_config
.
update
({
compilation_config
.
inductor_compile_config
.
update
({
"dce"
:
"dce"
:
True
,
True
,
...
...
vllm/v1/worker/cpu_model_runner.py
View file @
64587211
...
@@ -60,7 +60,8 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -60,7 +60,8 @@ class CPUModelRunner(GPUModelRunner):
def
warming_up_model
(
self
)
->
None
:
def
warming_up_model
(
self
)
->
None
:
logger
.
info
(
"Warming up model for the compilation..."
)
logger
.
info
(
"Warming up model for the compilation..."
)
# Only generate graph for the generic shape
# Only generate graph for the generic shape
self
.
_dummy_run
(
max
(
16
,
self
.
max_num_reqs
))
with
_set_global_compilation_settings
(
self
.
vllm_config
):
self
.
_dummy_run
(
max
(
16
,
self
.
max_num_reqs
))
logger
.
info
(
"Warming up done."
)
logger
.
info
(
"Warming up done."
)
def
_init_device_properties
(
self
)
->
None
:
def
_init_device_properties
(
self
)
->
None
:
...
@@ -71,16 +72,15 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -71,16 +72,15 @@ class CPUModelRunner(GPUModelRunner):
@
contextmanager
@
contextmanager
def
_set_global_compilation_settings
():
def
_set_global_compilation_settings
(
config
:
VllmConfig
):
import
torch._inductor.config
import
torch._inductor.config
# Note: The CPPGEMM backend requires freezing parameters.
inductor_config
=
config
.
compilation_config
.
inductor_compile_config
freezing_value
=
torch
.
_inductor
.
config
.
freezing
try
:
torch
.
_inductor
.
config
.
freezing
=
True
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
freezing_value
=
torch
.
_inductor
.
config
.
freezing
# including object type dict"
if
inductor_config
.
get
(
"max_autotune"
,
False
):
force_disable_caches
=
torch
.
_inductor
.
config
.
force_disable_caches
torch
.
_inductor
.
config
.
freezing
=
True
torch
.
_inductor
.
config
.
force_disable_caches
=
True
yield
yield
finally
:
torch
.
_inductor
.
config
.
freezing
=
freezing_value
torch
.
_inductor
.
config
.
freezing
=
freezing_value
torch
.
_inductor
.
config
.
force_disable_caches
=
force_disable_caches
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment