Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19fa90ed
Unverified
Commit
19fa90ed
authored
Apr 16, 2026
by
Asaf Gardin
Committed by
GitHub
Apr 15, 2026
Browse files
[Quantization] - Layerwise reloading of Attention/KV quantized models (#38995)
Signed-off-by:
Josephasafg
<
ajgard7@gmail.com
>
parent
03f8d3a5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
98 additions
and
27 deletions
+98
-27
tests/model_executor/model_loader/test_reload.py
tests/model_executor/model_loader/test_reload.py
+28
-0
vllm/model_executor/model_loader/reload/__init__.py
vllm/model_executor/model_loader/reload/__init__.py
+2
-3
vllm/model_executor/model_loader/reload/layerwise.py
vllm/model_executor/model_loader/reload/layerwise.py
+66
-22
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+2
-2
No files found.
tests/model_executor/model_loader/test_reload.py
View file @
19fa90ed
...
...
@@ -164,6 +164,34 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
assert
add_perp
<
mul_perp
def
test_kv_scale_reload
(
vllm_runner
):
"""Test reloading a checkpoint that contains k_scale/v_scale weights."""
if
not
current_platform
.
supports_fp8
():
pytest
.
skip
(
reason
=
"Requires FP8 support"
)
model
=
"nm-testing/Llama-3.2-1B-Instruct-FP8-KV"
# Load dummy weights, then reload real checkpoint
with
vllm_runner
(
model_name
=
model
,
load_format
=
"dummy"
,
enable_prefix_caching
=
False
,
max_model_len
=
16
,
max_num_seqs
=
1
,
)
as
llm
:
llm
.
collective_rpc
(
"update_config"
,
kwargs
=
{
"overrides"
:
{
"load_config"
:
{
"load_format"
:
"auto"
}}},
)
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
model
})
reloaded_perp
=
llm
.
generate_prompt_perplexity
(
[
"The capital of France is the city of Paris"
],
mask
=
[
"The capital of France is"
],
)[
0
]
assert
reloaded_perp
<
10
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
pytest
.
param
(
1
),
pytest
.
param
(
2
,
marks
=
[
pytest
.
mark
.
slow_test
])]
)
...
...
vllm/model_executor/model_loader/reload/__init__.py
View file @
19fa90ed
...
...
@@ -8,10 +8,9 @@ which is useful for weight updates without full model reconstruction.
Limitations:
1. Composition with CPU offloading has not been implemented
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
3. Tied parameters will only reflect processing from one of the parent layers (for
2. Tied parameters will only reflect processing from one of the parent layers (for
example, only processing from embed_tokens will have an effect)
4
. This design assumes that the number of weights loaded from disk is the same as the
3
. This design assumes that the number of weights loaded from disk is the same as the
number of weights created at model init time. This is not true for quant methods
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
cases are non-issues for today's quant methods, but future quantizations may cause
...
...
vllm/model_executor/model_loader/reload/layerwise.py
View file @
19fa90ed
...
...
@@ -200,6 +200,8 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
if
hasattr
(
model
,
"_original_do_torchao_reload"
):
model
.
_do_torchao_reload
=
model
.
_original_do_torchao_reload
deferred_attn
:
list
[
tuple
[
torch
.
nn
.
Module
,
LayerReloadingInfo
]]
=
[]
for
layer
in
model
.
modules
():
info
=
get_layerwise_info
(
layer
)
if
not
info
.
can_load
():
...
...
@@ -208,22 +210,11 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
# Attention/MLA layers are processed after all other layers
if
isinstance
(
layer
,
(
Attention
,
MLAAttention
)):
if
info
.
load_numel
>
0
:
raise
NotImplementedError
(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
elif
info
.
kernel_tensors
is
None
:
raise
NotImplementedError
(
"Layerwise loading of Q/K/V scale weights is not implemented yet"
)
else
:
_place_kernel_tensors
(
layer
,
info
)
layer
.
process_weights_after_loading
(
model_config
.
dtype
)
deferred_attn
.
append
((
layer
,
info
))
continue
# No weights were loaded
el
if
info
.
load_numel
<=
0
:
if
info
.
load_numel
<=
0
:
# first load: checkpoint did not contain weights for this layer
if
info
.
kernel_tensors
is
None
:
_layerwise_process
(
layer
,
info
)
...
...
@@ -244,11 +235,58 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
info
.
reset
()
# Process attention layers after all other layers are done
for
layer
,
info
in
deferred_attn
:
_finalize_attention_layer
(
layer
,
info
,
model_config
)
info
.
reset
()
def
finalize_layerwise_reload
(
*
args
,
**
kwargs
):
finalize_layerwise_processing
(
*
args
,
**
kwargs
)
def
_finalize_attention_layer
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
,
model_config
:
ModelConfig
)
->
None
:
if
info
.
load_numel
>
0
and
info
.
kernel_tensors
is
not
None
:
# Reload with new scale weights from checkpoint
_place_kernel_tensors
(
layer
,
info
)
_reload_attention_scales
(
layer
,
info
)
elif
info
.
load_numel
>
0
or
info
.
kernel_tensors
is
None
:
raise
ValueError
(
"Layerwise loading of attention layers is not supported. "
"Attention must always process after linears."
)
else
:
_place_kernel_tensors
(
layer
,
info
)
layer
.
process_weights_after_loading
(
model_config
.
dtype
)
def
_reload_attention_scales
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
)
->
None
:
"""Load and process attention scale weights (k_scale, v_scale, etc.)
during reload.
Assumes dtype/shapes of attention tensors do not change during
processing, since we use .data.copy_() to preserve kernel tensor
references."""
quant_method
=
getattr
(
layer
,
"quant_method"
,
None
)
if
quant_method
is
None
:
return
# Re-create scale Parameters with sentinel values so unloaded scales
# are correctly detected by process_weights_after_loading
quant_method
.
create_weights
(
layer
)
for
name
,
args
in
info
.
loaded_weights
:
param
=
getattr
(
layer
,
name
)
args
.
arguments
[
"param"
]
=
param
_get_weight_loader
(
param
)(
*
args
.
args
,
**
args
.
kwargs
)
quant_method
.
process_weights_after_loading
(
layer
)
_copy_and_restore_kernel_tensors
(
layer
,
info
)
def
_layerwise_process
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
"""
Finalize layer loading after all weights have been buffered.
...
...
@@ -278,7 +316,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
param
.
weight_loader
(
*
args
.
args
,
**
args
.
kwargs
)
# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method
=
getattr
(
layer
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
quant_method
.
process_weights_after_loading
(
layer
)
...
...
@@ -286,13 +323,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
if
info
.
kernel_tensors
is
not
None
:
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
param
.
data
.
copy_
(
getattr
(
layer
,
name
))
for
name
,
buffer
in
buffers
.
items
():
buffer
.
data
.
copy_
(
getattr
(
layer
,
name
))
_place_kernel_tensors
(
layer
,
info
)
_copy_and_restore_kernel_tensors
(
layer
,
info
)
info
.
reset
()
logger
.
debug
(
"%s: Processed"
,
layer
.
__class__
.
__name__
)
...
...
@@ -311,6 +342,19 @@ def _get_weight_loader(tensor: torch.Tensor):
return
getattr
(
tensor
,
"weight_loader"
,
default_weight_loader
)
def
_copy_and_restore_kernel_tensors
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
"""Copy processed values into original kernel tensor storage and restore
kernel tensor references on the layer. Preserves cudagraph references."""
assert
info
.
kernel_tensors
is
not
None
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
param
.
data
.
copy_
(
getattr
(
layer
,
name
))
for
name
,
buffer
in
buffers
.
items
():
buffer
.
data
.
copy_
(
getattr
(
layer
,
name
))
_place_kernel_tensors
(
layer
,
info
)
def
_place_kernel_tensors
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
for
name
in
get_layer_tensors
(
layer
):
delattr
(
layer
,
name
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
19fa90ed
...
...
@@ -1364,8 +1364,8 @@ def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> N
if
param
.
numel
()
==
1
and
loaded_weight
.
numel
()
==
1
:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
#
"broadcast" instead of
copy
param
.
data
.
fill
_
(
loaded_weight
.
item
(
))
#
reshape to match before
copy
ing
param
.
data
.
copy
_
(
loaded_weight
.
view
(
param
.
shape
))
else
:
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Attempted to load weight (
{
loaded_weight
.
size
()
}
) "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment