Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4555143e
Unverified
Commit
4555143e
authored
Jun 04, 2025
by
Li, Jiang
Committed by
GitHub
Jun 03, 2025
Browse files
[CPU] V1 support for the CPU backend (#16441)
parent
52dceb17
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
465 additions
and
40 deletions
+465
-40
.buildkite/scripts/hardware_ci/run-cpu-test.sh
.buildkite/scripts/hardware_ci/run-cpu-test.sh
+5
-8
docs/usage/v1_guide.md
docs/usage/v1_guide.md
+2
-0
requirements/cpu.txt
requirements/cpu.txt
+3
-0
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+4
-1
tests/models/language/generation/test_common.py
tests/models/language/generation/test_common.py
+0
-1
vllm/attention/backends/cpu_mla.py
vllm/attention/backends/cpu_mla.py
+3
-3
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+12
-4
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+6
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-1
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+57
-10
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+163
-0
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+86
-0
vllm/v1/worker/cpu_worker.py
vllm/v1/worker/cpu_worker.py
+101
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+18
-10
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+2
-1
No files found.
.buildkite/scripts/hardware_ci/run-cpu-test.sh
View file @
4555143e
...
@@ -6,6 +6,7 @@ set -ex
...
@@ -6,6 +6,7 @@ set -ex
# allow to bind to different cores
# allow to bind to different cores
CORE_RANGE
=
${
CORE_RANGE
:-
48
-95
}
CORE_RANGE
=
${
CORE_RANGE
:-
48
-95
}
OMP_CORE_RANGE
=
${
OMP_CORE_RANGE
:-
48
-95
}
NUMA_NODE
=
${
NUMA_NODE
:-
1
}
NUMA_NODE
=
${
NUMA_NODE
:-
1
}
export
CMAKE_BUILD_PARALLEL_LEVEL
=
32
export
CMAKE_BUILD_PARALLEL_LEVEL
=
32
...
@@ -23,10 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
...
@@ -23,10 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl
-C
"
$CORE_RANGE
"
-N
"
$NUMA_NODE
"
docker build
--build-arg
VLLM_CPU_DISABLE_AVX512
=
"true"
--tag
cpu-test-
"
$NUMA_NODE
"
-avx2
--target
vllm-test
-f
docker/Dockerfile.cpu
.
numactl
-C
"
$CORE_RANGE
"
-N
"
$NUMA_NODE
"
docker build
--build-arg
VLLM_CPU_DISABLE_AVX512
=
"true"
--tag
cpu-test-
"
$NUMA_NODE
"
-avx2
--target
vllm-test
-f
docker/Dockerfile.cpu
.
# Run the image, setting --shm-size=4g for tensor parallel.
# Run the image, setting --shm-size=4g for tensor parallel.
docker run
-itd
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--cpuset-cpus
=
"
$CORE_RANGE
"
\
docker run
-itd
--cpuset-cpus
=
"
$CORE_RANGE
"
--cpuset-mems
=
"
$NUMA_NODE
"
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--env
VLLM_CPU_OMP_THREADS_BIND
=
"
$OMP_CORE_RANGE
"
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
cpu-test-
"
$NUMA_NODE
"
--cpuset-mems
=
"
$NUMA_NODE
"
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
cpu-test-
"
$NUMA_NODE
"
docker run
-itd
--cpuset-cpus
=
"
$CORE_RANGE
"
--cpuset-mems
=
"
$NUMA_NODE
"
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--env
VLLM_CPU_OMP_THREADS_BIND
=
"
$OMP_CORE_RANGE
"
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
-avx2
cpu-test-
"
$NUMA_NODE
"
-avx2
docker run
-itd
--entrypoint
/bin/bash
-v
~/.cache/huggingface:/root/.cache/huggingface
--cpuset-cpus
=
"
$CORE_RANGE
"
\
--cpuset-mems
=
"
$NUMA_NODE
"
--privileged
=
true
-e
HF_TOKEN
--env
VLLM_CPU_KVCACHE_SPACE
=
4
--shm-size
=
4g
--name
cpu-test-
"
$NUMA_NODE
"
-avx2
cpu-test-
"
$NUMA_NODE
"
-avx2
function
cpu_tests
()
{
function
cpu_tests
()
{
set
-e
set
-e
...
@@ -56,7 +55,7 @@ function cpu_tests() {
...
@@ -56,7 +55,7 @@ function cpu_tests() {
# Run AWQ test
# Run AWQ test
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
set -e
set -e
pytest -s -v
\
VLLM_USE_V1=0
pytest -s -v
\
tests/quantization/test_ipex_quant.py"
tests/quantization/test_ipex_quant.py"
# Run chunked-prefill and prefix-cache test
# Run chunked-prefill and prefix-cache test
...
@@ -68,8 +67,6 @@ function cpu_tests() {
...
@@ -68,8 +67,6 @@ function cpu_tests() {
# online serving
# online serving
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
set -e
set -e
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=
$1
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py
\
python3 benchmarks/benchmark_serving.py
\
...
@@ -89,4 +86,4 @@ function cpu_tests() {
...
@@ -89,4 +86,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 40 mins.
# All of CPU tests are expected to be finished less than 40 mins.
export
-f
cpu_tests
export
-f
cpu_tests
timeout
40m
bash
-c
"cpu_tests
$CORE_RANGE
$NUMA_NODE
"
timeout
1h
bash
-c
"cpu_tests
$CORE_RANGE
$NUMA_NODE
"
docs/usage/v1_guide.md
View file @
4555143e
...
@@ -40,6 +40,8 @@ This living user guide outlines a few known **important changes and limitations*
...
@@ -40,6 +40,8 @@ This living user guide outlines a few known **important changes and limitations*
|
**NVIDIA**
|
<nobr>
🚀 Natively Supported
</nobr>
|
|
**NVIDIA**
|
<nobr>
🚀 Natively Supported
</nobr>
|
|
**AMD**
|
<nobr>
🚧 WIP
</nobr>
|
|
**AMD**
|
<nobr>
🚧 WIP
</nobr>
|
|
**TPU**
|
<nobr>
🚧 WIP
</nobr>
|
|
**TPU**
|
<nobr>
🚧 WIP
</nobr>
|
|
**CPU**
|
<nobr>
🚧 WIP
</nobr>
|
#### Feature / Model
#### Feature / Model
| Feature / Model | Status |
| Feature / Model | Status |
...
...
requirements/cpu.txt
View file @
4555143e
# Common dependencies
# Common dependencies
-r common.txt
-r common.txt
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
numba == 0.61.2; python_version > '3.9'
# Dependencies for CPUs
# Dependencies for CPUs
packaging>=24.2
packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools>=77.0.3,<80.0.0
...
...
tests/kernels/attention/test_attention_selector.py
View file @
4555143e
...
@@ -85,7 +85,10 @@ def test_env(
...
@@ -85,7 +85,10 @@ def test_env(
CpuPlatform
()):
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
block_size
,
False
)
block_size
,
False
)
assert
backend
.
get_name
()
==
"TORCH_SDPA"
if
use_v1
:
assert
backend
.
get_name
()
==
"TORCH_SDPA_VLLM_V1"
else
:
assert
backend
.
get_name
()
==
"TORCH_SDPA"
elif
device
==
"hip"
:
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
with
patch
(
"vllm.attention.selector.current_platform"
,
...
...
tests/models/language/generation/test_common.py
View file @
4555143e
...
@@ -87,7 +87,6 @@ AITER_MODEL_LIST = [
...
@@ -87,7 +87,6 @@ AITER_MODEL_LIST = [
pytest
.
param
(
"bigcode/starcoder2-3b"
),
# starcoder2
pytest
.
param
(
"bigcode/starcoder2-3b"
),
# starcoder2
pytest
.
param
(
pytest
.
param
(
"TitanML/tiny-mixtral"
,
# mixtral
"TitanML/tiny-mixtral"
,
# mixtral
marks
=
[
pytest
.
mark
.
cpu_model
],
)
)
])
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
...
...
vllm/attention/backends/cpu_mla.py
View file @
4555143e
...
@@ -178,7 +178,7 @@ class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
...
@@ -178,7 +178,7 @@ class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_kv_len
=
max_kv_len
,
max_kv_len
=
max_kv_len
,
query_start_loc
=
query_start_loc
,
prefill_
query_start_loc
=
query_start_loc
,
kv_start_loc
=
kv_start_loc
,
kv_start_loc
=
kv_start_loc
,
max_decode_seq_len
=
input_data
.
max_decode_seq_len
,
max_decode_seq_len
=
input_data
.
max_decode_seq_len
,
num_prefills
=
input_data
.
num_prefills
,
num_prefills
=
input_data
.
num_prefills
,
...
@@ -264,8 +264,8 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
...
@@ -264,8 +264,8 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
key
=
k
,
key
=
k
,
value
=
v_padded
,
value
=
v_padded
,
out
=
output
,
out
=
output
,
seqlen_q
=
prefill_metadata
.
query_start_loc
,
seqlen_q
=
prefill_metadata
.
prefill_
query_start_loc
,
seqlen_k
=
prefill_metadata
.
query_start_loc
,
seqlen_k
=
prefill_metadata
.
prefill_
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
max_query_len
,
pdropout
=
0.0
,
pdropout
=
0.0
,
...
...
vllm/attention/backends/torch_sdpa.py
View file @
4555143e
...
@@ -87,10 +87,13 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -87,10 +87,13 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# For chunked prefill only
# For chunked prefill only
max_query_len
:
Optional
[
int
]
=
None
max_query_len
:
Optional
[
int
]
=
None
max_kv_len
:
Optional
[
int
]
=
None
max_kv_len
:
Optional
[
int
]
=
None
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
prefill_
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
kv_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
kv_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
prefill_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
prefill_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# For V1 logits index only
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
...
@@ -375,7 +378,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
...
@@ -375,7 +378,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_kv_len
=
max_kv_len
,
max_kv_len
=
max_kv_len
,
query_start_loc
=
query_start_loc
,
prefill_
query_start_loc
=
query_start_loc
,
kv_start_loc
=
kv_start_loc
,
kv_start_loc
=
kv_start_loc
,
max_decode_seq_len
=
input_data
.
max_decode_seq_len
,
max_decode_seq_len
=
input_data
.
max_decode_seq_len
,
num_prefills
=
input_data
.
num_prefills
,
num_prefills
=
input_data
.
num_prefills
,
...
@@ -470,6 +473,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -470,6 +473,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
# For warming-up
if
attn_metadata
is
None
:
return
query
attn_type
=
self
.
attn_type
attn_type
=
self
.
attn_type
if
(
attn_type
==
AttentionType
.
ENCODER
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
...
@@ -537,8 +545,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -537,8 +545,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
assert
attn_metadata
.
seq_lens
is
not
None
if
not
prefill_meta
.
prefill_metadata
.
chunked_prefill
:
# type: ignore
if
not
prefill_meta
.
prefill_metadata
.
chunked_prefill
:
# type: ignore
assert
attn_metadata
.
seq_lens
is
not
None
self
.
_run_sdpa_forward
(
output
,
self
.
_run_sdpa_forward
(
output
,
query
,
query
,
key
,
key
,
...
@@ -555,7 +563,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -555,7 +563,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
query
[:
prefill_meta
.
num_prefill_tokens
,
:,
:],
query
[:
prefill_meta
.
num_prefill_tokens
,
:,
:],
key_cache
,
key_cache
,
value_cache
,
value_cache
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
prefill_
query_start_loc
,
prefill_meta
.
kv_start_loc
,
prefill_meta
.
kv_start_loc
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_kv_len
,
prefill_meta
.
max_kv_len
,
...
...
vllm/compilation/wrapper.py
View file @
4555143e
...
@@ -41,11 +41,16 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -41,11 +41,16 @@ class TorchCompileWrapperWithCustomDispatcher:
# compiling the forward method
# compiling the forward method
backend
=
vllm_config
.
compilation_config
.
init_backend
(
vllm_config
)
backend
=
vllm_config
.
compilation_config
.
init_backend
(
vllm_config
)
options
=
None
if
isinstance
(
backend
,
str
)
and
backend
==
"inductor"
:
options
=
get_current_vllm_config
(
).
compilation_config
.
inductor_compile_config
compiled_callable
=
torch
.
compile
(
compiled_callable
=
torch
.
compile
(
self
.
forward
,
self
.
forward
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
backend
=
backend
)
backend
=
backend
,
options
=
options
)
self
.
compiled_callable
=
compiled_callable
self
.
compiled_callable
=
compiled_callable
self
.
original_code_object
=
self
.
__class__
.
forward
.
__code__
self
.
original_code_object
=
self
.
__class__
.
forward
.
__code__
...
...
vllm/engine/arg_utils.py
View file @
4555143e
...
@@ -1399,6 +1399,7 @@ class EngineArgs:
...
@@ -1399,6 +1399,7 @@ class EngineArgs:
"FLASHINFER"
,
"FLASHINFER"
,
"FLASHINFER_VLLM_V1"
,
"FLASHINFER_VLLM_V1"
,
"ROCM_AITER_MLA"
,
"ROCM_AITER_MLA"
,
"TORCH_SDPA_VLLM_V1"
,
]
]
if
(
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
if
(
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
and
envs
.
VLLM_ATTENTION_BACKEND
not
in
V1_BACKENDS
):
and
envs
.
VLLM_ATTENTION_BACKEND
not
in
V1_BACKENDS
):
...
@@ -1431,7 +1432,8 @@ class EngineArgs:
...
@@ -1431,7 +1432,8 @@ class EngineArgs:
# Non-[CUDA, TPU] may be supported on V1, but off by default for now.
# Non-[CUDA, TPU] may be supported on V1, but off by default for now.
v0_hardware
=
not
any
(
v0_hardware
=
not
any
(
(
current_platform
.
is_cuda
(),
current_platform
.
is_tpu
()))
(
current_platform
.
is_cuda
(),
current_platform
.
is_tpu
(),
current_platform
.
is_cpu
()))
if
v0_hardware
and
_warn_or_fallback
(
# noqa: SIM103
if
v0_hardware
and
_warn_or_fallback
(
# noqa: SIM103
current_platform
.
device_name
):
current_platform
.
device_name
):
return
False
return
False
...
...
vllm/platforms/cpu.py
View file @
4555143e
...
@@ -57,7 +57,10 @@ class CpuPlatform(Platform):
...
@@ -57,7 +57,10 @@ class CpuPlatform(Platform):
logger
.
info
(
"Using CPU MLA backend."
)
logger
.
info
(
"Using CPU MLA backend."
)
return
"vllm.attention.backends.cpu_mla.CPUMLABackend"
return
"vllm.attention.backends.cpu_mla.CPUMLABackend"
logger
.
info
(
"Using Torch SDPA backend."
)
logger
.
info
(
"Using Torch SDPA backend."
)
return
"vllm.attention.backends.torch_sdpa.TorchSDPABackend"
if
use_v1
:
return
"vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
else
:
return
"vllm.attention.backends.torch_sdpa.TorchSDPABackend"
@
classmethod
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
...
@@ -81,6 +84,8 @@ class CpuPlatform(Platform):
...
@@ -81,6 +84,8 @@ class CpuPlatform(Platform):
if
not
model_config
.
enforce_eager
:
if
not
model_config
.
enforce_eager
:
model_config
.
enforce_eager
=
True
model_config
.
enforce_eager
=
True
model_config
.
disable_cascade_attn
=
True
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
ipex_available
=
find_spec
(
"intel_extension_for_pytorch"
)
is
not
None
ipex_available
=
find_spec
(
"intel_extension_for_pytorch"
)
is
not
None
...
@@ -128,7 +133,8 @@ class CpuPlatform(Platform):
...
@@ -128,7 +133,8 @@ class CpuPlatform(Platform):
f
"
{
kv_cache_space
}
, expect a positive integer value."
)
f
"
{
kv_cache_space
}
, expect a positive integer value."
)
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
if
(
parallel_config
.
distributed_executor_backend
is
not
None
if
(
parallel_config
.
world_size
>
1
and
parallel_config
.
distributed_executor_backend
is
not
None
and
parallel_config
.
distributed_executor_backend
!=
"mp"
):
and
parallel_config
.
distributed_executor_backend
!=
"mp"
):
logger
.
warning
((
"%s is not supported on CPU, fallback to mp "
logger
.
warning
((
"%s is not supported on CPU, fallback to mp "
"distributed executor backend."
),
"distributed executor backend."
),
...
@@ -141,7 +147,38 @@ class CpuPlatform(Platform):
...
@@ -141,7 +147,38 @@ class CpuPlatform(Platform):
parallel_config
.
sd_worker_cls
=
\
parallel_config
.
sd_worker_cls
=
\
"vllm.worker.cpu_worker.CPUWorker"
"vllm.worker.cpu_worker.CPUWorker"
else
:
else
:
parallel_config
.
worker_cls
=
"vllm.worker.cpu_worker.CPUWorker"
if
envs
.
VLLM_USE_V1
:
parallel_config
.
worker_cls
=
\
"vllm.v1.worker.cpu_worker.CPUWorker"
else
:
parallel_config
.
worker_cls
=
\
"vllm.worker.cpu_worker.CPUWorker"
# Note: workaround for v1 gpu_model_runner
from
vllm.config
import
CompilationLevel
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
=
[]
compilation_config
=
vllm_config
.
compilation_config
if
(
envs
.
VLLM_USE_V1
and
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
):
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
compilation_config
.
backend
=
"eager"
compilation_config
.
custom_ops
+=
[
"none"
]
compilation_config
.
inductor_compile_config
.
update
({
"dce"
:
True
,
"size_asserts"
:
False
,
"nan_asserts"
:
False
,
"memory_planning"
:
True
,
"epilogue_fusion"
:
True
,
})
if
vllm_config
.
lora_config
is
not
None
:
compilation_config
.
level
=
CompilationLevel
.
NO_COMPILATION
assert
vllm_config
.
device_config
.
device_type
==
"cpu"
assert
vllm_config
.
device_config
.
device_type
==
"cpu"
...
@@ -149,6 +186,12 @@ class CpuPlatform(Platform):
...
@@ -149,6 +186,12 @@ class CpuPlatform(Platform):
# Environment variables for CPU executor
# Environment variables for CPU executor
#
#
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
# Note: to avoid the error 'nthreads cannot be larger than environment
# variable "NUMEXPR_MAX_THREADS" (64)'.
os
.
environ
[
"NUMEXPR_MAX_THREADS"
]
=
str
(
len
(
os
.
sched_getaffinity
(
0
)))
# Set default threads num for OpenMP parallel
# Set default threads num for OpenMP parallel
os
.
environ
[
"OMP_NUM_THREADS"
]
=
str
(
torch
.
get_num_threads
())
os
.
environ
[
"OMP_NUM_THREADS"
]
=
str
(
torch
.
get_num_threads
())
...
@@ -171,13 +214,6 @@ class CpuPlatform(Platform):
...
@@ -171,13 +214,6 @@ class CpuPlatform(Platform):
# To hint IPEX uses shared memory based AllReduce
# To hint IPEX uses shared memory based AllReduce
os
.
environ
[
"LOCAL_WORLD_SIZE"
]
=
str
(
os
.
environ
[
"LOCAL_WORLD_SIZE"
]
=
str
(
vllm_config
.
parallel_config
.
tensor_parallel_size
)
vllm_config
.
parallel_config
.
tensor_parallel_size
)
if
sys
.
platform
==
"darwin"
and
\
envs
.
VLLM_WORKER_MULTIPROC_METHOD
==
"fork"
:
if
os
.
environ
.
get
(
'VLLM_WORKER_MULTIPROC_METHOD'
,
None
)
is
None
:
logger
.
warning
(
"Default to spawn method on MacOS. If this is not desired,"
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly."
)
os
.
environ
[
'VLLM_WORKER_MULTIPROC_METHOD'
]
=
'spawn'
if
vllm_config
.
model_config
and
vllm_config
.
model_config
.
use_mla
:
if
vllm_config
.
model_config
and
vllm_config
.
model_config
.
use_mla
:
logger
.
info
(
logger
.
info
(
...
@@ -204,3 +240,14 @@ class CpuPlatform(Platform):
...
@@ -204,3 +240,14 @@ class CpuPlatform(Platform):
Get device specific communicator class for distributed communication.
Get device specific communicator class for distributed communication.
"""
"""
return
"vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"
# noqa
return
"vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"
# noqa
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_v1
(
cls
,
model_config
)
->
bool
:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return
True
vllm/v1/attention/backends/cpu_attn.py
0 → 100644
View file @
4555143e
# SPDX-License-Identifier: Apache-2.0
import
numpy
as
np
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.torch_sdpa
import
(
TorchSDPABackendImpl
,
TorchSDPAMetadata
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.ipex_attn
import
PagedAttention
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
class
TorchSDPABackend
:
accept_output_buffer
:
bool
=
False
@
staticmethod
def
get_name
()
->
str
:
return
"TORCH_SDPA_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
type
[
"TorchSDPABackendImpl"
]:
return
TorchSDPABackendImpl
@
staticmethod
def
get_metadata_cls
()
->
type
[
"AttentionMetadata"
]:
return
TorchSDPAMetadata
@
staticmethod
def
get_state_cls
()
->
type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_builder_cls
()
->
type
[
"TorchSDPAMetadataBuilderV1"
]:
return
TorchSDPAMetadataBuilderV1
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
return
False
class
TorchSDPAMetadataBuilderV1
:
def
__init__
(
self
,
runner
:
CPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
)
->
None
:
self
.
runner
=
runner
self
.
block_table
=
block_table
# For reorder
self
.
reorder_prompt_req_index_list
=
np
.
empty
(
self
.
runner
.
max_num_reqs
,
dtype
=
np
.
int64
)
self
.
reorder_decode_req_index_list
=
np
.
empty
(
self
.
runner
.
max_num_reqs
,
dtype
=
np
.
int64
)
self
.
num_prompt_req
:
int
=
0
self
.
seq_start_loc_cpu
=
torch
.
zeros
(
runner
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
self
.
seq_start_loc_np
=
self
.
seq_start_loc_cpu
.
numpy
()
def
reorder_batch
(
self
,
input_batch
:
InputBatch
,
scheduler_output
:
SchedulerOutput
)
->
bool
:
prompt_list_idx
=
0
decode_list_idx
=
0
for
req_index
in
range
(
input_batch
.
num_reqs
):
if
input_batch
.
num_computed_tokens_cpu
[
req_index
]
<
input_batch
.
num_prompt_tokens
[
req_index
]:
# prompt stage
self
.
reorder_prompt_req_index_list
[
prompt_list_idx
]
=
req_index
prompt_list_idx
+=
1
else
:
# decode stage
self
.
reorder_decode_req_index_list
[
decode_list_idx
]
=
req_index
decode_list_idx
+=
1
assert
decode_list_idx
+
prompt_list_idx
==
input_batch
.
num_reqs
# Update prompt requests number
self
.
num_prompt_req
=
prompt_list_idx
reorder_req_num
=
0
for
req_index
in
range
(
decode_list_idx
):
if
self
.
reorder_decode_req_index_list
[
req_index
]
<
prompt_list_idx
:
reorder_req_num
+=
1
else
:
break
if
reorder_req_num
==
0
:
return
False
reorder_prompt_list
=
(
self
.
reorder_prompt_req_index_list
[:
prompt_list_idx
]
[
-
reorder_req_num
:])
reorder_decode_list
=
(
self
.
reorder_decode_req_index_list
[:
decode_list_idx
]
[:
reorder_req_num
])
assert
reorder_decode_list
.
size
==
reorder_prompt_list
.
size
for
idx
in
range
(
reorder_req_num
):
prompt_req_index
=
reorder_prompt_list
[
idx
].
item
()
decode_req_index
=
reorder_decode_list
[
idx
].
item
()
input_batch
.
swap_states
(
prompt_req_index
,
decode_req_index
)
return
True
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
runner
=
self
.
runner
block_table
=
self
.
block_table
seq_lens_np
=
runner
.
seq_lens_np
[:
num_reqs
]
num_prompt_req
=
self
.
num_prompt_req
max_prefill_seq_len
=
seq_lens_np
[:
num_prompt_req
].
max
().
item
(
)
if
num_prompt_req
>
0
else
0
max_decode_seq_len
=
seq_lens_np
[
num_prompt_req
:
num_reqs
].
max
().
item
(
)
if
num_prompt_req
<
num_reqs
else
0
self
.
seq_start_loc_np
[
0
]
=
0
np
.
cumsum
(
seq_lens_np
,
out
=
self
.
seq_start_loc_np
[
1
:
num_reqs
+
1
])
num_prefill_tokens
=
runner
.
query_start_loc_np
[
num_prompt_req
].
item
()
num_decode_tokens
=
runner
.
query_start_loc_np
[
num_reqs
].
item
(
)
-
num_prefill_tokens
slot_mapping
=
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
].
long
()
block_table_tensor
=
block_table
.
get_device_tensor
()
attn_metadata
=
TorchSDPAMetadata
(
num_prefills
=
num_prompt_req
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
seq_lens_tensor
=
runner
.
seq_lens_cpu
[
num_prompt_req
:
num_reqs
],
# decode
max_decode_seq_len
=
max_decode_seq_len
,
# decode
block_tables
=
block_table_tensor
[
num_prompt_req
:
num_reqs
],
# decode
chunked_prefill
=
True
,
max_query_len
=
max_query_len
,
max_kv_len
=
max_prefill_seq_len
,
prefill_query_start_loc
=
runner
.
query_start_loc_cpu
[:
num_prompt_req
+
1
],
# prefill
kv_start_loc
=
self
.
seq_start_loc_cpu
[:
num_prompt_req
+
1
],
# prefill
prefill_block_tables
=
block_table_tensor
[:
num_prompt_req
],
# prefill
query_start_loc
=
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
],
# for logits index
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
)
return
attn_metadata
vllm/v1/worker/cpu_model_runner.py
0 → 100644
View file @
4555143e
# SPDX-License-Identifier: Apache-2.0
from
contextlib
import
contextmanager
from
typing
import
Any
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
logger
=
init_logger
(
__name__
)
class
CPUModelRunner
(
GPUModelRunner
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
super
().
__init__
(
vllm_config
,
device
)
assert
device
==
torch
.
device
(
"cpu"
)
assert
self
.
speculative_config
is
None
,
"spec decode is not supported."
self
.
use_cuda_graph
=
False
self
.
cascade_attn_enabled
=
False
self
.
_postprocess_tenosrs
()
def
_postprocess_tenosrs
(
self
)
->
None
:
# Note: replace device tensors with cpu tensors
def
replace_tensor
(
obj
:
Any
,
cpu_attr_name
:
str
,
device_attr_name
)
->
None
:
cpu_tensor
=
getattr
(
obj
,
cpu_attr_name
,
None
)
device_tensor
=
getattr
(
obj
,
device_attr_name
,
None
)
if
cpu_tensor
is
not
None
and
device_tensor
is
not
None
:
assert
isinstance
(
cpu_tensor
,
torch
.
Tensor
)
assert
isinstance
(
device_tensor
,
torch
.
Tensor
)
setattr
(
obj
,
device_attr_name
,
cpu_tensor
)
for
k
,
v
in
vars
(
self
).
items
():
if
k
.
endswith
(
"_cpu"
)
and
isinstance
(
v
,
torch
.
Tensor
):
replace_tensor
(
self
,
k
,
k
[:
-
4
])
for
k
,
v
in
vars
(
self
.
input_batch
).
items
():
if
k
.
endswith
(
"_cpu_tensor"
)
and
isinstance
(
v
,
torch
.
Tensor
):
replace_tensor
(
self
.
input_batch
,
k
,
k
[:
-
11
])
for
k
,
v
in
vars
(
self
.
input_batch
.
block_table
).
items
():
if
k
.
endswith
(
"_cpu"
)
and
isinstance
(
v
,
torch
.
Tensor
):
replace_tensor
(
self
.
input_batch
.
block_table
,
k
,
k
[:
-
4
])
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
if
self
.
lora_config
:
self
.
model
=
self
.
load_lora_model
(
self
.
model
,
self
.
model_config
,
self
.
scheduler_config
,
self
.
lora_config
,
self
.
device
)
def
warming_up_model
(
self
)
->
None
:
logger
.
info
(
"Warming up model for the compilation..."
)
# Only generate graph for the generic shape
self
.
_dummy_run
(
max
(
16
,
self
.
max_num_reqs
))
logger
.
info
(
"Warming up done."
)
def
_init_device_properties
(
self
)
->
None
:
pass
def
_sync_device
(
self
)
->
None
:
pass
@
contextmanager
def
_set_global_compilation_settings
():
import
torch._inductor.config
# Note: The CPPGEMM backend requires freezing parameters.
freezing_value
=
torch
.
_inductor
.
config
.
freezing
torch
.
_inductor
.
config
.
freezing
=
True
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
# including object type dict"
force_disable_caches
=
torch
.
_inductor
.
config
.
force_disable_caches
torch
.
_inductor
.
config
.
force_disable_caches
=
True
yield
torch
.
_inductor
.
config
.
freezing
=
freezing_value
torch
.
_inductor
.
config
.
force_disable_caches
=
force_disable_caches
vllm/v1/worker/cpu_worker.py
0 → 100644
View file @
4555143e
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Optional
import
torch
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.v1.worker.gpu_worker
import
(
Worker
,
init_worker_distributed_environment
)
logger
=
init_logger
(
__name__
)
class
CPUWorker
(
Worker
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
):
super
().
__init__
(
vllm_config
,
local_rank
,
rank
,
distributed_init_method
,
is_driver_worker
=
is_driver_worker
)
self
.
parallel_config
.
disable_custom_all_reduce
=
True
def
init_device
(
self
):
# Setup OpenMP threads affinity.
omp_cpuids
=
envs
.
VLLM_CPU_OMP_THREADS_BIND
if
omp_cpuids
==
"all"
:
self
.
local_omp_cpuid
=
"all"
else
:
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
self
.
rank
]
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
if
ret
:
logger
.
info
(
ret
)
# Note: unique identifier for creating allreduce shared memory
os
.
environ
[
"VLLM_DIST_IDENT"
]
=
self
.
distributed_init_method
.
split
(
":"
)[
-
1
]
# Initialize the distributed environment.
init_worker_distributed_environment
(
self
.
vllm_config
,
self
.
rank
,
self
.
distributed_init_method
,
self
.
local_rank
,
"gloo"
)
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
# Construct the model runner
self
.
model_runner
:
CPUModelRunner
=
CPUModelRunner
(
self
.
vllm_config
,
torch
.
device
(
"cpu"
))
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
logger
.
warning
(
"sleep mode is not supported on CPU, ignore it."
)
pass
def
wake_up
(
self
,
tags
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
logger
.
warning
(
"sleep mode is not supported on CPU, ignore it."
)
pass
def
determine_available_memory
(
self
)
->
int
:
return
self
.
cache_config
.
cpu_kvcache_space_bytes
# type: ignore
def
compile_or_warm_up_model
(
self
)
->
None
:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
self
.
model_runner
.
warming_up_model
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
Optional
[
ModelRunnerOutput
]:
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
IntermediateTensors
(
get_pp_group
().
recv_tensor_dict
(
all_gather_group
=
get_tp_group
()))
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
not
get_pp_group
().
is_last_rank
:
assert
isinstance
(
output
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
output
.
tensors
,
all_gather_group
=
get_tp_group
())
return
None
assert
isinstance
(
output
,
ModelRunnerOutput
)
return
output
if
self
.
is_driver_worker
else
None
vllm/v1/worker/gpu_model_runner.py
View file @
4555143e
...
@@ -5,7 +5,7 @@ import copy
...
@@ -5,7 +5,7 @@ import copy
import
gc
import
gc
import
time
import
time
import
weakref
import
weakref
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -38,7 +38,6 @@ from vllm.sequence import IntermediateTensors
...
@@ -38,7 +38,6 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
...
@@ -203,8 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -203,8 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
# Cache the device properties.
# Cache the device properties.
self
.
device_properties
=
torch
.
cuda
.
get_device_properties
(
self
.
device
)
self
.
_init_device_properties
()
self
.
num_sms
=
self
.
device_properties
.
multi_processor_count
# Persistent buffers for CUDA graphs.
# Persistent buffers for CUDA graphs.
self
.
input_ids
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
input_ids
=
torch
.
zeros
(
self
.
max_num_tokens
,
...
@@ -315,6 +313,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -315,6 +313,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
,
scheduler_output
)
self
.
input_batch
,
scheduler_output
)
return
batch_reordered
return
batch_reordered
# Note: used for model runner override.
def
_init_device_properties
(
self
)
->
None
:
"""Initialize attributes from torch.cuda.get_device_properties
"""
self
.
device_properties
=
torch
.
cuda
.
get_device_properties
(
self
.
device
)
self
.
num_sms
=
self
.
device_properties
.
multi_processor_count
# Note: used for model runner override.
def
_sync_device
(
self
)
->
None
:
torch
.
cuda
.
synchronize
()
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""Update the cached states and the persistent batch with the scheduler
"""Update the cached states and the persistent batch with the scheduler
output.
output.
...
@@ -538,8 +547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -538,8 +547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_inputs
(
def
_prepare_inputs
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
dict
[
str
,
FlashAttentionMetadata
],
torch
.
Tensor
,
)
->
tuple
[
dict
[
str
,
Any
],
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
Optional
[
SpecDecodeMetadata
]]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
...
@@ -652,7 +660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -652,7 +660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
attn_metadata
:
dict
[
str
,
FlashAttentionMetadata
]
=
{}
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# in the same group share the same metadata.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
...
@@ -1710,7 +1718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1710,7 +1718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Must synchronize the non-blocking GPU->CPU transfers.
# Must synchronize the non-blocking GPU->CPU transfers.
if
prompt_logprobs_dict
:
if
prompt_logprobs_dict
:
torch
.
cuda
.
synchroniz
e
()
self
.
_sync_devic
e
()
return
prompt_logprobs_dict
return
prompt_logprobs_dict
...
@@ -1740,7 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1740,7 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
if
skip_attn
:
if
skip_attn
:
attn_metadata
:
Optional
[
dict
[
str
,
FlashAttentionMetadata
]]
=
None
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
else
:
else
:
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
...
@@ -1964,7 +1972,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1964,7 +1972,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampler_output
=
self
.
_dummy_sampler_run
(
hidden_states
)
sampler_output
=
self
.
_dummy_sampler_run
(
hidden_states
)
else
:
else
:
sampler_output
=
None
sampler_output
=
None
torch
.
cuda
.
synchroniz
e
()
self
.
_sync_devic
e
()
del
hidden_states
,
sampler_output
del
hidden_states
,
sampler_output
self
.
encoder_cache
.
clear
()
self
.
encoder_cache
.
clear
()
gc
.
collect
()
gc
.
collect
()
...
...
vllm/v1/worker/gpu_worker.py
View file @
4555143e
...
@@ -342,13 +342,14 @@ def init_worker_distributed_environment(
...
@@ -342,13 +342,14 @@ def init_worker_distributed_environment(
rank
:
int
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
distributed_init_method
:
Optional
[
str
]
=
None
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
)
->
None
:
)
->
None
:
"""Initialize the distributed environment."""
"""Initialize the distributed environment."""
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
set_custom_all_reduce
(
not
parallel_config
.
disable_custom_all_reduce
)
set_custom_all_reduce
(
not
parallel_config
.
disable_custom_all_reduce
)
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
distributed_init_method
,
local_rank
)
distributed_init_method
,
local_rank
,
backend
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment