Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fccd5325
Unverified
Commit
fccd5325
authored
Dec 09, 2025
by
Kyle Sayers
Committed by
GitHub
Dec 09, 2025
Browse files
[Quantization] FP8 Weight Reloading for Quantized RL Rollout (#28480)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
00e5cbb9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
207 additions
and
87 deletions
+207
-87
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+88
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+72
-79
vllm/model_executor/layers/quantization/kv_cache.py
vllm/model_executor/layers/quantization/kv_cache.py
+7
-0
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+4
-3
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+7
-4
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+4
-1
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+25
-0
No files found.
tests/quantization/test_fp8.py
View file @
fccd5325
...
...
@@ -10,10 +10,14 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8Config
,
Fp8KVCacheMethod
,
Fp8LinearMethod
,
Fp8MoEMethod
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.platforms
import
current_platform
MODELS
=
[
...
...
@@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
torch
.
narrow
(
y_nc_pad
,
0
,
0
,
x_nc
.
shape
[
0
]),
inv_scale_nc
,
dtype
),
)
@
pytest
.
mark
.
parametrize
(
"method_cls"
,
[
Fp8LinearMethod
,
Fp8MoEMethod
])
# FP8 weight reloading does not support online quantization
@
pytest
.
mark
.
parametrize
(
"is_checkpoint_fp8_serialized"
,
[
True
])
# skip False
@
pytest
.
mark
.
parametrize
(
"weight_block_size"
,
[
None
,
[
1
,
1
]])
# any postprocessing that is applied to the weights such as padding and repacking
# (excluding device sharding) must also be applied to the reloaded weights
#
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@
pytest
.
mark
.
parametrize
(
"use_marlin"
,
[
False
])
# skip True
def
test_fp8_reloading
(
method_cls
,
is_checkpoint_fp8_serialized
,
weight_block_size
,
use_marlin
,
dist_init
):
if
is_checkpoint_fp8_serialized
is
False
:
pytest
.
skip
(
"FP8 weight reloading does not support online quantization"
)
if
method_cls
is
Fp8MoEMethod
and
weight_block_size
is
None
:
pytest
.
skip
(
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
"If this is your use case, consider using a restore function like #26327"
)
with
torch
.
device
(
"cuda:0"
):
config
=
Fp8Config
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
weight_block_size
=
weight_block_size
,
)
if
method_cls
is
Fp8LinearMethod
:
layer
=
torch
.
nn
.
Linear
(
1
,
1
)
method
=
method_cls
(
config
)
method
.
create_weights
(
layer
=
layer
,
input_size_per_partition
=
1
,
output_partition_sizes
=
[
1
],
input_size
=
1
,
output_size
=
1
,
params_dtype
=
torch
.
bfloat16
,
weight_loader
=
default_weight_loader
,
)
else
:
layer
=
FusedMoE
(
num_experts
=
1
,
top_k
=
1
,
hidden_size
=
1
,
intermediate_size
=
1
,
)
method
=
method_cls
(
config
,
layer
)
method
.
create_weights
(
layer
=
layer
,
num_experts
=
1
,
hidden_size
=
1
,
intermediate_size_per_partition
=
1
,
params_dtype
=
torch
.
bfloat16
,
weight_loader
=
default_weight_loader
,
)
method
.
use_marlin
=
use_marlin
# capture weights format during loading
original_metadata
=
[
(
name
,
param
.
shape
,
getattr
(
param
,
"weight_loader"
,
default_weight_loader
))
for
name
,
param
in
layer
.
named_parameters
()
]
# test loading
for
name
,
shape
,
_
in
original_metadata
:
param
=
getattr
(
layer
,
name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
torch
.
zeros
(
shape
))
# cannot use empty
method
.
process_weights_after_loading
(
layer
)
# test reloading works after loading
# assuming that no reshaping occurred
for
name
,
shape
,
original_weight_loader
in
original_metadata
:
param
=
getattr
(
layer
,
name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
assert
weight_loader
is
original_weight_loader
weight_loader
(
param
,
torch
.
zeros
(
shape
))
# cannot use empty
method
.
process_weights_after_loading
(
layer
)
vllm/model_executor/layers/quantization/fp8.py
View file @
fccd5325
...
...
@@ -94,7 +94,7 @@ from vllm.model_executor.parameter import (
ModelWeightParameter
,
PerTensorScaleParameter
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.deep_gemm
import
(
...
...
@@ -548,46 +548,50 @@ class Fp8LinearMethod(LinearMethodBase):
assert
not
self
.
act_q_static
size_k_first
=
False
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
weight
,
weight_scale
_inv
=
process_fp8_weight_block_strategy
(
layer
.
weight
,
layer
.
weight_scale_inv
)
# Delete the weight_scale_inv parameter to avoid confusion
# with the weight_scale parameter
del
layer
.
weight_scale_inv
# Update layer with new values
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
replace_parameter
(
layer
,
"weight_scale_inv"
,
weight_scale_inv
.
data
)
# If checkpoint not serialized fp8, quantize the weights.
elif
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
else
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if
not
self
.
use_marlin
:
weight
,
weight_scale
,
input_scale
=
(
process_fp8_weight_tensor_strategy
(
weight
,
weight_scale
,
layer
.
logical_widths
,
getattr
(
layer
,
"input_scale"
,
None
),
)
)
if
self
.
act_q_static
:
assert
input_scale
is
not
None
input_scale
=
input_scale
.
max
()
weight
=
weight
.
t
()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
# Update layer with new values.
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
if
input_scale
is
not
None
:
replace_parameter
(
layer
,
"input_scale"
,
input_scale
)
else
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if
not
self
.
use_marlin
:
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_tensor_strategy
(
weight
,
weight_scale
,
layer
.
logical_widths
,
getattr
(
layer
,
"input_scale"
,
None
),
)
if
self
.
act_q_static
:
assert
input_scale
is
not
None
input_scale
=
input_scale
.
max
()
weight
=
weight
.
t
()
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
input_scale
=
(
Parameter
(
input_scale
,
requires_grad
=
False
)
if
input_scale
is
not
None
else
None
)
layer
.
input_scale
=
None
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
...
...
@@ -614,7 +618,7 @@ class Fp8LinearMethod(LinearMethodBase):
return
self
.
w8a8_block_fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
_inv
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
...
...
@@ -643,10 +647,15 @@ class Fp8LinearMethod(LinearMethodBase):
return
torch
.
nn
.
functional
.
linear
(
x
,
weight_bf16
.
t
(),
bias
)
if
self
.
use_marlin
:
if
self
.
block_quant
:
weight_scale
=
layer
.
weight_scale_inv
else
:
weight_scale
=
layer
.
weight_scale
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
...
...
@@ -660,7 +669,7 @@ class Fp8LinearMethod(LinearMethodBase):
return
self
.
w8a8_block_fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
_inv
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
...
...
@@ -937,22 +946,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale_inv
=
layer
.
w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses.
layer
.
w13_weight
=
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale_inv
=
Parameter
(
w13_weight_scale_inv
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
Parameter
(
w2_weight_scale_inv
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w13_weight_scale_inv"
,
w13_weight_scale_inv
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
replace_parameter
(
layer
,
"w2_weight_scale_inv"
,
w2_weight_scale_inv
)
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"
w13_weight
"
,
shuffled_w13
)
replace_parameter
(
layer
,
"
w2_weight
"
,
shuffled_w2
)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
...
...
@@ -990,13 +995,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
replace_parameter
(
layer
,
"w13_weight_scale"
,
torch
.
ones
(
layer
.
local_num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
,
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
local_num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
...
...
@@ -1005,16 +1011,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
if
self
.
rocm_aiter_moe_enabled
:
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
rocm_aiter_ops
.
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"
w13_weight
"
,
shuffled_w13
)
replace_parameter
(
layer
,
"
w2_weight
"
,
shuffled_w2
)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
...
...
@@ -1035,12 +1042,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_input_scale"
,
layer
.
w13_input_scale
.
max
())
replace_parameter
(
layer
,
"w2_input_scale"
,
layer
.
w2_input_scale
.
max
())
if
current_platform
.
is_fp8_fnuz
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
(
...
...
@@ -1054,22 +1057,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_weight"
,
w13_weight
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_weight_scale
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_input_scale"
,
w13_input_scale
)
replace_parameter
(
layer
,
"w2_weight"
,
w2_weight
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_weight_scale
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w2_input_scale"
,
w2_input_scale
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
...
...
@@ -1093,12 +1088,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"
w13_weight
"
,
shuffled_w13
)
replace_parameter
(
layer
,
"
w2_weight
"
,
shuffled_w2
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
max_w13_scales
)
if
self
.
flashinfer_moe_backend
is
not
None
:
# NOTE: weights have to be swapped since the activation is
...
...
vllm/model_executor/layers/quantization/kv_cache.py
View file @
fccd5325
...
...
@@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# skip if there are no weights to process (for example, weight reloading)
if
not
hasattr
(
layer
,
"q_scale"
):
assert
not
hasattr
(
layer
,
"k_scale"
)
assert
not
hasattr
(
layer
,
"v_scale"
)
assert
not
hasattr
(
layer
,
"prob_scale"
)
return
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
fccd5325
...
...
@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter
,
PerTensorScaleParameter
,
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
(
...
...
@@ -1404,12 +1405,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
if
should_use_deepgemm
:
dg_weight
,
dg_weight_scale
=
deepgemm_post_process_fp8_weight_block
(
wq
=
layer
.
weight
.
data
,
ws
=
layer
.
weight_scale
.
data
,
ws
=
layer
.
weight_scale
_inv
.
data
,
quant_block_shape
=
tuple
(
layer
.
weight_block_size
),
use_e8m0
=
is_deep_gemm_e8m0_used
(),
)
layer
.
weight
=
torch
.
nn
.
P
arameter
(
dg_
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
P
arameter
(
dg_
weight_scale
,
requires_grad
=
F
al
s
e
)
replace_p
arameter
(
layer
,
"
weight
"
,
dg_weight
)
replace_p
arameter
(
layer
,
"
weight_scale
_inv"
,
dg_weight_sc
ale
)
def
expert_weight_is_col_major
(
x
:
torch
.
Tensor
)
->
bool
:
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
fccd5325
...
...
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_quant_input
,
should_use_atomic_add_reduce
,
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
...
...
@@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
torch
.
nn
.
P
arameter
(
marlin_q
weight
,
requires_grad
=
False
)
replace_p
arameter
(
layer
,
"
weight
"
,
marlin_qweight
)
# WEIGHT SCALES
# Permute scales
...
...
@@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
scales
=
layer
.
weight_scale
.
to
(
layer
.
orig_dtype
)
elif
"weight_scale_inv"
in
dir
(
layer
):
scales
=
layer
.
weight_scale_inv
.
to
(
layer
.
orig_dtype
)
del
layer
.
weight_scale_inv
group_size
=
-
1
if
weight_block_size
is
None
else
weight_block_size
[
1
]
...
...
@@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
)
if
input_dtype
!=
torch
.
float8_e4m3fn
:
marlin_scales
=
fp8_fused_exponent_bias_into_scales
(
marlin_scales
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"weight_scale"
):
replace_parameter
(
layer
,
"weight_scale"
,
marlin_scales
)
elif
hasattr
(
layer
,
"weight_scale_inv"
):
replace_parameter
(
layer
,
"weight_scale_inv"
,
marlin_scales
)
if
hasattr
(
layer
,
"bias"
)
and
layer
.
bias
is
not
None
:
assert
layer
.
bias
.
shape
==
(
part_size_n
,)
bias
=
marlin_permute_bias
(
layer
.
bias
)
layer
.
bias
=
torch
.
nn
.
P
arameter
(
bias
,
requires_grad
=
False
)
replace_p
arameter
(
layer
,
"
bias
"
,
bias
)
def
prepare_moe_fp8_layer_for_marlin
(
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
fccd5325
...
...
@@ -118,8 +118,11 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
#
# Extra note: upon weight reloading weight_scale.ndim == 0
unfused_module_in_checkpoint
=
(
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
weight_scale
.
ndim
!=
0
and
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
# If unfused checkpoint, need requanize with the single scale.
...
...
vllm/model_executor/utils.py
View file @
fccd5325
...
...
@@ -50,6 +50,31 @@ def set_weight_attrs(
setattr
(
weight
,
key
,
value
)
def
replace_parameter
(
layer
:
torch
.
nn
.
Module
,
param_name
:
str
,
new_data
:
torch
.
Tensor
):
"""
Replace a parameter of a layer while maintaining the ability to reload the weight.
Called within implementations of the `process_weights_after_loading` method.
This function should not be called on weights which are tied/shared
Args:
layer: Layer containing parameter to replace
param_name: Name of parameter to replace
new_data: New data of the new parameter
"""
# should not be used on a tied/shared param
if
isinstance
(
new_data
,
torch
.
nn
.
Parameter
):
new_data
=
new_data
.
data
new_param
=
torch
.
nn
.
Parameter
(
new_data
,
requires_grad
=
False
)
old_param
:
torch
.
nn
.
Parameter
|
None
=
getattr
(
layer
,
param_name
,
None
)
if
old_param
is
not
None
and
hasattr
(
old_param
,
"weight_loader"
):
weight_loader
=
old_param
.
weight_loader
set_weight_attrs
(
new_param
,
{
"weight_loader"
:
weight_loader
})
setattr
(
layer
,
param_name
,
new_param
)
def
get_packed_modules_mapping
(
model
:
torch
.
nn
.
Module
)
->
dict
[
str
,
list
[
str
]]:
parent_map
=
getattr
(
model
,
"packed_modules_mapping"
,
None
)
parent_map
=
copy
.
deepcopy
(
parent_map
)
if
parent_map
is
not
None
else
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment