Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb14d53c
Unverified
Commit
fb14d53c
authored
Jul 03, 2025
by
Ning Xie
Committed by
GitHub
Jul 03, 2025
Browse files
[Kernel] refactor cpu worker v0 cache dtype (#20080)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
b024a42e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
18 deletions
+20
-18
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+20
-18
No files found.
vllm/worker/cpu_worker.py
View file @
fb14d53c
...
...
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
bind_kv_cache
from
vllm.utils
import
bind_kv_cache
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
,
CPUModelRunnerBase
from
vllm.worker.cpu_pooling_model_runner
import
CPUPoolingModelRunner
...
...
@@ -54,13 +54,8 @@ class CPUCacheEngine:
# in the scheduler.
self
.
num_cpu_blocks
=
cache_config
.
num_gpu_blocks
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
dtype
=
model_config
.
dtype
elif
cache_config
.
cache_dtype
in
[
"fp8"
,
"fp8_e5m2"
]:
self
.
dtype
=
torch
.
float8_e5m2
else
:
raise
NotImplementedError
(
f
"Unsupported KV cache type "
f
"
{
cache_config
.
cache_dtype
}
."
)
self
.
dtype
=
CPUCacheEngine
.
get_kv_cache_dtype
(
cache_config
,
model_config
)
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
...
...
@@ -97,10 +92,20 @@ class CPUCacheEngine:
def
copy
(
self
,
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
cpu_cache
,
src_to_dsts
)
@
staticmethod
def
get_kv_cache_dtype
(
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
):
if
cache_config
.
cache_dtype
==
"auto"
:
return
model_config
.
dtype
elif
cache_config
.
cache_dtype
in
[
"fp8"
,
"fp8_e5m2"
]:
return
torch
.
float8_e5m2
else
:
raise
NotImplementedError
(
f
"Unsupported KV cache type "
f
"
{
cache_config
.
cache_dtype
}
."
)
@
staticmethod
def
get_cache_block_size
(
block_size
:
int
,
cache_dtype
:
str
,
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
...
...
@@ -108,13 +113,10 @@ class CPUCacheEngine:
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
key_cache_block
=
block_size
*
num_heads
*
head_size
key_cache_block
=
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
if
not
model_config
.
use_mla
else
0
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
if
cache_dtype
==
"auto"
:
dtype
=
model_config
.
dtype
else
:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
dtype
=
CPUCacheEngine
.
get_kv_cache_dtype
(
cache_config
,
model_config
)
dtype_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
dtype_size
*
total
...
...
@@ -399,9 +401,9 @@ class CPUWorker(LocalOrDistributedWorkerBase):
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size in bytes of a single KV cache block.
"""
return
CPUCacheEngine
.
get_cache_block_size
(
self
.
cache_config
.
block_size
,
self
.
cache_config
.
cache_dtype
,
self
.
model_config
,
self
.
parallel_config
)
return
CPUCacheEngine
.
get_cache_block_size
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
def
get_cpus_id_binding_based_on_numa_nodes
(
self
)
->
str
:
"""Return CPUs id binding based on NUMA nodes.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment