Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98e7f223
Unverified
Commit
98e7f223
authored
Mar 27, 2026
by
Jonas M. Kübler
Committed by
GitHub
Mar 27, 2026
Browse files
enable skipping of SW attention layers when using FP8 KV cache (#33695)
Signed-off-by:
Jonas Kuebler
<
kuebj@amazon.com
>
parent
b111f8a6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
0 deletions
+58
-0
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+23
-0
vllm/config/cache.py
vllm/config/cache.py
+3
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-0
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+25
-0
No files found.
tests/quantization/test_fp8.py
View file @
98e7f223
...
...
@@ -466,3 +466,26 @@ def test_fp8_reloading(
weight_loader
(
param
,
torch
.
zeros
(
shape
))
# cannot use empty
method
.
process_weights_after_loading
(
layer
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
,
)
def
test_kv_cache_dtype_skip_layers
(
vllm_runner
,
monkeypatch
):
"""Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
with
vllm_runner
(
"facebook/opt-125m"
,
kv_cache_dtype
=
"fp8"
,
kv_cache_dtype_skip_layers
=
[
"0"
,
"2"
],
enforce_eager
=
True
,
)
as
llm
:
def
check_layers
(
model
):
for
i
,
layer
in
enumerate
(
model
.
model
.
decoder
.
layers
):
expected
=
"auto"
if
str
(
i
)
in
[
"0"
,
"2"
]
else
"fp8"
assert
layer
.
self_attn
.
attn
.
kv_cache_dtype
==
expected
llm
.
apply_model
(
check_layers
)
vllm/config/cache.py
View file @
98e7f223
...
...
@@ -87,6 +87,9 @@ class CacheConfig:
It enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
checkpoint if available. Otherwise, the scales will default to 1.0."""
kv_cache_dtype_skip_layers
:
list
[
str
]
=
field
(
default_factory
=
list
)
"""Layer patterns to skip KV cache quantization. Accepts layer indices
(e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
cpu_kvcache_space_bytes
:
int
|
None
=
None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded
:
int
|
None
=
None
...
...
vllm/engine/arg_utils.py
View file @
98e7f223
...
...
@@ -597,6 +597,9 @@ class EngineArgs:
attention_backend
:
AttentionBackendEnum
|
None
=
AttentionConfig
.
backend
calculate_kv_scales
:
bool
=
CacheConfig
.
calculate_kv_scales
kv_cache_dtype_skip_layers
:
list
[
str
]
=
get_field
(
CacheConfig
,
"kv_cache_dtype_skip_layers"
)
mamba_cache_dtype
:
MambaDType
=
CacheConfig
.
mamba_cache_dtype
mamba_ssm_cache_dtype
:
MambaDType
=
CacheConfig
.
mamba_ssm_cache_dtype
mamba_block_size
:
int
|
None
=
get_field
(
CacheConfig
,
"mamba_block_size"
)
...
...
@@ -1003,6 +1006,9 @@ class EngineArgs:
cache_group
.
add_argument
(
"--calculate-kv-scales"
,
**
cache_kwargs
[
"calculate_kv_scales"
]
)
cache_group
.
add_argument
(
"--kv-cache-dtype-skip-layers"
,
**
cache_kwargs
[
"kv_cache_dtype_skip_layers"
]
)
cache_group
.
add_argument
(
"--kv-sharing-fast-prefill"
,
**
cache_kwargs
[
"kv_sharing_fast_prefill"
]
)
...
...
@@ -1578,6 +1584,7 @@ class EngineArgs:
enable_prefix_caching
=
self
.
enable_prefix_caching
,
prefix_caching_hash_algo
=
self
.
prefix_caching_hash_algo
,
calculate_kv_scales
=
self
.
calculate_kv_scales
,
kv_cache_dtype_skip_layers
=
self
.
kv_cache_dtype_skip_layers
,
kv_sharing_fast_prefill
=
self
.
kv_sharing_fast_prefill
,
mamba_cache_dtype
=
self
.
mamba_cache_dtype
,
mamba_ssm_cache_dtype
=
self
.
mamba_ssm_cache_dtype
,
...
...
vllm/model_executor/layers/attention/attention.py
View file @
98e7f223
...
...
@@ -240,6 +240,31 @@ class Attention(nn.Module, AttentionLayerBase):
and
kv_cache_scheme
.
get
(
"strategy"
)
==
"attn_head"
)
# Skip quantization for specified layers
if
cache_config
is
not
None
and
cache_config
.
kv_cache_dtype_skip_layers
:
from
vllm.model_executor.models.utils
import
extract_layer_index
skip
=
False
# Check attention type
if
(
sliding_window
is
not
None
and
"sliding_window"
in
cache_config
.
kv_cache_dtype_skip_layers
):
skip
=
True
# Check layer index
layer_idx
=
extract_layer_index
(
prefix
)
if
str
(
layer_idx
)
in
cache_config
.
kv_cache_dtype_skip_layers
:
skip
=
True
if
skip
:
kv_cache_dtype
=
"auto"
calculate_kv_scales
=
False
logger
.
info
(
"Layer %s: kv_cache_dtype=%s, sliding_window=%s"
,
prefix
,
kv_cache_dtype
,
sliding_window
,
)
self
.
kv_cache_torch_dtype
=
kv_cache_dtype_str_to_dtype
(
kv_cache_dtype
,
vllm_config
.
model_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment