Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f857a03f
Unverified
Commit
f857a03f
authored
Jan 30, 2026
by
Kyle Sayers
Committed by
GitHub
Jan 30, 2026
Browse files
[QeRL] Layerwise Reloading (#32133)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
74898a70
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
923 additions
and
314 deletions
+923
-314
tests/conftest.py
tests/conftest.py
+22
-4
tests/model_executor/model_loader/test_reload.py
tests/model_executor/model_loader/test_reload.py
+150
-0
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+15
-4
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+1
-1
vllm/model_executor/model_loader/online_quantization.py
vllm/model_executor/model_loader/online_quantization.py
+0
-275
vllm/model_executor/model_loader/reload/__init__.py
vllm/model_executor/model_loader/reload/__init__.py
+37
-0
vllm/model_executor/model_loader/reload/layerwise.py
vllm/model_executor/model_loader/reload/layerwise.py
+270
-0
vllm/model_executor/model_loader/reload/meta.py
vllm/model_executor/model_loader/reload/meta.py
+146
-0
vllm/model_executor/model_loader/reload/sanitize.py
vllm/model_executor/model_loader/reload/sanitize.py
+50
-0
vllm/model_executor/model_loader/reload/torchao_decorator.py
vllm/model_executor/model_loader/reload/torchao_decorator.py
+58
-0
vllm/model_executor/model_loader/reload/types.py
vllm/model_executor/model_loader/reload/types.py
+33
-0
vllm/model_executor/model_loader/reload/utils.py
vllm/model_executor/model_loader/reload/utils.py
+31
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+16
-17
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+1
-1
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+1
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+90
-8
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+2
-2
No files found.
tests/conftest.py
View file @
f857a03f
...
...
@@ -27,7 +27,7 @@ import threading
from
collections.abc
import
Generator
from
contextlib
import
nullcontext
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
TypedDict
,
TypeVar
,
cast
,
TYPE_CHECKING
from
typing
import
Any
,
Callable
,
TypedDict
,
TypeVar
,
cast
,
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
pytest
...
...
@@ -1024,7 +1024,9 @@ class VllmRunner:
**
kwargs
,
)
def
generate_prompt_perplexity
(
self
,
prompts
:
list
[
str
])
->
list
[
float
]:
def
generate_prompt_perplexity
(
self
,
prompts
:
list
[
str
],
mask
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
float
]:
"""
Return the perplexity score associated with generating the prompts
...
...
@@ -1035,13 +1037,20 @@ class VllmRunner:
prompts
,
max_tokens
=
1
,
num_logprobs
=
None
,
num_prompt_logprobs
=
0
)
mask_prefix_lens
=
(
[
len
(
self
.
llm
.
get_tokenizer
()(
prefix
)[
"input_ids"
])
for
prefix
in
mask
]
if
mask
is
not
None
else
[
0
for
_
in
range
(
len
(
prompts
))]
)
perplexities
=
[]
for
output
in
outputs
:
for
output
,
mask_prefix_len
in
zip
(
outputs
,
mask_prefix_lens
)
:
output
=
cast
(
TokensTextLogprobsPromptLogprobs
,
output
)
token_datas
=
cast
(
list
[
dict
[
int
,
Logprob
]
|
None
],
output
[
3
])
assert
token_datas
[
0
]
is
None
token_log_probs
=
[]
for
token_data
in
token_datas
[
1
:]:
for
token_data
in
token_datas
[
mask_prefix_len
+
1
:]:
assert
token_data
is
not
None
assert
len
(
token_data
)
==
1
token_log_prob
=
list
(
token_data
.
values
())[
0
].
logprob
...
...
@@ -1122,6 +1131,9 @@ class VllmRunner:
def
get_llm
(
self
)
->
LLM
:
return
self
.
llm
def
collective_rpc
(
self
,
*
args
,
**
kwargs
):
return
self
.
llm
.
collective_rpc
(
*
args
,
**
kwargs
)
def
__enter__
(
self
):
return
self
...
...
@@ -1532,3 +1544,9 @@ def use_fresh_inductor_cache():
"""
with
fresh_cache
():
yield
@
pytest
.
fixture
(
scope
=
"function"
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
tests/model_executor/model_loader/test_reload.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
inspect
from
weakref
import
WeakKeyDictionary
,
ref
import
pytest
import
torch
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
from
vllm.model_executor.model_loader.reload.meta
import
(
capture_layer_to_meta
,
get_numel_loaded
,
materialize_layer
,
materialize_meta_tensor
,
restore_layer_on_meta
,
to_meta_tensor
,
)
from
vllm.model_executor.model_loader.reload.types
import
LayerReloadingInfo
from
vllm.model_executor.model_loader.reload.utils
import
get_layer_tensors
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
def
test_move_metatensors
():
tensor
=
torch
.
empty
((
1
,
2
,
3
))
meta_tensor
=
to_meta_tensor
(
tensor
)
materialized_tensor
=
materialize_meta_tensor
(
meta_tensor
)
assert
meta_tensor
.
device
.
type
==
"meta"
assert
tensor
.
device
==
materialized_tensor
.
device
assert
tensor
.
dtype
==
meta_tensor
.
dtype
==
materialized_tensor
.
dtype
assert
tensor
.
shape
==
meta_tensor
.
shape
==
materialized_tensor
.
shape
assert
tensor
.
__class__
==
meta_tensor
.
__class__
==
materialized_tensor
.
__class__
assert
tensor
.
__dict__
==
meta_tensor
.
__dict__
==
materialized_tensor
.
__dict__
def
test_reload_lifecycle
():
layer
=
torch
.
nn
.
Linear
(
2
,
3
)
info
=
LayerReloadingInfo
(
restore_metadata
=
capture_layer_to_meta
(
layer
))
restore_layer_on_meta
(
layer
,
info
)
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
meta_tensor
=
getattr
(
layer
,
name
)
assert
tensor
.
dtype
==
meta_tensor
.
dtype
assert
tensor
.
shape
==
meta_tensor
.
shape
assert
tensor
.
__class__
==
meta_tensor
.
__class__
assert
tensor
.
__dict__
==
meta_tensor
.
__dict__
materialize_layer
(
layer
)
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
materialized_tensor
=
getattr
(
layer
,
name
)
assert
tensor
.
dtype
==
materialized_tensor
.
dtype
assert
tensor
.
shape
==
materialized_tensor
.
shape
assert
tensor
.
__class__
==
materialized_tensor
.
__class__
assert
tensor
.
__dict__
==
materialized_tensor
.
__dict__
def
test_model_cleanup
(
dist_init
,
default_vllm_config
):
layer
=
QKVParallelLinear
(
2
,
3
,
4
)
assert
layer
.
weight
.
weight_loader
.
__self__
is
layer
info
=
LayerReloadingInfo
(
restore_metadata
=
capture_layer_to_meta
(
layer
))
mock_info_dict
:
WeakKeyDictionary
[
torch
.
nn
.
Module
,
LayerReloadingInfo
]
=
(
WeakKeyDictionary
()
)
mock_info_dict
[
layer
]
=
info
layer_ref
=
ref
(
layer
)
del
layer
gc
.
collect
()
assert
layer_ref
()
is
None
assert
len
(
mock_info_dict
)
==
0
def
test_get_numel_loaded
():
param
=
torch
.
empty
(
10
,
device
=
"meta"
)
loaded_weight
=
torch
.
empty
(
10
)
def
complex_weight_loader
(
param
,
loaded_weight
):
param
[:
3
]
=
loaded_weight
[:
3
]
param
[
5
:
8
]
=
loaded_weight
[
5
:
8
]
return
"value"
args
=
inspect
.
signature
(
complex_weight_loader
).
bind
(
param
,
loaded_weight
)
num_loaded
,
ret
=
get_numel_loaded
(
complex_weight_loader
,
args
)
assert
num_loaded
==
6
assert
ret
==
"value"
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"base_model,mul_model,add_model"
,
[
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
),
(
"inference-optimization/Qwen3-0.6B-FP8_BLOCK"
,
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK"
,
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK"
,
),
(
"inference-optimization/Qwen3-0.6B-W4A16-G128"
,
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128"
,
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128"
,
),
(
"inference-optimization/DeepSeek-V3-debug-empty"
,
"inference-optimization/DeepSeek-V3-debug-multiply"
,
"inference-optimization/DeepSeek-V3-debug-add"
,
),
(
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC"
,
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC"
,
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC"
,
),
(
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16"
,
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16"
,
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16"
,
),
],
)
def
test_reload_weights
(
base_model
,
mul_model
,
add_model
,
tp_size
,
vllm_runner
):
if
cuda_device_count_stateless
()
<
tp_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
if
"FP8"
in
base_model
and
not
current_platform
.
supports_fp8
():
pytest
.
skip
(
reason
=
"Requires FP8 support"
)
with
vllm_runner
(
model_name
=
base_model
,
tensor_parallel_size
=
tp_size
,
enable_expert_parallel
=
(
tp_size
>
1
and
"DeepSeek"
in
base_model
),
enable_prefix_caching
=
False
,
)
as
llm
:
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
mul_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
mul_perp
<
add_perp
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
add_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
add_perp
<
mul_perp
tests/quantization/test_torchao.py
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib.metadata
import
importlib.util
import
pytest
import
torch
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.platforms
import
current_platform
DTYPE
=
[
"bfloat16"
]
...
...
@@ -105,8 +105,8 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_online_quant_config_dict_json
(
vllm_runner
):
"""Testing on
the fly
quantization, load_weights integration point,
def
test_online_quant_config_dict_json
(
vllm_runner
,
enable_pickle
):
"""Testing on
line
quantization, load_weights integration point,
with config dict serialized to json string
"""
torch
.
_dynamo
.
reset
()
...
...
@@ -135,7 +135,18 @@ def test_online_quant_config_dict_json(vllm_runner):
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
4
)
assert
output
load_config
=
llm
.
llm
.
llm_engine
.
vllm_config
.
load_config
model_config
=
llm
.
llm
.
llm_engine
.
vllm_config
.
model_config
def
load_weights
(
model
):
model_loader
=
get_model_loader
(
load_config
)
weights_iterator
=
model_loader
.
get_all_weights
(
model_config
,
model
)
model
.
load_weights
(
weights_iterator
)
llm
.
apply_model
(
load_weights
)
reload_output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
4
)
assert
output
[
0
][
0
]
==
reload_output
[
0
][
0
]
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
f857a03f
...
...
@@ -543,7 +543,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
def
test_reload_weights_before_load_model
(
model_runner
):
with
pytest
.
raises
(
Assertion
Error
):
with
pytest
.
raises
(
Value
Error
):
model_runner
.
reload_weights
()
...
...
vllm/model_executor/model_loader/online_quantization.py
deleted
100644 → 0
View file @
74898a70
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
types
from
collections.abc
import
Iterable
import
torch
from
torch
import
nn
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.utils
import
process_weights_after_loading
logger
=
init_logger
(
__name__
)
# Notes for Online Quantization
# In terms of state of checkpoints, quantization config and their
# correspondance to online quantization:
# | Use Case | Checkpoints | model_config.quantization |
# | no quant | high precision | None |
# | offline quant | quantized | fp8, torchao etc. |
# | online quant | high precision | torchao etc. |
#
# The process for loading non-quantized checkpoint
# 1. load non-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
#
# The process for loading offline quantized checkpoint
# 1. load offline-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
# The process for unquantized model reloading
# (repeated run in RL training loop)
# first run
# UI1. load_weights: load bfloat16 weights
# UI2. process_weights_after_loading: any additional post processing
# subsequent run
# UC1: load_weights: load bfloat16 weights
# (shouldn't be any issues since we didn't change any attributes
# of the weights)
# UC2: process_weights_after_loading: any additional post processing
# The process for weight reloading with online quantization
# (repeated run in RL training loop)
# first run
# I1. load_weights: load bfloat16 weights
# I2. process_weights_after_loading:
# record weight metadata and attributes for R1 and R2
# quantize weights to fp8
# subsequent run
# (beginning model weight is in fp8)
# load_weights:
# R1. restore bfloat16 model weight metadata
# R2. restore the model weight attributes
# R3. reload bfloat16 weights
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# R5. (workaround for cudagraph), we restore the weight params to original quantized
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
# the weight update work well with cudagraph
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights
def
maybe_save_metadata_and_attributes_for_weight_reloading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
):
# following is to support on the fly quantization, currently only supported
# for torchao
if
model_config
.
quantization
!=
"torchao"
:
return
from
vllm.model_executor.model_loader.weight_utils
import
get_quant_config
quant_config
=
get_quant_config
(
model_config
,
None
)
# If checkpoint is already torchao serialized, this means it's
# pre-quantized quantization case, we'll skip saving the metadata
# Otherwise, this is Step I2 of initialization steps of
# online quantization
# This step record the weights metadata and weight attributes so we can
# restore the bfloat16 model weights during the relad step (R1 and R2)
# see Notes in online_quantization.py for more details
if
not
(
hasattr
(
quant_config
,
"is_checkpoint_torchao_serialized"
)
and
not
quant_config
.
is_checkpoint_torchao_serialized
):
return
# This is the I2 step of online quantiztion that saves
# metadata and attributes of weights so they can be used in R1 and
# R2 step, note that we only save these during initialization
# Includes two things
# 1. save floating point metadata (shape, dtype, device) for init
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
if
getattr
(
model
,
"weight_metadata_and_attr_saved"
,
False
):
return
# save the dtype, shape and device for model parameter, used for
# restoring the model high precision parameters before
# reloading the weights
assert
not
hasattr
(
model
,
"original_weights_rebuild_keys"
)
model
.
original_weights_rebuild_keys
=
{}
for
name
,
p
in
model
.
named_parameters
():
model
.
original_weights_rebuild_keys
[
name
]
=
{
"shape"
:
p
.
shape
,
"dtype"
:
p
.
dtype
,
"device"
:
p
.
device
,
}
# record the weight attributes (loader functions etc.)
# so these can be recovered later when we reload the weights
# structure: {"weight_name": {"weight_attr_key": attr}}
assert
not
hasattr
(
model
,
"recorded_weight_attr"
)
model
.
recorded_weight_attr
=
{}
for
name
,
param
in
model
.
named_parameters
():
model
.
recorded_weight_attr
[
name
]
=
{}
for
key
in
param
.
__dict__
:
if
hasattr
(
param
,
key
):
attr
=
getattr
(
param
,
key
)
if
not
callable
(
attr
):
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
elif
hasattr
(
attr
,
"__self__"
)
and
param
is
attr
.
__self__
:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
.
__func__
else
:
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
# mark the metadata and attributes saved so we don't run it again
model
.
_model_config
=
model_config
model
.
weight_metadata_and_attr_saved
=
True
def
_bond_method_to_cls
(
func
,
obj
):
if
hasattr
(
func
,
"__self__"
)
or
not
callable
(
func
):
# If the function is already bound to an instance, return it as is
return
func
else
:
return
types
.
MethodType
(
func
,
obj
)
def
support_quantized_model_reload_from_hp_weights
(
original_load_weights
):
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights
"""
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4, R5 in the Notes
def
patched_model_load_weights
(
auto_weight_loader
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
*
,
mapper
=
None
)
->
set
[
str
]:
model
=
auto_weight_loader
.
module
offline_quantization_or_first_run_of_online_quantization
=
not
getattr
(
model
,
"weight_metadata_and_attr_saved"
,
False
)
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see Notes in this file for more details
if
offline_quantization_or_first_run_of_online_quantization
:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
return
original_load_weights
(
auto_weight_loader
,
weights
,
mapper
=
mapper
)
model_config
=
model
.
_model_config
# TODO: Add fp8 support
assert
model_config
.
quantization
==
"torchao"
,
(
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
# TODO: maybe set remove_duplicate to True?
original_quantized_weight_dict
=
dict
(
model
.
named_parameters
(
remove_duplicate
=
False
)
)
named_modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
model_device
=
None
for
name
,
d
in
model
.
original_weights_rebuild_keys
.
items
():
_shape
=
d
[
"shape"
]
_dtype
=
d
[
"dtype"
]
_device
=
d
[
"device"
]
if
model_device
is
not
None
:
assert
model_device
==
_device
,
(
"Expecting all weights "
"to be in the same device for now, got both: "
f
"
{
model_device
}
and
{
_device
}
"
)
else
:
model_device
=
_device
if
name
in
original_quantized_weight_dict
:
module_name
,
weight_name
=
name
.
rsplit
(
"."
,
1
)
module
=
named_modules
[
module_name
]
setattr
(
module
,
weight_name
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
_shape
,
dtype
=
_dtype
,
device
=
_device
),
requires_grad
=
False
,
),
)
# Step R2: recover the weight attributes to the state before first loading
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for
full_weight_name
,
weight_attr_dict
in
model
.
recorded_weight_attr
.
items
():
for
attr_name
,
attr
in
weight_attr_dict
.
items
():
module_name
,
weight_name
=
full_weight_name
.
rsplit
(
"."
,
1
)
module
=
named_modules
[
module_name
]
weight
=
getattr
(
module
,
weight_name
)
if
not
hasattr
(
weight
,
attr_name
):
setattr
(
weight
,
attr_name
,
_bond_method_to_cls
(
attr
,
weight
))
# Step R3: reload bfloat16 / high precision weights
updated_params
=
original_load_weights
(
auto_weight_loader
,
weights
,
mapper
=
mapper
)
# Step R4: online quantize the weights
# manually process weights after loading
model
.
process_weights_after_loading_already_called
=
False
if
model_device
is
not
None
:
process_weights_after_loading
(
model
,
model_config
,
model_device
)
else
:
logger
.
warning_once
(
"model_device is None, skip calling process_weights_after_loading"
)
# Step R5 (workaround for cudagraph): restore the original quantized weights
# and do a copy_ of the currents weights to the original weights
updated_quantized_weights
=
dict
(
model
.
named_parameters
(
remove_duplicate
=
False
))
for
name
in
model
.
original_weights_rebuild_keys
:
if
name
in
original_quantized_weight_dict
:
original_quantized_weight
=
original_quantized_weight_dict
[
name
]
updated_quantized_weight
=
updated_quantized_weights
[
name
]
module_name
,
weight_name
=
name
.
rsplit
(
"."
,
1
)
module
=
named_modules
[
module_name
]
setattr
(
module
,
weight_name
,
original_quantized_weight
)
with
torch
.
no_grad
():
original_quantized_weight
.
copy_
(
updated_quantized_weight
)
del
original_quantized_weight_dict
del
named_modules
del
updated_quantized_weight
model
.
process_weights_after_loading_already_called
=
True
return
updated_params
return
patched_model_load_weights
vllm/model_executor/model_loader/reload/__init__.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layerwise weight reloading utilities for vLLM.
This module provides functionality to reload model weights layer-by-layer,
which is useful for weight updates without full model reconstruction.
Limitations:
1. Composition with CPU offloading has not been implemented
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
3. Tied parameters will only reflect processing from one of the parent layers (for
example, only processing from embed_tokens will have an effect)
4. This design assumes that the number of weights loaded from disk is the same as the
number of weights created at model init time. This is not true for quant methods
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
cases are non-issues for today's quant methods, but future quantizations may cause
reloading to fail
"""
__all__
=
[
"record_metadata_for_reloading"
,
"initialize_layerwise_reload"
,
"finalize_layerwise_reload"
,
"set_torchao_reload_attrs"
,
"support_quantized_model_reload_from_hp_weights"
,
]
from
.layerwise
import
(
finalize_layerwise_reload
,
initialize_layerwise_reload
,
record_metadata_for_reloading
,
)
from
.torchao_decorator
import
(
set_torchao_reload_attrs
,
support_quantized_model_reload_from_hp_weights
,
)
vllm/model_executor/model_loader/reload/layerwise.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
inspect
from
collections.abc
import
Callable
from
functools
import
wraps
from
weakref
import
WeakKeyDictionary
import
torch
from
vllm.attention.layer
import
Attention
,
MLAAttention
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
QuantizeMethodBase
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.meta
import
(
capture_layer_to_meta
,
get_numel_loaded
,
materialize_layer
,
restore_layer_on_meta
,
)
from
.types
import
LayerReloadingInfo
from
.utils
import
get_layer_params_buffers
,
get_layer_size
,
get_layer_tensors
logger
=
init_logger
(
__name__
)
__all__
=
[
"get_layerwise_info"
,
"record_metadata_for_reloading"
,
"initialize_layerwise_reload"
,
"finalize_layerwise_reload"
,
]
# Global dict storing information used for layerwise restoring, loading, and processing.
# For more information regarding what info is stored when, see `LayerReloadingInfo`
#
# Use a weak ref dictionary so that modules can be freed when the model is freed.
# Values are sanitized from references to the layer key in order to avoid circular refs
LAYERWISE_INFO
:
WeakKeyDictionary
[
torch
.
nn
.
Module
,
LayerReloadingInfo
]
=
(
WeakKeyDictionary
()
)
def
get_layerwise_info
(
layer
:
torch
.
nn
.
Module
)
->
LayerReloadingInfo
:
"""
Get information related to restoring and layerwise processing. If no previous
information existed, a new entry is constructed
"""
if
layer
not
in
LAYERWISE_INFO
:
LAYERWISE_INFO
[
layer
]
=
LayerReloadingInfo
()
return
LAYERWISE_INFO
[
layer
]
def
record_metadata_for_reloading
(
model
:
torch
.
nn
.
Module
):
"""
Record layer metadata needed for later reloading.
Stores parameter and buffer metadata as meta tensors for restoration.
Must be called before `initialize_layerwise_reload`.
"""
for
layer
in
model
.
modules
():
info
=
get_layerwise_info
(
layer
)
info
.
restore_metadata
=
capture_layer_to_meta
(
layer
)
@
torch
.
no_grad
()
def
initialize_layerwise_reload
(
model
:
torch
.
nn
.
Module
):
"""
Set up layerwise weight loading with deferred processing.
Must be called after `record_metadata_for_reloading`. This function:
1. Saves current kernel tensors for later copying
2. Restores layer parameters/buffers from metadata (on meta device)
3. Wraps weight loaders to defer processing until all weights are loaded
When all weights for a layer are loaded, the wrapped loaders will:
1. Materialize the layer onto the target device
2. Load all cached weights
3. Run quantization processing if applicable
4. Copy processed values back to original tensor storage
"""
# disable torchao reloading to avoid infinite recursion
model
.
_original_do_torchao_reload
=
getattr
(
model
,
"_do_torchao_reload"
,
False
)
model
.
_do_torchao_reload
=
False
for
layer
in
model
.
modules
():
info
=
get_layerwise_info
(
layer
)
# Skip if the layer has already been initialized
if
info
.
can_process
():
continue
# Save current tensors for later copying
info
.
kernel_tensors
=
get_layer_params_buffers
(
layer
)
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta
(
layer
,
info
)
# Track loading progress to determine when to process/copy
info
.
load_numel
=
0
info
.
load_numel_total
=
get_layer_size
(
layer
)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
if
_get_weight_loader
(
tensor
).
__name__
!=
"online_process_loader"
:
tensor
.
weight_loader
=
make_online_process_loader
(
layer
,
name
)
def
make_online_process_loader
(
layer
:
torch
.
nn
.
Module
,
param_name
:
str
)
->
Callable
:
"""Create a wrapped weight loader that defers processing."""
info
=
get_layerwise_info
(
layer
)
param
=
getattr
(
layer
,
param_name
)
original_loader
=
_get_original_loader
(
param
)
loader_signature
=
inspect
.
signature
(
original_loader
)
@
wraps
(
original_loader
,
assigned
=
(
"__doc__"
,
"__annotations__"
))
def
online_process_loader
(
*
args
,
**
kwargs
):
if
not
info
.
can_process
():
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
# weights (beyond load_numel_total) after it's already processed.
#
# Best solution is to ensure that `load_numel_total` reflects the
# actual number of weights loaded, either by modifying qconfigs to
# create as many weights as loaded (see padding issue as well)
# or maybe capturing how many weights are loaded on first pass
#
# For now, `load_numel_total` is still safe to use as long as
# there's no way to reach `load_numel_total` without loading all
# necessary weights. `weight_shape` is very small, so this is safe.
# see Limitations(4)
logger
.
debug
(
"%s: Excessive loading"
,
layer
.
__class__
.
__name__
)
return
# Bind and normalize arguments
bound_args
=
loader_signature
.
bind
(
*
args
,
**
kwargs
)
bound_args
.
apply_defaults
()
# Cache loaded weights, track loading progress
info
.
loaded_weights
.
append
((
param_name
,
bound_args
))
num_loaded
,
ret
=
get_numel_loaded
(
original_loader
,
bound_args
)
info
.
load_numel
+=
num_loaded
logger
.
debug
(
"%s: %d / %d"
,
layer
.
__class__
.
__name__
,
info
.
load_numel
,
info
.
load_numel_total
,
)
# Process and copy when all weights are loaded
if
info
.
load_numel
>=
info
.
load_numel_total
and
not
isinstance
(
# type: ignore[operator]
layer
,
(
Attention
,
MLAAttention
)
):
_layerwise_process
(
layer
,
info
)
return
ret
return
online_process_loader
def
finalize_layerwise_reload
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
):
"""
Remove the outermost layer of weight loading wrappers.
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
"""
model
.
_do_torchao_reload
=
model
.
_original_do_torchao_reload
for
layer
in
model
.
modules
():
info
=
get_layerwise_info
(
layer
)
# Attention/MLA layers are processed after all other layers
if
isinstance
(
layer
,
(
Attention
,
MLAAttention
)):
if
info
.
load_numel
>
0
:
raise
NotImplementedError
(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
else
:
_place_kernel_tensors
(
layer
,
info
)
layer
.
process_weights_after_loading
(
model_config
.
dtype
)
# No weights were loaded, place kernel tensors back
elif
info
.
can_process
()
and
info
.
load_numel
<=
0
:
_place_kernel_tensors
(
layer
,
info
)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# see Limitations(4)
elif
info
.
load_numel
>
0
and
info
.
load_numel
<
info
.
load_numel_total
:
# type: ignore[operator]
logger
.
debug
(
"%s: Delayed processing"
,
layer
.
__class__
.
__name__
)
_layerwise_process
(
layer
,
info
)
info
.
reset
()
def
_layerwise_process
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
"""
Finalize layer loading after all weights have been cached.
This function:
1. Materializes the layer onto the target device
2. Loads all cached weights
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer
(
layer
)
# Unwrap layerwise loading wrappers
for
param
in
get_layer_tensors
(
layer
).
values
():
param
.
weight_loader
=
_get_original_loader
(
param
)
# Load all cached weights into materialized layer (using original loaders)
for
name
,
args
in
info
.
loaded_weights
:
param
=
getattr
(
layer
,
name
)
args
.
arguments
[
"param"
]
=
param
param
.
weight_loader
(
*
args
.
args
,
**
args
.
kwargs
)
# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method
=
getattr
(
layer
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
quant_method
.
process_weights_after_loading
(
layer
)
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
param
.
data
.
copy_
(
getattr
(
layer
,
name
))
for
name
,
buffer
in
buffers
.
items
():
buffer
.
data
.
copy_
(
getattr
(
layer
,
name
))
_place_kernel_tensors
(
layer
,
info
)
info
.
reset
()
logger
.
debug
(
"%s: Processed"
,
layer
.
__class__
.
__name__
)
def
_get_original_loader
(
tensor
:
torch
.
Tensor
)
->
Callable
:
"""Return the weight loader with any layerwise wrappers removed"""
loader
=
_get_weight_loader
(
tensor
)
while
loader
.
__name__
==
"online_process_loader"
:
loader
=
loader
.
__wrapped__
# type: ignore[union-attr]
return
loader
def
_get_weight_loader
(
tensor
:
torch
.
Tensor
):
return
getattr
(
tensor
,
"weight_loader"
,
default_weight_loader
)
def
_place_kernel_tensors
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
for
name
in
get_layer_tensors
(
layer
):
delattr
(
layer
,
name
)
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
layer
.
register_parameter
(
name
,
param
)
for
name
,
buffer
in
buffers
.
items
():
layer
.
register_buffer
(
name
,
buffer
)
vllm/model_executor/model_loader/reload/meta.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
inspect
from
collections.abc
import
Callable
import
torch
from
torch.utils._python_dispatch
import
TorchDispatchMode
from
.sanitize
import
restore_layer_refs
,
sanitize_layer_refs
from
.types
import
LayerReloadingInfo
,
LayerTensors
from
.utils
import
get_layer_params_buffers
,
get_layer_tensors
__all__
=
[
"to_meta_tensor"
,
"materialize_meta_tensor"
,
"capture_layer_to_meta"
,
"restore_layer_on_meta"
,
"materialize_layer"
,
"get_numel_loaded"
,
]
SKIP_MODULES
:
set
[
str
]
=
{
"HadamardTransform"
}
SKIP_TENSORS
:
set
[
str
]
=
{
"_expert_map"
,
"expert_mask"
,
"expert_global_to_physical"
,
"expert_physical_to_global"
,
"expert_local_to_global"
,
}
def
to_meta_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Convert a tensor to a meta tensor while preserving class and attributes."""
meta_tensor
=
tensor
.
data
.
to
(
"meta"
)
meta_tensor
.
__class__
=
tensor
.
__class__
meta_tensor
.
__dict__
=
tensor
.
__dict__
.
copy
()
return
meta_tensor
def
materialize_meta_tensor
(
meta_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Materialize a meta tensor into an actual tensor on the current device.
Should be called within the torch device context for the given rank.
"""
tensor
=
torch
.
empty_strided
(
size
=
tuple
(
meta_tensor
.
size
()),
stride
=
tuple
(
meta_tensor
.
stride
()),
dtype
=
meta_tensor
.
dtype
,
requires_grad
=
False
,
)
tensor
.
__class__
=
meta_tensor
.
__class__
tensor
.
__dict__
=
meta_tensor
.
__dict__
.
copy
()
return
tensor
def
capture_layer_to_meta
(
layer
:
torch
.
nn
.
Module
)
->
LayerTensors
:
if
layer
.
__class__
.
__name__
in
SKIP_MODULES
:
return
({},
{})
params
,
buffers
=
get_layer_params_buffers
(
layer
)
return
(
{
name
:
sanitize_layer_refs
(
to_meta_tensor
(
param
),
layer
)
for
name
,
param
in
params
.
items
()
if
name
not
in
SKIP_TENSORS
},
{
name
:
sanitize_layer_refs
(
to_meta_tensor
(
buffer
),
layer
)
for
name
,
buffer
in
buffers
.
items
()
if
name
not
in
SKIP_TENSORS
},
)
def
restore_layer_on_meta
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
"""Restore a layer to model format with tensors on the meta device"""
if
layer
.
__class__
.
__name__
in
SKIP_MODULES
:
return
for
name
in
get_layer_tensors
(
layer
):
if
name
not
in
SKIP_TENSORS
:
delattr
(
layer
,
name
)
restore_params
,
restore_buffers
=
info
.
restore_metadata
for
name
,
param
in
restore_params
.
items
():
if
name
not
in
SKIP_TENSORS
:
param
=
restore_layer_refs
(
param
,
layer
)
layer
.
register_parameter
(
name
,
param
)
for
name
,
buffer
in
restore_buffers
.
items
():
if
name
not
in
SKIP_TENSORS
:
buffer
=
restore_layer_refs
(
buffer
,
layer
)
layer
.
register_buffer
(
name
,
buffer
)
def
materialize_layer
(
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Materialize all meta tensors in a layer to actual tensors."""
if
layer
.
__class__
.
__name__
in
SKIP_MODULES
:
return
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
if
name
not
in
SKIP_TENSORS
:
setattr
(
layer
,
name
,
materialize_meta_tensor
(
tensor
))
class
MetaCopyCounter
(
TorchDispatchMode
):
"""
Tracks total number of elements modified with `copy_`.
Useful for keeping track of weight loading where underlying weights can be
arbitrarily transformed (such as with `narrow`) before calling copy.
Note: Assumes that copy kwargs are not used.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
copied_numel
=
0
def
__torch_dispatch__
(
self
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
kwargs
=
{}
if
func
is
torch
.
ops
.
aten
.
copy_
.
default
and
args
[
0
].
device
.
type
==
"meta"
:
assert
args
[
0
].
numel
()
==
args
[
1
].
numel
()
self
.
copied_numel
+=
args
[
0
].
numel
()
return
func
(
*
args
,
**
kwargs
)
def
get_numel_loaded
(
weight_loader
:
Callable
,
args
:
inspect
.
BoundArguments
)
->
tuple
[
int
,
object
]:
"""
Determine how many elements would be loaded by a weight loader call.
:param weight loader: used to load weights
:param args: bound arguments to weight loader
:return: number of elements loaded by the weight loader, the return value of the
weight loader
"""
assert
args
.
arguments
[
"param"
].
device
.
type
==
"meta"
with
MetaCopyCounter
()
as
counter
:
return_value
=
weight_loader
(
*
args
.
args
,
**
args
.
kwargs
)
return
counter
.
copied_numel
,
return_value
vllm/model_executor/model_loader/reload/sanitize.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
types
import
MethodType
import
torch
__all__
=
[
"sanitize_layer_refs"
,
"restore_layer_refs"
]
layer_ref_sentinel
=
object
()
def
sanitize_layer_refs
(
tensor
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
"""
Removes references to layer held by tensor attributes. Specifically, removes the
`__self__` attribute of weight loader methods attached to the tensor.
Used by `capture_layer_to_meta` to avoid circular references to layers in
`LAYERWISE_INFO`, leading to modules never being cleaned up. Without sanitation,
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
even when the model is deleted.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for
key
,
value
in
tensor
.
__dict__
.
items
():
if
isinstance
(
value
,
MethodType
)
and
value
.
__self__
is
layer
:
tensor
.
__dict__
[
key
]
=
value
.
__func__
.
__get__
(
layer_ref_sentinel
)
return
tensor
def
restore_layer_refs
(
tensor
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
"""
Restores references to layer held by tensor attributes.
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
weight loading.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for
key
,
value
in
tensor
.
__dict__
.
items
():
if
isinstance
(
value
,
MethodType
)
and
value
.
__self__
is
layer_ref_sentinel
:
tensor
.
__dict__
[
key
]
=
value
.
__func__
.
__get__
(
layer
)
return
tensor
vllm/model_executor/model_loader/reload/torchao_decorator.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
functools
import
wraps
from
types
import
FunctionType
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.config
import
ModelConfig
from
.layerwise
import
(
finalize_layerwise_reload
,
initialize_layerwise_reload
,
)
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
__all__
=
[
"set_torchao_reload_attrs"
,
"support_quantized_model_reload_from_hp_weights"
]
def
set_torchao_reload_attrs
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
):
model
.
_do_torchao_reload
=
True
model
.
_model_config
=
model_config
def
support_quantized_model_reload_from_hp_weights
(
original_load_weights
:
FunctionType
):
"""
Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights.
Only applies to torchao quantized models. Assumes that all model weights are
loaded within a single weights iterator (cannot perform batched updates)
"""
@
wraps
(
original_load_weights
)
def
patched_model_load_weights
(
self
:
"AutoWeightsLoader"
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
*
args
,
**
kwargs
,
):
model
=
self
.
module
if
not
getattr
(
model
,
"_do_torchao_reload"
,
False
):
return
original_load_weights
(
self
,
weights
,
*
args
,
**
kwargs
)
initialize_layerwise_reload
(
model
)
loaded_weights
=
original_load_weights
(
self
,
weights
,
*
args
,
**
kwargs
)
finalize_layerwise_reload
(
model
,
model
.
_model_config
)
return
loaded_weights
return
patched_model_load_weights
vllm/model_executor/model_loader/reload/types.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
inspect
import
BoundArguments
import
torch
__all__
=
[
"LayerTensors"
,
"LayerReloadingInfo"
]
# encodes both parameters and buffers separately
LayerTensors
=
tuple
[
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
torch
.
Tensor
]]
@
dataclass
class
LayerReloadingInfo
:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
# kernel format (device)
kernel_tensors
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
# track how many restored elements are ready for loading
load_numel
:
int
=
0
load_numel_total
:
int
|
None
=
None
# stores arguments and tensors ready for loading
loaded_weights
:
list
[
tuple
[
str
,
BoundArguments
]]
=
field
(
default_factory
=
list
)
def
reset
(
self
):
self
.
__init__
(
restore_metadata
=
self
.
restore_metadata
)
# type: ignore[misc]
def
can_process
(
self
)
->
bool
:
return
self
.
load_numel_total
is
not
None
vllm/model_executor/model_loader/reload/utils.py
0 → 100644
View file @
f857a03f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
.types
import
LayerTensors
__all__
=
[
"get_layer_tensors"
,
"get_layer_params_buffers"
,
"get_layer_size"
,
]
def
get_layer_tensors
(
layer
:
torch
.
nn
.
Module
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Get all parameters and buffers from a module as a dict."""
params
,
buffers
=
get_layer_params_buffers
(
layer
)
return
params
|
buffers
def
get_layer_params_buffers
(
layer
:
torch
.
nn
.
Module
)
->
LayerTensors
:
"""Get all parameters and buffers of a module as a tuple of dicts."""
return
(
{
name
:
param
for
name
,
param
in
layer
.
_parameters
.
items
()
if
param
is
not
None
},
{
name
:
buffer
for
name
,
buffer
in
layer
.
_buffers
.
items
()
if
buffer
is
not
None
},
)
def
get_layer_size
(
layer
:
torch
.
nn
.
Module
)
->
int
:
"""Calculate total number of elements across all tensors in a layer."""
return
sum
(
tensor
.
numel
()
for
tensor
in
get_layer_tensors
(
layer
).
values
())
vllm/model_executor/model_loader/utils.py
View file @
f857a03f
...
...
@@ -18,6 +18,10 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.model_loader.reload
import
(
record_metadata_for_reloading
,
set_torchao_reload_attrs
,
)
from
vllm.model_executor.models.interfaces
import
SupportsQuant
from
vllm.utils.platform_utils
import
is_pin_memory_available
...
...
@@ -45,7 +49,9 @@ def initialize_model(
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
# new-style model class
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
,
prefix
=
prefix
):
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
model
=
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
record_metadata_for_reloading
(
model
)
return
model
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
...
...
@@ -75,27 +81,15 @@ def initialize_model(
if
"scheduler_config"
in
all_params
:
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
,
prefix
=
prefix
):
return
model_class
(
**
kwargs
)
model
=
model_class
(
**
kwargs
)
record_metadata_for_reloading
(
model
)
return
model
def
process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
if
getattr
(
model
,
"process_weights_after_loading_already_called"
,
False
):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger
.
debug_once
(
"process_weights_after_loading already called for model %s"
,
model
)
return
# to avoid circular dependency
from
vllm.model_executor.model_loader.online_quantization
import
(
maybe_save_metadata_and_attributes_for_weight_reloading
,
)
maybe_save_metadata_and_attributes_for_weight_reloading
(
model
,
model_config
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
...
...
@@ -117,6 +111,11 @@ def process_weights_after_loading(
# of process_weights_after_loading
module
.
process_weights_after_loading
(
model_config
.
dtype
)
# Needed for torchao model reloading via model.reload_weights
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
if
model_config
.
quantization
==
"torchao"
:
set_torchao_reload_attrs
(
model
,
model_config
)
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
...
...
vllm/model_executor/models/utils.py
View file @
f857a03f
...
...
@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
)
from
vllm.model_executor.model_loader.
online_quantization
import
(
from
vllm.model_executor.model_loader.
reload
import
(
support_quantized_model_reload_from_hp_weights
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
vllm/model_executor/parameter.py
View file @
f857a03f
...
...
@@ -522,8 +522,7 @@ class SharedWeightParameter(BasevLLMParameter):
@
property
def
data
(
self
):
raise
ValueError
(
"Accessing `data` of a "
"`PartitionedModelWeightParameter` is not allowed. "
"Accessing `data` of a `SharedWeightParameter` is not allowed. "
"Instead, use `get_partition` to get the weight of "
"the particular partition you want to access"
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
f857a03f
...
...
@@ -7,7 +7,7 @@ import itertools
import
threading
import
time
from
collections
import
defaultdict
from
collections.abc
import
Iterator
,
Sequence
from
collections.abc
import
Iterable
,
Iterator
,
Sequence
from
contextlib
import
contextmanager
from
copy
import
copy
,
deepcopy
from
dataclasses
import
dataclass
...
...
@@ -59,6 +59,10 @@ from vllm.model_executor.layers.rotary_embedding import (
XDRotaryEmbedding
,
)
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.model_loader.reload
import
(
finalize_layerwise_reload
,
initialize_layerwise_reload
,
)
from
vllm.model_executor.models.interfaces
import
(
MultiModalEmbeddings
,
SupportsMRoPE
,
...
...
@@ -2524,8 +2528,10 @@ class GPUModelRunner(
return
mm_embeds
,
is_mm_embed
def
get_model
(
self
)
->
nn
.
Module
:
# get raw model out of the cudagraph wrapper.
if
not
hasattr
(
self
,
"model"
):
raise
ValueError
(
"Cannot get model before model has been initialized"
)
if
isinstance
(
self
.
model
,
(
CUDAGraphWrapper
,
UBatchWrapper
)):
# get raw model out of the cudagraph wrapper.
return
self
.
model
.
unwrap
()
return
self
.
model
...
...
@@ -4270,13 +4276,89 @@ class GPUModelRunner(
return
None
def
reload_weights
(
self
)
->
None
:
assert
getattr
(
self
,
"model"
,
None
)
is
not
None
,
(
"Cannot reload weights before model is loaded."
def
reload_weights
(
self
,
weights_iterator
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
weights_path
:
str
|
None
=
None
,
is_checkpoint_format
:
bool
=
True
,
)
->
None
:
"""
Reload weights from a weights iterator or from disk
:param weights_iterator: weights to load into model
:param weights_path: path to load weights from if weights_iterator is not
provided. Use path of original model if neither is provided.
:param is_checkpoint_format: set to False if weights have already been processed
into kernel format (repacking, renaming, ect.)
"""
# TODO(@kylesayrs): generalize to all runners and loaders
# argument validation
if
weights_iterator
is
None
and
not
is_checkpoint_format
:
logger
.
warning
(
"Reloading from disk means that weights will be in checkpoint format. "
"Please use `is_checkpoint_format=True` "
"to avoid weight reloading errors"
)
model
=
self
.
get_model
()
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
counter_before_reloading
=
time
.
perf_counter
()
# load weights from disk if none are provided
if
weights_iterator
is
None
:
model_loader
=
get_model_loader
(
self
.
load_config
)
logger
.
info
(
"Reloading weights inplace..."
)
model_loader
.
load_weights
(
self
.
get_model
(),
model_config
=
self
.
model_config
)
if
not
hasattr
(
model_loader
,
"get_all_weights"
):
raise
NotImplementedError
(
f
"Model reloading with `
{
self
.
load_config
.
load_format
}
` format"
)
if
weights_path
is
not
None
:
self
.
model_config
.
model
=
weights_path
weights_iterator
=
model_loader
.
get_all_weights
(
self
.
model_config
,
model
)
weights_iterator
=
cast
(
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
weights_iterator
)
# begin loading weights
logger
.
info_once
(
"Reloading weights inplace..."
,
scope
=
"local"
)
load_device
=
(
self
.
vllm_config
.
load_config
.
device
or
self
.
vllm_config
.
device_config
.
device
)
with
torch
.
device
(
load_device
):
if
is_checkpoint_format
:
# load weights from checkpoint/ original model format
initialize_layerwise_reload
(
model
)
loaded_weights
=
model
.
load_weights
(
weights_iterator
)
finalize_layerwise_reload
(
model
,
self
.
model_config
)
else
:
# load weights from kernel format
logger
.
warning_once
(
"Reloading with `is_checkpoint_format=True` requires that "
"weights be in kernel format and already sharded"
,
scope
=
"local"
,
)
loaded_weights
=
set
()
for
name
,
loaded_weight
in
weights_iterator
:
param
=
model
.
get_parameter
(
name
)
# TODO: buffers?
param
.
copy_
(
loaded_weight
)
loaded_weights
.
add
(
name
)
# logging and validation
counter_after_reloading
=
time
.
perf_counter
()
diff_seconds
=
counter_after_reloading
-
counter_before_reloading
logger
.
info_once
(
"Reloading and processing weights took %.2f seconds"
,
diff_seconds
,
scope
=
"local"
,
)
if
self
.
model_config
.
quantization
is
None
and
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
logger
.
warning
(
"Following weights were not loaded from checkpoint: %s"
,
weights_not_loaded
,
)
def
save_tensorized_model
(
self
,
...
...
vllm/v1/worker/gpu_worker.py
View file @
f857a03f
...
...
@@ -280,8 +280,8 @@ class Worker(WorkerBase):
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
self
.
model_runner
.
update_config
(
overrides
)
def
reload_weights
(
self
)
->
None
:
self
.
model_runner
.
reload_weights
()
def
reload_weights
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
model_runner
.
reload_weights
(
*
args
,
**
kwargs
)
@
torch
.
inference_mode
()
def
determine_available_memory
(
self
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment