Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
62def07d
Unverified
Commit
62def07d
authored
Dec 27, 2025
by
Boyuan Feng
Committed by
GitHub
Dec 28, 2025
Browse files
[BugFix] register quant scale tensors as buffer (#31395)
Signed-off-by:
Boyuan Feng
<
boyuan@meta.com
>
parent
b326598e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
16 deletions
+82
-16
vllm/attention/layer.py
vllm/attention/layer.py
+82
-16
No files found.
vllm/attention/layer.py
View file @
62def07d
...
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod
,
UnquantizedLinearMethod
,
)
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizeMethodBase
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
...
@@ -46,6 +47,35 @@ from vllm.v1.kv_cache_interface import (
...
@@ -46,6 +47,35 @@ from vllm.v1.kv_cache_interface import (
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
should_load_quant_weights
(
quant_method
:
QuantizeMethodBase
|
None
)
->
bool
:
"""Returns whether the quantization method should load quantized weights."""
return
quant_method
is
not
None
and
not
isinstance
(
quant_method
,
UnquantizedLinearMethod
)
def
set_default_quant_scales
(
layer
:
nn
.
Module
,
register_buffer
:
bool
=
False
)
->
None
:
"""Sets default quantization scales for the layer."""
if
register_buffer
:
layer
.
register_buffer
(
"_k_scale"
,
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
))
layer
.
register_buffer
(
"_v_scale"
,
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
))
layer
.
register_buffer
(
"_q_scale"
,
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
))
layer
.
register_buffer
(
"_prob_scale"
,
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
))
else
:
layer
.
_k_scale
.
fill_
(
1.0
)
layer
.
_v_scale
.
fill_
(
1.0
)
layer
.
_q_scale
.
fill_
(
1.0
)
layer
.
_prob_scale
.
fill_
(
1.0
)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer
.
_q_scale_float
=
1.0
layer
.
_k_scale_float
=
1.0
layer
.
_v_scale_float
=
1.0
layer
.
_prob_scale_float
=
1.0
def
_init_kv_cache_quant
(
def
_init_kv_cache_quant
(
layer
:
nn
.
Module
,
layer
:
nn
.
Module
,
quant_config
:
QuantizationConfig
|
None
,
quant_config
:
QuantizationConfig
|
None
,
...
@@ -74,17 +104,21 @@ def _init_kv_cache_quant(
...
@@ -74,17 +104,21 @@ def _init_kv_cache_quant(
# with the model weights.
# with the model weights.
layer
.
kv_cache_dtype
=
kv_cache_dtype
layer
.
kv_cache_dtype
=
kv_cache_dtype
layer
.
calculate_kv_scales
=
calculate_kv_scales
layer
.
calculate_kv_scales
=
calculate_kv_scales
layer
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
layer
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
layer
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
layer
.
_prob_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# We also keep q/k/v_scale on host (cpu) memory for attention
# Note [Register q/k/v/prob scales in state dict]
# backends that require the scales to be on host instead of on device.
# When calling model.to(device), only parameters/buffers in state dict are
# e.g. Flashinfer
# moved. If not registering q/k/v/prob scales in state dict, there would
layer
.
_q_scale_float
=
1.0
# be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
layer
.
_k_scale_float
=
1.0
# on cpu.
layer
.
_v_scale_float
=
1.0
# Registering in state dict means it interacts with weight loading. One edge
# case is when quant_method is None, or quant_method is UnquantizedLinearMethod
# (i.e., should_load_quant_weights(quant_method) == False).
# In this case, the checkpoint does not have the scales. We need to
# initialize the scales to 1.0 and update the scales after weight loading.
# This is espectially important when we load dummy weights first (providing
# wrong scales) and then load real weights (which misses scales and keeps the
# wrong scales from dummy load).
set_default_quant_scales
(
layer
,
register_buffer
=
True
)
# The output scale on host memory. This should be the input scale of
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
# the quant op after this attention layer.
...
@@ -93,9 +127,9 @@ def _init_kv_cache_quant(
...
@@ -93,9 +127,9 @@ def _init_kv_cache_quant(
quant_method
=
(
quant_method
=
(
quant_config
.
get_quant_method
(
layer
,
prefix
=
prefix
)
if
quant_config
else
None
quant_config
.
get_quant_method
(
layer
,
prefix
=
prefix
)
if
quant_config
else
None
)
)
if
quant_method
is
not
None
and
not
isinstance
(
quant_method
,
UnquantizedLinearMethod
# See [Note: Register q/k/v/prob scales in state dict]
):
if
should_load_quant_weights
(
quant_method
):
assert
isinstance
(
quant_method
,
BaseKVCacheMethod
)
assert
isinstance
(
quant_method
,
BaseKVCacheMethod
)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
# checkpoint config and become the "auto" behavior
...
@@ -169,10 +203,16 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -169,10 +203,16 @@ class Attention(nn.Module, AttentionLayerBase):
assert
num_heads
%
num_kv_heads
==
0
,
(
assert
num_heads
%
num_kv_heads
==
0
,
(
f
"num_heads (
{
num_heads
}
) is not divisible by num_kv_heads (
{
num_kv_heads
}
)"
f
"num_heads (
{
num_heads
}
) is not divisible by num_kv_heads (
{
num_kv_heads
}
)"
)
)
self
.
quant_config
=
quant_config
self
.
layer_name
=
prefix
# Initialize KV cache quantization attributes
# Initialize KV cache quantization attributes
_init_kv_cache_quant
(
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
,
kv_cache_dtype
,
calculate_kv_scales
self
,
self
.
quant_config
,
self
.
layer_name
,
kv_cache_dtype
,
calculate_kv_scales
,
)
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
...
@@ -249,7 +289,6 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -249,7 +289,6 @@ class Attention(nn.Module, AttentionLayerBase):
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
if
kv_sharing_target_layer_name
is
not
None
:
if
kv_sharing_target_layer_name
is
not
None
:
...
@@ -378,6 +417,17 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -378,6 +417,17 @@ class Attention(nn.Module, AttentionLayerBase):
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method
=
(
self
.
quant_config
.
get_quant_method
(
self
,
prefix
=
self
.
layer_name
)
if
self
.
quant_config
else
None
)
if
not
should_load_quant_weights
(
quant_method
):
set_default_quant_scales
(
self
,
register_buffer
=
False
)
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
return
self
.
attn_backend
...
@@ -453,10 +503,15 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -453,10 +503,15 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_cache_dtype
=
"auto"
kv_cache_dtype
=
"auto"
block_size
=
16
block_size
=
16
calculate_kv_scales
=
False
calculate_kv_scales
=
False
self
.
quant_config
=
quant_config
# Initialize KV cache quantization attributes
# Initialize KV cache quantization attributes
_init_kv_cache_quant
(
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
,
kv_cache_dtype
,
calculate_kv_scales
self
,
self
.
quant_config
,
self
.
layer_name
,
kv_cache_dtype
,
calculate_kv_scales
,
)
)
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
...
@@ -586,6 +641,17 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -586,6 +641,17 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if
hasattr
(
self
.
impl
,
"process_weights_after_loading"
):
if
hasattr
(
self
.
impl
,
"process_weights_after_loading"
):
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method
=
(
self
.
quant_config
.
get_quant_method
(
self
,
prefix
=
self
.
layer_name
)
if
self
.
quant_config
else
None
)
if
not
should_load_quant_weights
(
quant_method
):
set_default_quant_scales
(
self
,
register_buffer
=
False
)
def
calc_kv_scales
(
def
calc_kv_scales
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
)
->
None
:
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment