Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
648edcf7
Unverified
Commit
648edcf7
authored
Mar 27, 2026
by
Kyle Sayers
Committed by
GitHub
Mar 27, 2026
Browse files
[QeRL] Compose online quantization with quantized reloading (#38032)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
7ba425e9
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
184 additions
and
260 deletions
+184
-260
tests/model_executor/model_loader/test_reload.py
tests/model_executor/model_loader/test_reload.py
+57
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+27
-202
vllm/model_executor/layers/quantization/mxfp8.py
vllm/model_executor/layers/quantization/mxfp8.py
+2
-0
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+19
-7
vllm/model_executor/model_loader/dummy_loader.py
vllm/model_executor/model_loader/dummy_loader.py
+9
-0
vllm/model_executor/model_loader/reload/__init__.py
vllm/model_executor/model_loader/reload/__init__.py
+2
-0
vllm/model_executor/model_loader/reload/layerwise.py
vllm/model_executor/model_loader/reload/layerwise.py
+61
-29
vllm/model_executor/model_loader/reload/meta.py
vllm/model_executor/model_loader/reload/meta.py
+3
-4
vllm/model_executor/model_loader/reload/types.py
vllm/model_executor/model_loader/reload/types.py
+3
-3
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+1
-15
No files found.
tests/model_executor/model_loader/test_reload.py
View file @
648edcf7
...
@@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
...
@@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
add_perp
<
mul_perp
assert
add_perp
<
mul_perp
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"base_model,mul_model,add_model,quantization"
,
[
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
"fp8"
,
),
(
"inference-optimization/DeepSeek-V3-debug-empty"
,
"inference-optimization/DeepSeek-V3-debug-multiply"
,
"inference-optimization/DeepSeek-V3-debug-add"
,
"fp8"
,
),
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
"mxfp8"
,
),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
],
)
def
test_online_quantize_reload
(
base_model
,
mul_model
,
add_model
,
quantization
,
tp_size
,
vllm_runner
):
if
cuda_device_count_stateless
()
<
tp_size
:
pytest
.
skip
(
reason
=
"Not enough CUDA devices"
)
if
quantization
==
"fp8"
and
not
current_platform
.
supports_fp8
():
pytest
.
skip
(
reason
=
"Requires FP8 support"
)
with
vllm_runner
(
model_name
=
base_model
,
quantization
=
quantization
,
tensor_parallel_size
=
tp_size
,
enable_expert_parallel
=
(
tp_size
>
1
and
"DeepSeek"
in
base_model
),
enable_prefix_caching
=
False
,
)
as
llm
:
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
mul_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
mul_perp
<
add_perp
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
add_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
add_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 7"
],
mask
=
[
"3 4 ="
])[
0
]
assert
add_perp
<
mul_perp
vllm/model_executor/layers/quantization/fp8.py
View file @
648edcf7
...
@@ -73,7 +73,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -73,7 +73,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
initialize_single_dummy_weight
from
vllm.model_executor.model_loader.reload.layerwise
import
(
initialize_online_processing
,
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
BlockQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
...
@@ -496,8 +498,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -496,8 +498,8 @@ class Fp8LinearMethod(LinearMethodBase):
class
Fp8OnlineLinearMethod
(
Fp8LinearMethod
):
class
Fp8OnlineLinearMethod
(
Fp8LinearMethod
):
"""Online version of Fp8LinearMethod
,
loads
the fp16/bf16
checkpoint
"""Online version of Fp8LinearMethod
which
loads
a full precision
checkpoint
and quantize
d the
weights during loading."""
and quantize
s
weights during loading."""
uses_meta_device
:
bool
=
True
uses_meta_device
:
bool
=
True
...
@@ -519,84 +521,25 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
...
@@ -519,84 +521,25 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
layer
.
orig_dtype
=
params_dtype
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
layer
.
weight_block_size
=
None
# WEIGHT
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
param
=
layer
.
weight
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.
return
res
weight
=
ModelWeightParameter
(
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
output_size_per_partition
,
output_size_per_partition
,
input_size_per_partition
,
input_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
# materialized and processed during loading
device
=
"meta"
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
input_dim
=
1
,
input_dim
=
1
,
output_dim
=
0
,
output_dim
=
0
,
weight_loader
=
patched_
weight_loader
,
weight_loader
=
weight_loader
,
)
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
initialize_online_processing
(
layer
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
weight
.
device
==
torch
.
device
(
"meta"
):
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
layer
.
weight
.
weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
initialize_single_dummy_weight
(
layer
.
weight
)
# TODO(future): support block_quant in online quant path
# TODO(future): support block_quant in online quant path
assert
not
self
.
block_quant
assert
not
self
.
block_quant
...
@@ -845,9 +788,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -845,9 +788,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# Allow for accessing weights and scales in standard way.
# Allow for accessing weights and scales in standard way.
w13
=
layer
.
w13_weight
w13
=
layer
.
w13_weight
w2
=
layer
.
w2_weight
w2
=
layer
.
w2_weight
...
@@ -892,9 +832,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -892,9 +832,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
w13_input_scale
,
w2_input_scale
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
w13_input_scale
,
w2_input_scale
)
)
# Prevent duplicate processing (e.g., during weight reload)
layer
.
_already_called_process_weights_after_loading
=
True
def
maybe_make_prepare_finalize
(
def
maybe_make_prepare_finalize
(
self
,
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
...
@@ -1013,86 +950,12 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1013,86 +950,12 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer
.
orig_dtype
=
params_dtype
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
layer
.
weight_block_size
=
None
# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs
=
extra_weight_attrs
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# add a counter to track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer
.
_w13_weight_orig_id
=
id
(
layer
.
w13_weight
)
layer
.
_w2_weight_orig_id
=
id
(
layer
.
w2_weight
)
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
if
id
(
param
)
==
layer
.
_w13_weight_orig_id
:
param
=
layer
.
w13_weight
elif
id
(
param
)
==
layer
.
_w2_weight_orig_id
:
param
=
layer
.
w2_weight
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
w13_weight
.
numel
()
+
layer
.
w2_weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.
return
res
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
extra_weight_attrs
=
new_extra_weight_attrs
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
num_experts
,
2
*
intermediate_size_per_partition
,
2
*
intermediate_size_per_partition
,
hidden_size
,
hidden_size
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
device
=
"meta"
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
...
@@ -1106,91 +969,53 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1106,91 +969,53 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
intermediate_size_per_partition
,
intermediate_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
# materialized and processed during loading
device
=
"meta"
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
# BIASES (for models like GPT-OSS that have biased MoE)
# BIASES (for models like GPT-OSS that have biased MoE)
if
self
.
moe
.
has_bias
:
if
self
.
moe
.
has_bias
:
# Use the original weight_loader (not patched) for biases
orig_extra_weight_attrs
=
dict
(
extra_weight_attrs
)
orig_extra_weight_attrs
[
"weight_loader"
]
=
weight_loader
w13_bias
=
torch
.
nn
.
Parameter
(
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
num_experts
,
num_experts
,
2
*
intermediate_size_per_partition
,
2
*
intermediate_size_per_partition
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
layer
.
orig_dtype
,
dtype
=
layer
.
orig_dtype
,
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
orig_extra_weight_attrs
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
w2_bias
=
torch
.
nn
.
Parameter
(
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
layer
.
orig_dtype
),
torch
.
zeros
(
num_experts
,
hidden_size
,
device
=
"meta"
,
# materialized and processed during loading
dtype
=
layer
.
orig_dtype
,
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
orig_extra_weight_attrs
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
layer
.
w13_input_scale
=
None
initialize_online_processing
(
layer
)
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
w13_weight
.
device
==
torch
.
device
(
"meta"
):
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
{
"weight_loader"
:
layer
.
w13_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
initialize_single_dummy_weight
(
layer
.
w13_weight
)
if
layer
.
w2_weight
.
device
==
torch
.
device
(
"meta"
):
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
{
"weight_loader"
:
layer
.
w2_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
initialize_single_dummy_weight
(
layer
.
w2_weight
)
# If checkpoint is fp16, quantize in place.
fp8_dtype
=
current_platform
.
fp8_dtype
()
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w13_scale
=
layer
.
w13_weight_scale
w13_scale
=
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
layer
.
w2_weight_scale
w2_scale
=
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
for
expert
in
range
(
layer
.
local_num_experts
):
for
expert
in
range
(
layer
.
local_num_experts
):
w13
[
expert
,
:,
:],
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
w13
[
expert
,
:,
:],
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
...
@@ -1207,8 +1032,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1207,8 +1032,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
w2
,
w2
,
w13_scale
,
w13_scale
,
w2_scale
,
w2_scale
,
layer
.
w13_input_scale
,
w13_input_scale
=
layer
.
w13_input_scale
,
layer
.
w2_input_scale
,
w2_input_scale
=
layer
.
w2_input_scale
,
)
)
# Prevent duplicate processing (e.g., during weight reload)
# Prevent duplicate processing (e.g., during weight reload)
...
...
vllm/model_executor/layers/quantization/mxfp8.py
View file @
648edcf7
...
@@ -337,6 +337,8 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
...
@@ -337,6 +337,8 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w13_scale
=
layer
.
w13_weight_scale
w13_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
w2_scale
=
layer
.
w2_weight_scale
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
w13
,
w13_scale
=
self
.
_quantize_mxfp8_moe_weight
(
layer
.
w13_weight
)
w13
,
w13_scale
=
self
.
_quantize_mxfp8_moe_weight
(
layer
.
w13_weight
)
w2
,
w2_scale
=
self
.
_quantize_mxfp8_moe_weight
(
layer
.
w2_weight
)
w2
,
w2_scale
=
self
.
_quantize_mxfp8_moe_weight
(
layer
.
w2_weight
)
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
648edcf7
...
@@ -9,6 +9,7 @@ import vllm.envs as envs
...
@@ -9,6 +9,7 @@ import vllm.envs as envs
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config.load
import
LoadConfig
from
vllm.config.load
import
LoadConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.reload
import
finalize_layerwise_processing
from
vllm.model_executor.model_loader.utils
import
(
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
initialize_model
,
process_weights_after_loading
,
process_weights_after_loading
,
...
@@ -49,16 +50,13 @@ class BaseModelLoader(ABC):
...
@@ -49,16 +50,13 @@ class BaseModelLoader(ABC):
device_config
.
device
if
load_config
.
device
is
None
else
load_config
.
device
device_config
.
device
if
load_config
.
device
is
None
else
load_config
.
device
)
)
target_device
=
torch
.
device
(
load_device
)
target_device
=
torch
.
device
(
load_device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
),
target_device
:
with
target_device
:
model
=
initialize_model
(
model
=
initialize_model
(
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
vllm_config
=
vllm_config
,
model_config
=
model_config
,
prefix
=
prefix
)
)
log_model_inspection
(
model
)
log_model_inspection
(
model
)
logger
.
debug
(
"Loading weights on %s ..."
,
load_device
)
logger
.
debug
(
"Loading weights on %s ..."
,
load_device
)
# Quantization does not happen in `load_weights` but after it
self
.
load_weights
(
model
,
model_config
)
self
.
load_weights
(
model
,
model_config
)
# Log peak GPU memory after loading weights. This is needed
# Log peak GPU memory after loading weights. This is needed
...
@@ -71,6 +69,11 @@ class BaseModelLoader(ABC):
...
@@ -71,6 +69,11 @@ class BaseModelLoader(ABC):
scope
=
"local"
,
scope
=
"local"
,
)
)
# Process weights into kernel format. Note that when using online
# quantization, weights are (typically) quantized as they are loaded.
if
_has_online_quant
(
model
):
finalize_layerwise_processing
(
model
,
model_config
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
return
model
.
eval
()
...
@@ -84,3 +87,12 @@ def log_model_inspection(model: nn.Module) -> None:
...
@@ -84,3 +87,12 @@ def log_model_inspection(model: nn.Module) -> None:
from
vllm.model_inspection
import
format_model_inspection
from
vllm.model_inspection
import
format_model_inspection
logger
.
info
(
"vLLM model structure:
\n
%s"
,
format_model_inspection
(
model
))
logger
.
info
(
"vLLM model structure:
\n
%s"
,
format_model_inspection
(
model
))
def
_has_online_quant
(
model
:
nn
.
Module
):
for
module
in
model
.
modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
getattr
(
quant_method
,
"uses_meta_device"
,
False
):
return
True
return
False
vllm/model_executor/model_loader/dummy_loader.py
View file @
648edcf7
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.config.load
import
LoadConfig
from
vllm.config.load
import
LoadConfig
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.reload.meta
import
materialize_meta_tensor
from
vllm.model_executor.model_loader.reload.utils
import
get_layer_tensors
from
vllm.model_executor.model_loader.weight_utils
import
initialize_dummy_weights
from
vllm.model_executor.model_loader.weight_utils
import
initialize_dummy_weights
...
@@ -23,6 +26,12 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -23,6 +26,12 @@ class DummyModelLoader(BaseModelLoader):
pass
# Nothing to download
pass
# Nothing to download
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
# materialize meta tensors as part of online quantization lifecycle
for
layer
in
model
.
modules
():
for
name
,
param
in
get_layer_tensors
(
layer
).
items
():
if
param
.
device
==
torch
.
device
(
"meta"
):
setattr
(
layer
,
name
,
materialize_meta_tensor
(
param
))
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
,
model_config
)
initialize_dummy_weights
(
model
,
model_config
)
vllm/model_executor/model_loader/reload/__init__.py
View file @
648edcf7
...
@@ -21,12 +21,14 @@ Limitations:
...
@@ -21,12 +21,14 @@ Limitations:
__all__
=
[
__all__
=
[
"record_metadata_for_reloading"
,
"record_metadata_for_reloading"
,
"initialize_layerwise_reload"
,
"initialize_layerwise_reload"
,
"finalize_layerwise_processing"
,
"finalize_layerwise_reload"
,
"finalize_layerwise_reload"
,
"set_torchao_reload_attrs"
,
"set_torchao_reload_attrs"
,
"support_quantized_model_reload_from_hp_weights"
,
"support_quantized_model_reload_from_hp_weights"
,
]
]
from
.layerwise
import
(
from
.layerwise
import
(
finalize_layerwise_processing
,
finalize_layerwise_reload
,
finalize_layerwise_reload
,
initialize_layerwise_reload
,
initialize_layerwise_reload
,
record_metadata_for_reloading
,
record_metadata_for_reloading
,
...
...
vllm/model_executor/model_loader/reload/layerwise.py
View file @
648edcf7
...
@@ -28,6 +28,7 @@ __all__ = [
...
@@ -28,6 +28,7 @@ __all__ = [
"get_layerwise_info"
,
"get_layerwise_info"
,
"record_metadata_for_reloading"
,
"record_metadata_for_reloading"
,
"initialize_layerwise_reload"
,
"initialize_layerwise_reload"
,
"finalize_layerwise_processing"
,
"finalize_layerwise_reload"
,
"finalize_layerwise_reload"
,
]
]
...
@@ -89,7 +90,7 @@ def initialize_layerwise_reload(model: torch.nn.Module):
...
@@ -89,7 +90,7 @@ def initialize_layerwise_reload(model: torch.nn.Module):
info
=
get_layerwise_info
(
layer
)
info
=
get_layerwise_info
(
layer
)
# Skip if the layer has already been initialized
# Skip if the layer has already been initialized
if
info
.
can_
process
():
if
info
.
can_
load
():
continue
continue
# Save current tensors for later copying
# Save current tensors for later copying
...
@@ -98,6 +99,12 @@ def initialize_layerwise_reload(model: torch.nn.Module):
...
@@ -98,6 +99,12 @@ def initialize_layerwise_reload(model: torch.nn.Module):
# Restore layer parameters/buffers onto meta device
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta
(
layer
,
info
)
restore_layer_on_meta
(
layer
,
info
)
initialize_online_processing
(
layer
)
def
initialize_online_processing
(
layer
:
torch
.
nn
.
Module
):
info
=
get_layerwise_info
(
layer
)
# Track loading progress to determine when to process/copy
# Track loading progress to determine when to process/copy
info
.
load_numel
=
0
info
.
load_numel
=
0
info
.
load_numel_total
=
get_layer_size
(
layer
)
info
.
load_numel_total
=
get_layer_size
(
layer
)
...
@@ -118,7 +125,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
...
@@ -118,7 +125,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
@
wraps
(
original_loader
,
assigned
=
(
"__doc__"
,
"__annotations__"
))
@
wraps
(
original_loader
,
assigned
=
(
"__doc__"
,
"__annotations__"
))
def
online_process_loader
(
*
args
,
**
kwargs
):
def
online_process_loader
(
*
args
,
**
kwargs
):
if
not
info
.
can_
process
():
if
not
info
.
can_
load
():
# Unfortunately, some qconfigs are set up to load the same weight
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
# each of the qkv partitions. This results in layers loading extra
...
@@ -140,7 +147,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
...
@@ -140,7 +147,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
bound_args
=
loader_signature
.
bind
(
*
args
,
**
kwargs
)
bound_args
=
loader_signature
.
bind
(
*
args
,
**
kwargs
)
bound_args
.
apply_defaults
()
bound_args
.
apply_defaults
()
#
Cache
loaded weights, track loading progress
#
Buffer
loaded weights, track loading progress
info
.
loaded_weights
.
append
((
param_name
,
bound_args
))
info
.
loaded_weights
.
append
((
param_name
,
bound_args
))
num_loaded
,
ret
=
get_numel_loaded
(
original_loader
,
bound_args
)
num_loaded
,
ret
=
get_numel_loaded
(
original_loader
,
bound_args
)
info
.
load_numel
+=
num_loaded
info
.
load_numel
+=
num_loaded
...
@@ -163,19 +170,26 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
...
@@ -163,19 +170,26 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
return
online_process_loader
return
online_process_loader
def
finalize_layerwise_
reload
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
):
def
finalize_layerwise_
processing
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
):
"""
"""
Remove the outermost layer of weight loading wrappers.
Apply processing to any layers which were not layerwise processed during loading.
This includes attention layers and layers which have weight elements which are not
loaded (due to padding).
This function should be applied after `initialize_layerwise_reload` is applied
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
:param model: model to finalize processing for
:param model_config: config needed for applying processing to attention layers
"""
"""
if
hasattr
(
model
,
"_original_do_torchao_reload"
):
model
.
_do_torchao_reload
=
model
.
_original_do_torchao_reload
model
.
_do_torchao_reload
=
model
.
_original_do_torchao_reload
for
layer
in
model
.
modules
():
for
layer
in
model
.
modules
():
info
=
get_layerwise_info
(
layer
)
info
=
get_layerwise_info
(
layer
)
if
not
info
.
can_load
():
info
.
reset
()
continue
# Attention/MLA layers are processed after all other layers
# Attention/MLA layers are processed after all other layers
if
isinstance
(
layer
,
(
Attention
,
MLAAttention
)):
if
isinstance
(
layer
,
(
Attention
,
MLAAttention
)):
...
@@ -184,17 +198,29 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
...
@@ -184,17 +198,29 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
)
elif
info
.
kernel_tensors
is
None
:
raise
NotImplementedError
(
"Layerwise loading of Q/K/V scale weights is not implemented yet"
)
else
:
else
:
_place_kernel_tensors
(
layer
,
info
)
_place_kernel_tensors
(
layer
,
info
)
layer
.
process_weights_after_loading
(
model_config
.
dtype
)
layer
.
process_weights_after_loading
(
model_config
.
dtype
)
# No weights were loaded, place kernel tensors back
# No weights were loaded
elif
info
.
can_process
()
and
info
.
load_numel
<=
0
:
elif
info
.
load_numel
<=
0
:
# first load but received no weights. This happens on dummy load
if
info
.
kernel_tensors
is
None
:
materialize_layer
(
layer
)
# reloading: place kernel tensors back as a fallback
else
:
logger
.
warning
(
"%s: Failed to load weights"
,
layer
.
__class__
.
__name__
)
_place_kernel_tensors
(
layer
,
info
)
_place_kernel_tensors
(
layer
,
info
)
# Process non-attention layers which did not load all elements. This can happen
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to ex
e
cess memory usage
# Having too many of these delayed layers can lead to excess memory usage
# see Limitations(4)
# see Limitations(4)
elif
info
.
load_numel
>
0
and
info
.
load_numel
<
info
.
load_numel_total
:
# type: ignore[operator]
elif
info
.
load_numel
>
0
and
info
.
load_numel
<
info
.
load_numel_total
:
# type: ignore[operator]
logger
.
debug
(
"%s: Delayed processing"
,
layer
.
__class__
.
__name__
)
logger
.
debug
(
"%s: Delayed processing"
,
layer
.
__class__
.
__name__
)
...
@@ -203,20 +229,24 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
...
@@ -203,20 +229,24 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
info
.
reset
()
info
.
reset
()
def
finalize_layerwise_reload
(
*
args
,
**
kwargs
):
finalize_layerwise_processing
(
*
args
,
**
kwargs
)
def
_layerwise_process
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
def
_layerwise_process
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
):
"""
"""
Finalize layer loading after all weights have been
cach
ed.
Finalize layer loading after all weights have been
buffer
ed.
This function:
This function:
1. Materializes the layer onto the target device
1. Materializes the layer onto the target device
2. Loads all
cach
ed weights
2. Loads all
buffer
ed weights
3. Runs quantization processing if applicable
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
4. Copies processed values back to original tensor storage
"""
"""
# Materialize layer tensors onto device
# Materialize layer tensors onto device
materialize_layer
(
layer
)
materialize_layer
(
layer
)
# Reset
FP8
online quantization flag so process_weights_after_loading
# Reset online quantization flag so process_weights_after_loading
# will run again during reload
# will run again during reload
if
hasattr
(
layer
,
"_already_called_process_weights_after_loading"
):
if
hasattr
(
layer
,
"_already_called_process_weights_after_loading"
):
delattr
(
layer
,
"_already_called_process_weights_after_loading"
)
delattr
(
layer
,
"_already_called_process_weights_after_loading"
)
...
@@ -225,7 +255,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
...
@@ -225,7 +255,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
for
param
in
get_layer_tensors
(
layer
).
values
():
for
param
in
get_layer_tensors
(
layer
).
values
():
param
.
weight_loader
=
_get_original_loader
(
param
)
param
.
weight_loader
=
_get_original_loader
(
param
)
# Load all
cach
ed weights into materialized layer (using original loaders)
# Load all
buffer
ed weights into materialized layer (using original loaders)
for
name
,
args
in
info
.
loaded_weights
:
for
name
,
args
in
info
.
loaded_weights
:
param
=
getattr
(
layer
,
name
)
param
=
getattr
(
layer
,
name
)
args
.
arguments
[
"param"
]
=
param
args
.
arguments
[
"param"
]
=
param
...
@@ -239,6 +269,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
...
@@ -239,6 +269,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Copy processed values into original tensor storage (preserves cudagraph refs)
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
# this code is a no-op if not reloading (because kernel tensors is empty)
if
info
.
kernel_tensors
is
not
None
:
parameters
,
buffers
=
info
.
kernel_tensors
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
for
name
,
param
in
parameters
.
items
():
param
.
data
.
copy_
(
getattr
(
layer
,
name
))
param
.
data
.
copy_
(
getattr
(
layer
,
name
))
...
@@ -268,6 +299,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
...
@@ -268,6 +299,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for
name
in
get_layer_tensors
(
layer
):
for
name
in
get_layer_tensors
(
layer
):
delattr
(
layer
,
name
)
delattr
(
layer
,
name
)
assert
info
.
kernel_tensors
is
not
None
parameters
,
buffers
=
info
.
kernel_tensors
parameters
,
buffers
=
info
.
kernel_tensors
for
name
,
param
in
parameters
.
items
():
for
name
,
param
in
parameters
.
items
():
layer
.
register_parameter
(
name
,
param
)
layer
.
register_parameter
(
name
,
param
)
...
...
vllm/model_executor/model_loader/reload/meta.py
View file @
648edcf7
...
@@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None:
...
@@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None:
setattr
(
layer
,
name
,
materialize_meta_tensor
(
tensor
))
setattr
(
layer
,
name
,
materialize_meta_tensor
(
tensor
))
class
Meta
CopyCounter
(
TorchDispatchMode
):
class
CopyCounter
(
TorchDispatchMode
):
"""
"""
Tracks total number of elements modified with `copy_`.
Tracks total number of elements modified with `copy_`.
...
@@ -122,7 +122,7 @@ class MetaCopyCounter(TorchDispatchMode):
...
@@ -122,7 +122,7 @@ class MetaCopyCounter(TorchDispatchMode):
if
kwargs
is
None
:
if
kwargs
is
None
:
kwargs
=
{}
kwargs
=
{}
if
func
is
torch
.
ops
.
aten
.
copy_
.
default
and
args
[
0
].
device
.
type
==
"meta"
:
if
func
is
torch
.
ops
.
aten
.
copy_
.
default
:
assert
args
[
0
].
numel
()
==
args
[
1
].
numel
()
assert
args
[
0
].
numel
()
==
args
[
1
].
numel
()
self
.
copied_numel
+=
args
[
0
].
numel
()
self
.
copied_numel
+=
args
[
0
].
numel
()
...
@@ -140,7 +140,6 @@ def get_numel_loaded(
...
@@ -140,7 +140,6 @@ def get_numel_loaded(
:return: number of elements loaded by the weight loader, the return value of the
:return: number of elements loaded by the weight loader, the return value of the
weight loader
weight loader
"""
"""
assert
args
.
arguments
[
"param"
].
device
.
type
==
"meta"
with
CopyCounter
()
as
counter
:
with
MetaCopyCounter
()
as
counter
:
return_value
=
weight_loader
(
*
args
.
args
,
**
args
.
kwargs
)
return_value
=
weight_loader
(
*
args
.
args
,
**
args
.
kwargs
)
return
counter
.
copied_numel
,
return_value
return
counter
.
copied_numel
,
return_value
vllm/model_executor/model_loader/reload/types.py
View file @
648edcf7
...
@@ -16,8 +16,8 @@ class LayerReloadingInfo:
...
@@ -16,8 +16,8 @@ class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
restore_metadata
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
# kernel format (device)
# kernel format (device)
, used to copy into when reloading only
kernel_tensors
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
kernel_tensors
:
LayerTensors
|
None
=
None
# track how many restored elements are ready for loading
# track how many restored elements are ready for loading
load_numel
:
int
=
0
load_numel
:
int
=
0
...
@@ -29,5 +29,5 @@ class LayerReloadingInfo:
...
@@ -29,5 +29,5 @@ class LayerReloadingInfo:
def
reset
(
self
):
def
reset
(
self
):
self
.
__init__
(
restore_metadata
=
self
.
restore_metadata
)
# type: ignore[misc]
self
.
__init__
(
restore_metadata
=
self
.
restore_metadata
)
# type: ignore[misc]
def
can_
process
(
self
)
->
bool
:
def
can_
load
(
self
)
->
bool
:
return
self
.
load_numel_total
is
not
None
return
self
.
load_numel_total
is
not
None
vllm/model_executor/model_loader/weight_utils.py
View file @
648edcf7
...
@@ -1323,25 +1323,11 @@ def initialize_dummy_weights(
...
@@ -1323,25 +1323,11 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
the parameter's number of elements and its data type.
"""
"""
# Check if any module uses online quantization with meta device weights.
# If so, we'll skip initializing params on meta device since they'll be
# handled in `process_weights_after_loading`.
def
uses_meta_device
(
module
:
torch
.
nn
.
Module
)
->
bool
:
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
return
getattr
(
quant_method
,
"uses_meta_device"
,
False
)
has_online_quant
=
any
(
uses_meta_device
(
m
)
for
m
in
model
.
modules
())
for
param
in
model
.
state_dict
().
values
():
for
param
in
model
.
state_dict
().
values
():
if
has_online_quant
and
param
.
device
==
torch
.
device
(
"meta"
):
# For online quantization, weights are created on meta device and
# dummy weight init will happen in `process_weights_after_loading`.
continue
initialize_single_dummy_weight
(
param
,
low
,
high
,
seed
)
initialize_single_dummy_weight
(
param
,
low
,
high
,
seed
)
@
torch
.
no_grad
()
def
initialize_single_dummy_weight
(
def
initialize_single_dummy_weight
(
param
:
torch
.
Tensor
,
param
:
torch
.
Tensor
,
low
:
float
=
-
1e-3
,
low
:
float
=
-
1e-3
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment