Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
648edcf7
Unverified
Commit
648edcf7
authored
Mar 27, 2026
by
Kyle Sayers
Committed by
GitHub
Mar 27, 2026
Browse files
[QeRL] Compose online quantization with quantized reloading (#38032)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
7ba425e9
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
184 additions
and
260 deletions
+184
-260
tests/model_executor/model_loader/test_reload.py
tests/model_executor/model_loader/test_reload.py
+57
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+27
-202
vllm/model_executor/layers/quantization/mxfp8.py
vllm/model_executor/layers/quantization/mxfp8.py
+2
-0
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+19
-7
vllm/model_executor/model_loader/dummy_loader.py
vllm/model_executor/model_loader/dummy_loader.py
+9
-0
vllm/model_executor/model_loader/reload/__init__.py
vllm/model_executor/model_loader/reload/__init__.py
+2
-0
vllm/model_executor/model_loader/reload/layerwise.py
vllm/model_executor/model_loader/reload/layerwise.py
+61
-29
vllm/model_executor/model_loader/reload/meta.py
vllm/model_executor/model_loader/reload/meta.py
+3
-4
vllm/model_executor/model_loader/reload/types.py
vllm/model_executor/model_loader/reload/types.py
+3
-3
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+1
-15
No files found.
tests/model_executor/model_loader/test_reload.py
View file @
648edcf7
...
...
@@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
add_perp
<
mul_perp
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"base_model,mul_model,add_model,quantization"
,
[
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
"fp8"
,
),
(
"inference-optimization/DeepSeek-V3-debug-empty"
,
"inference-optimization/DeepSeek-V3-debug-multiply"
,
"inference-optimization/DeepSeek-V3-debug-add"
,
"fp8"
,
),
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
"mxfp8"
,
),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
],
)
def
test_online_quantize_reload
(
base_model
,
mul_model
,
add_model
,
quantization
,
tp_size
,
vllm_runner
):
if
cuda_device_count_stateless
()
<
tp_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
if
quantization
==
"fp8"
and
not
current_platform
.
supports_fp8
():
pytest
.
skip
(
reason
=
"Requires FP8 support"
)
with
vllm_runner
(
model_name
=
base_model
,
quantization
=
quantization
,
tensor_parallel_size
=
tp_size
,
enable_expert_parallel
=
(
tp_size
>
1
and
"DeepSeek"
in
base_model
),
enable_prefix_caching
=
False
,
)
as
llm
:
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
mul_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
mul_perp
<
add_perp
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
add_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
add_perp
<
mul_perp
vllm/model_executor/layers/quantization/fp8.py
View file @
648edcf7
...
...
@@ -73,7 +73,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
vllm.model_executor.model_loader.weight_utils
import
initialize_single_dummy_weight
from
vllm.model_executor.model_loader.reload.layerwise
import
(
initialize_online_processing
,
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
...
...
@@ -496,8 +498,8 @@ class Fp8LinearMethod(LinearMethodBase):
class
Fp8OnlineLinearMethod
(
Fp8LinearMethod
):
"""Online version of Fp8LinearMethod
,
loads
the fp16/bf16
checkpoint
and quantize
d the
weights during loading."""
"""Online version of Fp8LinearMethod
which
loads
a full precision
checkpoint
and quantize
s
weights during loading."""
uses_meta_device
:
bool
=
True
...
...
@@ -519,84 +521,25 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# WEIGHT
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
param
=
layer
.
weight
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.
return
res
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_
weight_loader
,
weight_loader
=
weight_loader
,
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
layer
.
register_parameter
(
"weight"
,
weight
)
initialize_online_processing
(
layer
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
weight
.
device
==
torch
.
device
(
"meta"
):
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
layer
.
weight
.
weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
initialize_single_dummy_weight
(
layer
.
weight
)
# TODO(future): support block_quant in online quant path
assert
not
self
.
block_quant
...
...
@@ -845,9 +788,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# Allow for accessing weights and scales in standard way.
w13
=
layer
.
w13_weight
w2
=
layer
.
w2_weight
...
...
@@ -892,9 +832,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
w13_input_scale
,
w2_input_scale
)
# Prevent duplicate processing (e.g., during weight reload)
layer
.
_already_called_process_weights_after_loading
=
True
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
...
...
@@ -1013,86 +950,12 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs
=
extra_weight_attrs
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# add a counter to track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer
.
_w13_weight_orig_id
=
id
(
layer
.
w13_weight
)
layer
.
_w2_weight_orig_id
=
id
(
layer
.
w2_weight
)
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
if
id
(
param
)
==
layer
.
_w13_weight_orig_id
:
param
=
layer
.
w13_weight
elif
id
(
param
)
==
layer
.
_w2_weight_orig_id
:
param
=
layer
.
w2_weight
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
w13_weight
.
numel
()
+
layer
.
w2_weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.
return
res
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
extra_weight_attrs
=
new_extra_weight_attrs
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
dtype
=
params_dtype
,
),
...
...
@@ -1106,91 +969,53 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
# BIASES (for models like GPT-OSS that have biased MoE)
if
self
.
moe
.
has_bias
:
# Use the original weight_loader (not patched) for biases
orig_extra_weight_attrs
=
dict
(
extra_weight_attrs
)
orig_extra_weight_attrs
[
"weight_loader"
]
=
weight_loader
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
orig_extra_weight_attrs
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
layer
.
orig_dtype
),
torch
.
zeros
(
num_experts
,
hidden_size
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
orig_extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
initialize_online_processing
(
layer
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
w13_weight
.
device
==
torch
.
device
(
"meta"
):
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
{
"weight_loader"
:
layer
.
w13_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
initialize_single_dummy_weight
(
layer
.
w13_weight
)
if
layer
.
w2_weight
.
device
==
torch
.
device
(
"meta"
):
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
{
"weight_loader"
:
layer
.
w2_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
initialize_single_dummy_weight
(
layer
.
w2_weight
)
# If checkpoint is fp16, quantize in place.
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w13_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
w13_scale
=
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
for
expert
in
range
(
layer
.
local_num_experts
):
w13
[
expert
,
:,
:],
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
...
...
@@ -1207,8 +1032,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
w2
,
w13_scale
,
w2_scale
,
layer
.
w13_input_scale
,
layer
.
w2_input_scale
,
w13_input_scale
=
layer
.
w13_input_scale
,
w2_input_scale
=
layer
.
w2_input_scale
,
)
# Prevent duplicate processing (e.g., during weight reload)
...
...
vllm/model_executor/layers/quantization/mxfp8.py
View file @
648edcf7
...
...
@@ -337,6 +337,8 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w13_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
w13
,
w13_scale
=
self
.
_quantize_mxfp8_moe_weight
(
layer
.
w13_weight
)
w2
,
w2_scale
=
self
.
_quantize_mxfp8_moe_weight
(
layer
.
w2_weight
)
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
648edcf7
...
...
@@ -9,6 +9,7 @@ import vllm.envs as envs
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config.load
import
LoadConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.reload
import
finalize_layerwise_processing
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
process_weights_after_loading
,
...
...
@@ -49,16 +50,13 @@ class BaseModelLoader(ABC):
device_config
.
device
if
load_config
.
device
is
None
else
load_config
.
device
)
target_device
=
torch
.
device
(
load_device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
)
with
set_default_torch_dtype
(
model_config
.
dtype
),
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
)
log_model_inspection
(
model
)
logger
.
debug
(
"Loading weights on %s ..."
,
load_device
)
# Quantization does not happen in `load_weights` but after it
self
.
load_weights
(
model
,
model_config
)
# Log peak GPU memory after loading weights. This is needed
...
...
@@ -71,6 +69,11 @@ class BaseModelLoader(ABC):
scope
=
"local"
,
)
# Process weights into kernel format. Note that when using online
# quantization, weights are (typically) quantized as they are loaded.
if
_has_online_quant
(
model
):
finalize_layerwise_processing
(
model
,
model_config
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
...
...
@@ -84,3 +87,12 @@ def log_model_inspection(model: nn.Module) -> None:
from
vllm.model_inspection
import
format_model_inspection
logger
.
info
(
"vLLM model structure:
\n
%s"
,
format_model_inspection
(
model
))
def
_has_online_quant
(
model
:
nn
.
Module
):
for
module
in
model
.
modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
getattr
(
quant_method
,
"uses_meta_device"
,
False
):
return
True
return
False
vllm/model_executor/model_loader/dummy_loader.py
View file @
648edcf7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ModelConfig
from
vllm.config.load
import
LoadConfig
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.reload.meta
import
materialize_meta_tensor
from
vllm.model_executor.model_loader.reload.utils
import
get_layer_tensors
from
vllm.model_executor.model_loader.weight_utils
import
initialize_dummy_weights
...
...
@@ -23,6 +26,12 @@ class DummyModelLoader(BaseModelLoader):
pass
# Nothing to download
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
# materialize meta tensors as part of online quantization lifecycle
for
layer
in
model
.
modules
():
for
name
,
param
in
get_layer_tensors
(
layer
).
items
():
if
param
.
device
==
torch
.
device
(
"meta"
):
setattr
(
layer
,
name
,
materialize_meta_tensor
(
param
))
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
,
model_config
)
vllm/model_executor/model_loader/reload/__init__.py
View file @
648edcf7
...
...
@@ -21,12 +21,14 @@ Limitations:
__all__
=
[
"record_metadata_for_reloading"
,
"initialize_layerwise_reload"
,
"finalize_layerwise_processing"
,
"finalize_layerwise_reload"
,
"set_torchao_reload_attrs"
,
"support_quantized_model_reload_from_hp_weights"
,
]
from
.layerwise
import
(
finalize_layerwise_processing
,
finalize_layerwise_reload
,
initialize_layerwise_reload
,
record_metadata_for_reloading
,
...
...
vllm/model_executor/model_loader/reload/layerwise.py
View file @
648edcf7
...
...
@@ -28,6 +28,7 @@ __all__ = [
"get_layerwise_info"
,
"record_metadata_for_reloading"
,
"initialize_layerwise_reload"
,
"finalize_layerwise_processing"
,
"finalize_layerwise_reload"
,
]
...
...
@@ -89,7 +90,7 @@ def initialize_layerwise_reload(model: torch.nn.Module):
info
=
get_layerwise_info
(
layer
)
# Skip if the layer has already been initialized
if
info
.
can_
process
():
if
info
.
can_
load
():
continue
# Save current tensors for later copying
...
...
@@ -98,15 +99,21 @@ def initialize_layerwise_reload(model: torch.nn.Module):
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta
(
layer
,
info
)
# Track loading progress to determine when to process/copy
info
.
load_numel
=
0
info
.
load_numel_total
=
get_layer_size
(
layer
)
initialize_online_processing
(
layer
)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
if
_get_weight_loader
(
tensor
).
__name__
!=
"online_process_loader"
:
tensor
.
weight_loader
=
make_online_process_loader
(
layer
,
name
)
def
initialize_online_processing
(
layer
:
torch
.
nn
.
Module
):
info
=
get_layerwise_info
(
layer
)
# Track loading progress to determine when to process/copy
info
.
load_numel
=
0
info
.
load_numel_total
=
get_layer_size
(
layer
)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
if
_get_weight_loader
(
tensor
).
__name__
!=
"online_process_loader"
:
tensor
.
weight_loader
=
make_online_process_loader
(
layer
,
name
)
def
make_online_process_loader
(
layer
:
torch
.
nn
.
Module
,
param_name
:
str
)
->
Callable
:
...
...
@@ -118,7 +125,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
@
wraps
(
original_loader
,
assigned
=
(
"__doc__"
,
"__annotations__"
))
def
online_process_loader
(
*
args
,
**
kwargs
):
if
not
info
.
can_
process
():
if
not
info
.
can_
load
():
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
...
...
@@ -140,7 +147,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
bound_args
=
loader_signature
.
bind
(
*
args
,
**
kwargs
)
bound_args
.
apply_defaults
()
#
Cache
loaded weights, track loading progress
#
Buffer
loaded weights, track loading progress
info
.
loaded_weights
.
append
((
param_name
,
bound_args
))
num_loaded
,
ret
=
get_numel_loaded
(
original_loader
,
bound_args
)
info
.
load_numel
+=
num_loaded
...
...
@@ -163,19 +170,26 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
return
online_process_loader
def
finalize_layerwise_
reload
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
):
def
finalize_layerwise_
processing
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
):
"""
Remove the outermost layer of weight loading wrappers.
Apply processing to any layers which were not layerwise processed during loading.
This includes attention layers and layers which have weight elements which are not
loaded (due to padding).
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
:param model: model to finalize processing for
:param model_config: config needed for applying processing to attention layers
"""
model
.
_do_torchao_reload
=
model
.
_original_do_torchao_reload
if
hasattr
(
model
,
"_original_do_torchao_reload"
):
model
.
_do_torchao_reload
=
model
.
_original_do_torchao_reload
for
layer
in
model
.
modules
():
info
=
get_layerwise_info
(
layer
)
if
not
info
.
can_load
():
info
.
reset
()
continue
# Attention/MLA layers are processed after all other layers
if
isinstance
(
layer
,
(
Attention
,
MLAAttention
)):
...
...
@@ -184,17 +198,29 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
elif
info
.
kernel_tensors
is
None
:
raise
NotImplementedError
(
"Layerwise loading of Q/K/V scale weights is not implemented yet"
)
else
:
_place_kernel_tensors
(
layer
,
info
)
layer
.
process_weights_after_loading
(
model_config
.
dtype
)
# No weights were loaded, place kernel tensors back
elif
info
.
can_process
()
and
info
.
load_numel
<=
0
:
_place_kernel_tensors
(
layer
,
info
)
# No weights were loaded
elif
info
.
load_numel
<=
0
:
# first load but received no weights. This happens on dummy load
if
info
.
kernel_tensors
is
None
:
materialize_layer
(
layer
)
# reloading: place kernel tensors back as a fallback
else
:
logger
.
warning
(
"%s: Failed to load weights"
,
layer
.
__class__
.
__name__
)
_place_kernel_tensors
(
layer
,
info
)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to ex
e
cess memory usage
# Having too many of these delayed layers can lead to excess memory usage
# see Limitations(4)
elif
info
.
load_numel
>
0
and
info
.
load_numel
<
info
.
load_numel_total
:
# type: ignore[operator]
logger
.
debug
(
"%s: Delayed processing"
,
layer
.
__class__
.
__name__
)
...
...
@@ -203,20 +229,24 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
info
.
reset
()
def
finalize_layerwise_reload
(
*
args
,
**
kwargs
):
finalize_layerwise_processing
(
*
args
,
**
kwargs
)
def
_layerwise_process
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
"""
Finalize layer loading after all weights have been
cach
ed.
Finalize layer loading after all weights have been
buffer
ed.
This function:
1. Materializes the layer onto the target device
2. Loads all
cach
ed weights
2. Loads all
buffer
ed weights
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer
(
layer
)
# Reset
FP8
online quantization flag so process_weights_after_loading
# Reset online quantization flag so process_weights_after_loading
# will run again during reload
if
hasattr
(
layer
,
"_already_called_process_weights_after_loading"
):
delattr
(
layer
,
"_already_called_process_weights_after_loading"
)
...
...
@@ -225,7 +255,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
for
param
in
get_layer_tensors
(
layer
).
values
():
param
.
weight_loader
=
_get_original_loader
(
param
)
# Load all
cach
ed weights into materialized layer (using original loaders)
# Load all
buffer
ed weights into materialized layer (using original loaders)
for
name
,
args
in
info
.
loaded_weights
:
param
=
getattr
(
layer
,
name
)
args
.
arguments
[
"param"
]
=
param
...
...
@@ -239,13 +269,14 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
param
.
data
.
copy_
(
getattr
(
layer
,
name
))
for
name
,
buffer
in
buffers
.
items
():
buffer
.
data
.
copy_
(
getattr
(
layer
,
name
))
if
info
.
kernel_tensors
is
not
None
:
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
param
.
data
.
copy_
(
getattr
(
layer
,
name
))
for
name
,
buffer
in
buffers
.
items
():
buffer
.
data
.
copy_
(
getattr
(
layer
,
name
))
_place_kernel_tensors
(
layer
,
info
)
_place_kernel_tensors
(
layer
,
info
)
info
.
reset
()
logger
.
debug
(
"%s: Processed"
,
layer
.
__class__
.
__name__
)
...
...
@@ -268,6 +299,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for
name
in
get_layer_tensors
(
layer
):
delattr
(
layer
,
name
)
assert
info
.
kernel_tensors
is
not
None
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
layer
.
register_parameter
(
name
,
param
)
...
...
vllm/model_executor/model_loader/reload/meta.py
View file @
648edcf7
...
...
@@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None:
setattr
(
layer
,
name
,
materialize_meta_tensor
(
tensor
))
class
Meta
CopyCounter
(
TorchDispatchMode
):
class
CopyCounter
(
TorchDispatchMode
):
"""
Tracks total number of elements modified with `copy_`.
...
...
@@ -122,7 +122,7 @@ class MetaCopyCounter(TorchDispatchMode):
if
kwargs
is
None
:
kwargs
=
{}
if
func
is
torch
.
ops
.
aten
.
copy_
.
default
and
args
[
0
].
device
.
type
==
"meta"
:
if
func
is
torch
.
ops
.
aten
.
copy_
.
default
:
assert
args
[
0
].
numel
()
==
args
[
1
].
numel
()
self
.
copied_numel
+=
args
[
0
].
numel
()
...
...
@@ -140,7 +140,6 @@ def get_numel_loaded(
:return: number of elements loaded by the weight loader, the return value of the
weight loader
"""
assert
args
.
arguments
[
"param"
].
device
.
type
==
"meta"
with
MetaCopyCounter
()
as
counter
:
with
CopyCounter
()
as
counter
:
return_value
=
weight_loader
(
*
args
.
args
,
**
args
.
kwargs
)
return
counter
.
copied_numel
,
return_value
vllm/model_executor/model_loader/reload/types.py
View file @
648edcf7
...
...
@@ -16,8 +16,8 @@ class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
# kernel format (device)
kernel_tensors
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
# kernel format (device)
, used to copy into when reloading only
kernel_tensors
:
LayerTensors
|
None
=
None
# track how many restored elements are ready for loading
load_numel
:
int
=
0
...
...
@@ -29,5 +29,5 @@ class LayerReloadingInfo:
def
reset
(
self
):
self
.
__init__
(
restore_metadata
=
self
.
restore_metadata
)
# type: ignore[misc]
def
can_
process
(
self
)
->
bool
:
def
can_
load
(
self
)
->
bool
:
return
self
.
load_numel_total
is
not
None
vllm/model_executor/model_loader/weight_utils.py
View file @
648edcf7
...
...
@@ -1323,25 +1323,11 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
# Check if any module uses online quantization with meta device weights.
# If so, we'll skip initializing params on meta device since they'll be
# handled in `process_weights_after_loading`.
def
uses_meta_device
(
module
:
torch
.
nn
.
Module
)
->
bool
:
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
return
getattr
(
quant_method
,
"uses_meta_device"
,
False
)
has_online_quant
=
any
(
uses_meta_device
(
m
)
for
m
in
model
.
modules
())
for
param
in
model
.
state_dict
().
values
():
if
has_online_quant
and
param
.
device
==
torch
.
device
(
"meta"
):
# For online quantization, weights are created on meta device and
# dummy weight init will happen in `process_weights_after_loading`.
continue
initialize_single_dummy_weight
(
param
,
low
,
high
,
seed
)
@
torch
.
no_grad
()
def
initialize_single_dummy_weight
(
param
:
torch
.
Tensor
,
low
:
float
=
-
1e-3
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment