Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a99564ac
Unverified
Commit
a99564ac
authored
Oct 25, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 25, 2025
Browse files
[Attention] Add missing kv cache scale setup (#27490)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
4c5f6321
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
59 deletions
+72
-59
vllm/attention/layer.py
vllm/attention/layer.py
+72
-59
No files found.
vllm/attention/layer.py
View file @
a99564ac
...
...
@@ -123,6 +123,69 @@ def maybe_get_vit_flash_attn_backend(
return
attn_backend
,
flash_attn_varlen_func
def
_init_kv_cache_quant
(
layer
:
nn
.
Module
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
kv_cache_dtype
:
str
,
calculate_kv_scales
:
bool
,
)
->
None
:
"""Initializes KV cache scaling factors and quantization method.
This helper function sets up the KV cache quantization attributes that are
shared between Attention and MLAAttention layers. It initializes scale
tensors for query, key, value, and probability, and configures the
quantization method if applicable.
Args:
layer: The attention layer instance to initialize.
quant_config: Optional quantization configuration.
prefix: Layer name prefix for quantization method lookup.
kv_cache_dtype: The KV cache data type string.
calculate_kv_scales: Whether to calculate KV scales dynamically.
"""
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
layer
.
kv_cache_dtype
=
kv_cache_dtype
layer
.
calculate_kv_scales
=
calculate_kv_scales
layer
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
layer
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
layer
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
layer
.
_prob_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer
.
_q_scale_float
=
1.0
layer
.
_k_scale_float
=
1.0
layer
.
_v_scale_float
=
1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
layer
.
_o_scale_float
=
None
quant_method
=
(
quant_config
.
get_quant_method
(
layer
,
prefix
=
prefix
)
if
quant_config
else
None
)
if
quant_method
is
not
None
and
not
isinstance
(
quant_method
,
UnquantizedLinearMethod
):
assert
isinstance
(
quant_method
,
BaseKVCacheMethod
)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if
kv_cache_dtype
==
"fp8_e5m2"
:
raise
ValueError
(
"fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
)
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
layer
.
quant_method
=
quant_method
layer
.
quant_method
.
create_weights
(
layer
)
class
Attention
(
nn
.
Module
,
AttentionLayerBase
):
"""Attention layer.
...
...
@@ -184,30 +247,10 @@ class Attention(nn.Module, AttentionLayerBase):
f
"num_heads (
{
num_heads
}
) is not divisible by num_kv_heads (
{
num_kv_heads
}
)"
)
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
calculate_kv_scales
=
calculate_kv_scales
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_prob_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self
.
_q_scale_float
=
1.0
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self
.
_o_scale_float
:
float
|
None
=
None
# Initialize KV cache quantization attributes
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
,
kv_cache_dtype
,
calculate_kv_scales
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
...
...
@@ -215,26 +258,6 @@ class Attention(nn.Module, AttentionLayerBase):
self
.
sliding_window
=
sliding_window
self
.
has_sink
=
extra_impl_args
.
get
(
"sinks"
)
is
not
None
quant_method
=
(
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
quant_config
else
None
)
if
quant_method
is
not
None
and
not
isinstance
(
quant_method
,
UnquantizedLinearMethod
):
assert
isinstance
(
quant_method
,
BaseKVCacheMethod
)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if
self
.
kv_cache_dtype
==
"fp8_e5m2"
:
raise
ValueError
(
"fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
)
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self
.
quant_method
=
quant_method
self
.
quant_method
.
create_weights
(
self
)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -636,7 +659,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_cache_dtype
=
"auto"
block_size
=
16
calculate_kv_scales
=
False
self
.
kv_cache_dtype
=
kv_cache_dtype
# Initialize KV cache quantization attributes
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
,
kv_cache_dtype
,
calculate_kv_scales
)
dtype
=
torch
.
get_default_dtype
()
self
.
attn_backend
=
get_attn_backend
(
...
...
@@ -685,20 +712,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
]
# Align with Attention's scale attributes for MLA backends.
self
.
calculate_kv_scales
=
calculate_kv_scales
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_prob_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# Host-side mirrors used by some attention backends
self
.
_q_scale_float
=
1.0
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
self
.
_o_scale_float
:
float
|
None
=
None
self
.
use_sparse
=
use_sparse
# Initialize q/k/v range constants.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment