Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a2ae4965
Unverified
Commit
a2ae4965
authored
Mar 15, 2025
by
Li, Jiang
Committed by
GitHub
Mar 14, 2025
Browse files
[CPU] Support FP8 KV cache (#14741)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
877e3522
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
122 additions
and
36 deletions
+122
-36
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+21
-17
csrc/cpu/cpu_types_x86.hpp
csrc/cpu/cpu_types_x86.hpp
+9
-0
docs/source/getting_started/installation/cpu.md
docs/source/getting_started/installation/cpu.md
+1
-1
tests/basic_correctness/test_chunked_prefill.py
tests/basic_correctness/test_chunked_prefill.py
+2
-2
tests/models/decoder_only/language/test_fp8.py
tests/models/decoder_only/language/test_fp8.py
+61
-0
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+5
-4
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+19
-11
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+4
-1
No files found.
csrc/cpu/cache.cpp
View file @
a2ae4965
...
...
@@ -3,6 +3,12 @@
#include "cpu_types.hpp"
#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif
namespace
{
template
<
typename
scalar_t
>
void
copy_blocks_cpu_impl
(
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
...
...
@@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
}
const
int
element_num_per_block
=
key_caches
[
0
][
0
].
numel
();
VLLM_DISPATCH_FLOATING_TYPES
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
block_mapping
,
element_num_per_block
,
num_layers
);
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
});
DISPATCH_MACRO
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
block_mapping
,
element_num_per_block
,
num_layers
);
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
});
}
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
...
...
@@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
reshape_and_cache_cpu_impl
)
reshape_and_cache_cpu_impl
<
scalar_t
>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
num_tokens
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
CPU_KERNEL_GUARD_OUT
(
reshape_and_cache_cpu_impl
)
});
DISPATCH_MACRO
(
key
.
scalar_type
(),
"reshape_and_cache_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
reshape_and_cache_cpu_impl
)
reshape_and_cache_cpu_impl
<
scalar_t
>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
num_tokens
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
CPU_KERNEL_GUARD_OUT
(
reshape_and_cache_cpu_impl
)
});
}
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
...
...
csrc/cpu/cpu_types_x86.hpp
View file @
a2ae4965
...
...
@@ -16,9 +16,18 @@ namespace vec_op {
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
...
...
docs/source/getting_started/installation/cpu.md
View file @
a2ae4965
...
...
@@ -189,7 +189,7 @@ vLLM CPU backend supports the following vLLM features:
-
Model Quantization (
`INT8 W8A8, AWQ, GPTQ`
)
-
Chunked-prefill
-
Prefix-caching
-
FP8-E5M2 KV
-C
ach
ing (TODO)
-
FP8-E5M2 KV
c
ach
e
## Related runtime environment variables
...
...
tests/basic_correctness/test_chunked_prefill.py
View file @
a2ae4965
...
...
@@ -266,7 +266,7 @@ def test_with_prefix_caching(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
...
...
@@ -303,7 +303,7 @@ def test_models_cpu(
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"chunk_size"
,
[
30
,
32
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
,
"half"
])
@
pytest
.
mark
.
cpu_model
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cpu
(),
reason
=
"CPU only"
)
def
test_with_prefix_caching_cpu
(
...
...
tests/models/decoder_only/language/test_fp8.py
View file @
a2ae4965
...
...
@@ -11,6 +11,7 @@ import pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm.platforms
import
current_platform
from
...utils
import
check_logprobs_close
...
...
@@ -93,3 +94,63 @@ def test_models(
name_0
=
"fp16_kv_cache"
,
name_1
=
"fp8_kv_cache"
,
)
@
pytest
.
mark
.
cpu_model
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cpu
(),
reason
=
"test for the CPU backend."
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype,base_model,test_model"
,
[
# Test BF16 checkpoint w. fp8_e5m2 kv-cache.
(
"fp8_e5m2"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"meta-llama/Llama-3.2-1B-Instruct"
),
])
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
4
])
# Due to low-precision numerical divergence, this test is too sensitive for
# the async postprocessor
@
pytest
.
mark
.
parametrize
(
"disable_async_output_proc"
,
[
True
])
def
test_cpu_models
(
vllm_runner
,
example_prompts
,
kv_cache_dtype
:
str
,
base_model
:
str
,
test_model
:
str
,
max_tokens
:
int
,
disable_async_output_proc
:
bool
,
)
->
None
:
"""
Only checks log probs match to cover the discrepancy in
numerical sensitive kernels.
"""
MAX_MODEL_LEN
=
1024
NUM_LOG_PROBS
=
8
with
vllm_runner
(
base_model
,
max_model_len
=
MAX_MODEL_LEN
,
dtype
=
"bfloat16"
,
kv_cache_dtype
=
"auto"
,
disable_async_output_proc
=
disable_async_output_proc
,
)
as
vllm_model
:
baseline_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
NUM_LOG_PROBS
)
with
vllm_runner
(
test_model
,
max_model_len
=
MAX_MODEL_LEN
,
dtype
=
"bfloat16"
,
kv_cache_dtype
=
kv_cache_dtype
,
disable_async_output_proc
=
disable_async_output_proc
,
)
as
vllm_model
:
test_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
NUM_LOG_PROBS
)
check_logprobs_close
(
outputs_0_lst
=
baseline_outputs
,
outputs_1_lst
=
test_outputs
,
name_0
=
"bf16_kv_cache"
,
name_1
=
"fp8_kv_cache"
,
)
vllm/attention/backends/torch_sdpa.py
View file @
a2ae4965
...
...
@@ -17,7 +17,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache
)
# yapf: enable
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.ipex_attn
import
PagedAttention
from
vllm.attention.ops.ipex_attn
import
PagedAttention
,
_use_ipex
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.logger
import
init_logger
from
vllm.utils
import
make_tensor_with_pad
...
...
@@ -431,10 +431,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
is_quantized_kv_cache
(
kv_cache_dtype
):
if
is_quantized_kv_cache
(
kv_cache_dtype
)
and
not
_use_ipex
:
raise
NotImplementedError
(
"Torch SDPA backend
does not support
FP8 KV cache
.
"
"
Please use xFormers backend instead
."
)
"Torch SDPA backend FP8 KV cache
requires
"
"
intel_extension_for_pytorch support
."
)
self
.
attn_type
=
attn_type
def
forward
(
...
...
vllm/platforms/cpu.py
View file @
a2ae4965
...
...
@@ -60,9 +60,6 @@ class CpuPlatform(Platform):
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if
not
model_config
.
enforce_eager
:
logger
.
warning
(
"CUDA graph is not supported on CPU, fallback to the eager "
"mode."
)
model_config
.
enforce_eager
=
True
cache_config
=
vllm_config
.
cache_config
...
...
@@ -70,6 +67,25 @@ class CpuPlatform(Platform):
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
scheduler_config
=
vllm_config
.
scheduler_config
if
((
scheduler_config
.
chunked_prefill_enabled
or
cache_config
.
enable_prefix_caching
)
and
cache_config
.
cache_dtype
!=
"auto"
):
raise
RuntimeError
(
"Chunked-prefill and prefix-cache on the CPU "
"backend is not compatible with FP8 KV cache."
)
if
cache_config
.
cache_dtype
==
"fp8_e4m3"
:
cache_config
.
cache_dtype
=
"fp8_e5m2"
logger
.
warning
(
"CPU backend doesn't support fp8_e4m3 KV cache type, "
"cast to fp8_e5m2."
)
if
(
cache_config
.
cache_dtype
!=
"auto"
and
model_config
.
dtype
==
torch
.
half
):
logger
.
warning
(
"FP8 KV cache on the CPU backend only does not"
" support fp16 for now, cast to bf16."
)
model_config
.
dtype
=
torch
.
bfloat16
kv_cache_space
=
envs
.
VLLM_CPU_KVCACHE_SPACE
if
kv_cache_space
>=
0
:
...
...
@@ -85,14 +101,6 @@ class CpuPlatform(Platform):
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f
"
{
kv_cache_space
}
, expect a positive integer value."
)
scheduler_config
=
vllm_config
.
scheduler_config
if
((
scheduler_config
.
chunked_prefill_enabled
or
cache_config
.
enable_prefix_caching
)
and
model_config
.
dtype
==
torch
.
half
):
logger
.
warning
(
"Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16."
)
model_config
.
dtype
=
torch
.
bfloat16
parallel_config
=
vllm_config
.
parallel_config
if
(
parallel_config
.
distributed_executor_backend
is
not
None
and
parallel_config
.
distributed_executor_backend
!=
"mp"
):
...
...
vllm/worker/cpu_worker.py
View file @
a2ae4965
...
...
@@ -53,8 +53,11 @@ class CPUCacheEngine:
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
dtype
=
model_config
.
dtype
elif
cache_config
.
cache_dtype
in
[
"fp8"
,
"fp8_e5m2"
]:
self
.
dtype
=
torch
.
float8_e5m2
else
:
self
.
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
raise
NotImplementedError
(
f
"Unsupported KV cache type "
f
"
{
cache_config
.
cache_dtype
}
."
)
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment