Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da94c7c0
Unverified
Commit
da94c7c0
authored
Nov 18, 2025
by
Jerry Zhang
Committed by
GitHub
Nov 18, 2025
Browse files
Move online quantization to `model.load_weights` (#26327)
Signed-off-by:
Jerry Zhang
<
jerryzh168@gmail.com
>
parent
1395461f
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
314 additions
and
113 deletions
+314
-113
examples/offline_inference/rlhf.py
examples/offline_inference/rlhf.py
+1
-1
examples/offline_inference/rlhf_online_quant.py
examples/offline_inference/rlhf_online_quant.py
+162
-0
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+11
-35
vllm/model_executor/model_loader/online_quantization.py
vllm/model_executor/model_loader/online_quantization.py
+128
-77
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+8
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+4
-0
No files found.
examples/offline_inference/rlhf.py
View file @
da94c7c0
...
@@ -62,7 +62,7 @@ ray.init()
...
@@ -62,7 +62,7 @@ ray.init()
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-group
s
.html
# https://docs.ray.io/en/latest/
ray-core/scheduling/
placement-group.html
pg_inference
=
placement_group
([{
"GPU"
:
1
,
"CPU"
:
0
}]
*
2
)
pg_inference
=
placement_group
([{
"GPU"
:
1
,
"CPU"
:
0
}]
*
2
)
ray
.
get
(
pg_inference
.
ready
())
ray
.
get
(
pg_inference
.
ready
())
scheduling_inference
=
PlacementGroupSchedulingStrategy
(
scheduling_inference
=
PlacementGroupSchedulingStrategy
(
...
...
examples/offline_inference/rlhf_online_quant.py
0 → 100644
View file @
da94c7c0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
tensor-parallel vLLM inference engine occupies GPU 1–2.
The example performs the following steps:
* Load the training model on GPU 0.
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.
For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import
json
import
os
import
ray
import
torch
from
ray.util.placement_group
import
placement_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
rlhf_utils
import
stateless_init_process_group
from
torchao.core.config
import
config_to_dict
from
torchao.quantization
import
(
Float8DynamicActivationFloat8WeightConfig
,
PerRow
,
)
from
transformers
import
AutoModelForCausalLM
from
vllm
import
LLM
,
SamplingParams
from
vllm.utils.network_utils
import
get_ip
,
get_open_port
class
MyLLM
(
LLM
):
"""Configure the vLLM worker for Ray placement group execution."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os
.
environ
.
pop
(
"CUDA_VISIBLE_DEVICES"
,
None
)
super
().
__init__
(
*
args
,
**
kwargs
)
# Load the OPT-125M model onto GPU 0 for the training workload.
train_model
=
AutoModelForCausalLM
.
from_pretrained
(
"facebook/opt-125m"
)
train_model
.
to
(
"cuda:0"
)
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"1,2"
ray
.
init
()
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference
=
placement_group
([{
"GPU"
:
1
,
"CPU"
:
0
}]
*
2
)
ray
.
get
(
pg_inference
.
ready
())
scheduling_inference
=
PlacementGroupSchedulingStrategy
(
placement_group
=
pg_inference
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
0
,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# generate torchao quantization config for RL rollout
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
# use serialized config files instead of passing around json string
config
=
Float8DynamicActivationFloat8WeightConfig
(
granularity
=
PerRow
())
json_str
=
json
.
dumps
(
config_to_dict
(
config
))
llm
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
0
,
scheduling_strategy
=
scheduling_inference
,
)(
MyLLM
).
remote
(
model
=
"facebook/opt-125m"
,
hf_overrides
=
{
"quantization_config_dict_json"
:
json_str
},
enforce_eager
=
True
,
worker_extension_cls
=
"rlhf_utils.WorkerExtension"
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
"ray"
,
)
# Generate text from the prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
outputs
=
ray
.
get
(
llm
.
generate
.
remote
(
prompts
,
sampling_params
))
print
(
"-"
*
50
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
50
)
# Set up the communication channel between the training process and the
# inference engine.
master_address
=
get_ip
()
master_port
=
get_open_port
()
handle
=
llm
.
collective_rpc
.
remote
(
"init_weight_update_group"
,
args
=
(
master_address
,
master_port
,
1
,
3
)
)
model_update_group
=
stateless_init_process_group
(
master_address
,
master_port
,
0
,
3
,
torch
.
device
(
"cuda:0"
)
)
ray
.
get
(
handle
)
# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for
name
,
p
in
train_model
.
named_parameters
():
p
.
data
.
zero_
()
# Synchronize the updated weights to the inference engine.
for
name
,
p
in
train_model
.
named_parameters
():
dtype_name
=
str
(
p
.
dtype
).
split
(
"."
)[
-
1
]
handle
=
llm
.
collective_rpc
.
remote
(
"update_weight"
,
args
=
(
name
,
dtype_name
,
p
.
shape
)
)
model_update_group
.
broadcast
(
p
,
src
=
0
,
stream
=
torch
.
cuda
.
current_stream
())
ray
.
get
(
handle
)
# Verify that the inference weights have been updated.
assert
all
(
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"check_weights_changed"
)))
# Generate text with the updated model. The output is expected to be nonsense
# because the weights are zero.
outputs_updated
=
ray
.
get
(
llm
.
generate
.
remote
(
prompts
,
sampling_params
))
print
(
"-"
*
50
)
for
output
in
outputs_updated
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
50
)
vllm/model_executor/model_loader/default_loader.py
View file @
da94c7c0
...
@@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
fastsafetensors_weights_iterator
,
fastsafetensors_weights_iterator
,
filter_duplicate_safetensors_files
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
filter_files_not_needed_for_inference
,
get_quant_config
,
maybe_download_from_modelscope
,
maybe_download_from_modelscope
,
multi_thread_pt_weights_iterator
,
multi_thread_pt_weights_iterator
,
multi_thread_safetensors_weights_iterator
,
multi_thread_safetensors_weights_iterator
,
...
@@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
)
)
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
if
model_config
.
quantization
==
"torchao"
and
torchao_version_at_least
(
if
model_config
.
quantization
==
"torchao"
:
"0.14.0"
quant_config
=
get_quant_config
(
model_config
,
self
.
load_config
)
if
(
hasattr
(
quant_config
,
"is_checkpoint_torchao_serialized"
)
and
quant_config
.
is_checkpoint_torchao_serialized
and
torchao_version_at_least
(
"0.14.0"
)
):
):
self
.
load_config
.
safetensors_load_strategy
=
"torchao"
self
.
load_config
.
safetensors_load_strategy
=
"torchao"
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization
=
not
getattr
(
model
,
"weight_metadata_and_attr_saved"
,
False
)
if
model_config
.
quantization
is
None
:
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
# model is not quantized
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
))
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
)
)
elif
offline_quantization_or_first_run_of_online_quantization
:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
)
)
else
:
# to avoid circular dependency
from
vllm.model_executor.model_loader.online_quantization
import
(
load_weights_and_online_quantize
,
)
# subsequent runs of weight loading with online
# quantization
loaded_weights
=
load_weights_and_online_quantize
(
self
,
model
,
model_config
)
self
.
counter_after_loading_weights
=
time
.
perf_counter
()
self
.
counter_after_loading_weights
=
time
.
perf_counter
()
logger
.
info_once
(
logger
.
info_once
(
...
...
vllm/model_executor/model_loader/online_quantization.py
View file @
da94c7c0
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
types
import
types
from
collections.abc
import
Iterable
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.default_loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.utils
import
process_weights_after_loading
from
vllm.model_executor.model_loader.utils
import
process_weights_after_loading
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -56,6 +56,9 @@ logger = init_logger(__name__)
...
@@ -56,6 +56,9 @@ logger = init_logger(__name__)
# R4. quantize weights (by calling process_weights_after_loading),
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# True to stop it from running again
# R5. (workaround for cudagraph), we restore the weight params to original quantized
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
# the weight update work well with cudagraph
# process_weights_after_loading (if called):
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# this will be skipped since it's already ran in
# load_weights
# load_weights
...
@@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
...
@@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
if
model_config
.
quantization
!=
"torchao"
:
if
model_config
.
quantization
!=
"torchao"
:
return
return
if
getattr
(
model
,
"process_weights_after_loading_already_called"
,
False
):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger
.
warning
(
"process_weights_after_loading already called for model %s"
,
model
)
return
from
vllm.model_executor.model_loader.weight_utils
import
get_quant_config
from
vllm.model_executor.model_loader.weight_utils
import
get_quant_config
quant_config
=
get_quant_config
(
model_config
,
None
)
quant_config
=
get_quant_config
(
model_config
,
None
)
...
@@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
...
@@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
else
:
else
:
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
# mark the metadata and attributes saved so we don't run it again
# mark the metadata and attributes saved so we don't run it again
model
.
_model_config
=
model_config
model
.
weight_metadata_and_attr_saved
=
True
model
.
weight_metadata_and_attr_saved
=
True
...
@@ -148,12 +144,35 @@ def _bond_method_to_cls(func, obj):
...
@@ -148,12 +144,35 @@ def _bond_method_to_cls(func, obj):
return
types
.
MethodType
(
func
,
obj
)
return
types
.
MethodType
(
func
,
obj
)
def
load_weights_and_online_quantize
(
def
support_quantized_model_reload_from_hp_weights
(
original_load_weights
):
model_loader
:
DefaultModelLoader
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
)
->
set
[
str
]:
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights
"""
# online quantization, right now only enabled for
# online quantization, right now only enabled for
# torchao
# torchao
# R1, R2, R3, R4 in the Notes
# R1, R2, R3, R4, R5 in the Notes
def
patched_model_load_weights
(
auto_weight_loader
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
*
,
mapper
=
None
)
->
set
[
str
]:
model
=
auto_weight_loader
.
module
offline_quantization_or_first_run_of_online_quantization
=
not
getattr
(
model
,
"weight_metadata_and_attr_saved"
,
False
)
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see Notes in this file for more details
if
offline_quantization_or_first_run_of_online_quantization
:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
return
original_load_weights
(
auto_weight_loader
,
weights
,
mapper
=
mapper
)
model_config
=
model
.
_model_config
# TODO: Add fp8 support
# TODO: Add fp8 support
assert
model_config
.
quantization
==
"torchao"
,
(
assert
model_config
.
quantization
==
"torchao"
,
(
...
@@ -164,11 +183,13 @@ def load_weights_and_online_quantize(
...
@@ -164,11 +183,13 @@ def load_weights_and_online_quantize(
# Step R1: First restore the quantized weights to original bfloat16
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
# and attributes, so that bfloat16 weights can be loaded properly
existing_param_names
=
dict
(
model
.
named_parameters
(
remove_duplicate
=
False
)).
keys
()
# TODO: maybe set remove_duplicate to True?
original_quantized_weight_dict
=
dict
(
model
.
named_parameters
(
remove_duplicate
=
False
)
)
named_modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
named_modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
model_device
=
None
model_device
=
None
# Step R2: recover the parameter to the state before first loading
for
name
,
d
in
model
.
original_weights_rebuild_keys
.
items
():
for
name
,
d
in
model
.
original_weights_rebuild_keys
.
items
():
_shape
=
d
[
"shape"
]
_shape
=
d
[
"shape"
]
_dtype
=
d
[
"dtype"
]
_dtype
=
d
[
"dtype"
]
...
@@ -182,15 +203,19 @@ def load_weights_and_online_quantize(
...
@@ -182,15 +203,19 @@ def load_weights_and_online_quantize(
else
:
else
:
model_device
=
_device
model_device
=
_device
if
name
in
existing_param_names
:
if
name
in
original_quantized_weight_dict
:
module_name
,
weight_name
=
name
.
rsplit
(
"."
,
1
)
module_name
,
weight_name
=
name
.
rsplit
(
"."
,
1
)
module
=
named_modules
[
module_name
]
module
=
named_modules
[
module_name
]
setattr
(
setattr
(
module
,
module
,
weight_name
,
weight_name
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
_shape
,
dtype
=
_dtype
,
device
=
_device
)),
torch
.
nn
.
Parameter
(
torch
.
empty
(
_shape
,
dtype
=
_dtype
,
device
=
_device
),
requires_grad
=
False
,
),
)
)
# Step R2: recover the weight attributes to the state before first loading
# recorded_weight_attr is
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# e.g.
...
@@ -211,14 +236,40 @@ def load_weights_and_online_quantize(
...
@@ -211,14 +236,40 @@ def load_weights_and_online_quantize(
if
not
hasattr
(
weight
,
attr_name
):
if
not
hasattr
(
weight
,
attr_name
):
setattr
(
weight
,
attr_name
,
_bond_method_to_cls
(
attr
,
weight
))
setattr
(
weight
,
attr_name
,
_bond_method_to_cls
(
attr
,
weight
))
# Step
I1
: reload bfloat16 / high precision weights
# Step
R3
: reload bfloat16 / high precision weights
loaded_weights
=
model
.
load_weights
(
updated_params
=
original_
load_weights
(
model_loader
.
get_all_weights
(
model_config
,
model
)
auto_weight_loader
,
weights
,
mapper
=
mapper
)
)
# Step
I2
: online quantize the weights
# Step
R4
: online quantize the weights
# manually process weights after loading
# manually process weights after loading
model
.
process_weights_after_loading_already_called
=
False
model
.
process_weights_after_loading_already_called
=
False
if
model_device
is
not
None
:
process_weights_after_loading
(
model
,
model_config
,
model_device
)
process_weights_after_loading
(
model
,
model_config
,
model_device
)
else
:
logger
.
warning_once
(
"model_device is None, skip calling process_weights_after_loading"
)
# Step R5 (workaround for cudagraph): restore the original quantized weights
# and do a copy_ of the currents weights to the original weights
updated_quantized_weights
=
dict
(
model
.
named_parameters
(
remove_duplicate
=
False
))
for
name
in
model
.
original_weights_rebuild_keys
:
if
name
in
original_quantized_weight_dict
:
original_quantized_weight
=
original_quantized_weight_dict
[
name
]
updated_quantized_weight
=
updated_quantized_weights
[
name
]
module_name
,
weight_name
=
name
.
rsplit
(
"."
,
1
)
module
=
named_modules
[
module_name
]
setattr
(
module
,
weight_name
,
original_quantized_weight
)
with
torch
.
no_grad
():
original_quantized_weight
.
copy_
(
updated_quantized_weight
)
del
original_quantized_weight_dict
del
named_modules
del
updated_quantized_weight
model
.
process_weights_after_loading_already_called
=
True
model
.
process_weights_after_loading_already_called
=
True
return
loaded_weights
return
updated_params
return
patched_model_load_weights
vllm/model_executor/model_loader/utils.py
View file @
da94c7c0
...
@@ -88,6 +88,14 @@ def initialize_model(
...
@@ -88,6 +88,14 @@ def initialize_model(
def
process_weights_after_loading
(
def
process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
)
->
None
:
if
getattr
(
model
,
"process_weights_after_loading_already_called"
,
False
):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger
.
debug_once
(
"process_weights_after_loading already called for model %s"
,
model
)
return
# to avoid circular dependency
# to avoid circular dependency
from
vllm.model_executor.model_loader.online_quantization
import
(
from
vllm.model_executor.model_loader.online_quantization
import
(
maybe_save_metadata_and_attributes_for_weight_reloading
,
maybe_save_metadata_and_attributes_for_weight_reloading
,
...
...
vllm/model_executor/models/utils.py
View file @
da94c7c0
...
@@ -21,6 +21,9 @@ from vllm.logger import init_logger
...
@@ -21,6 +21,9 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
)
)
from
vllm.model_executor.model_loader.online_quantization
import
(
support_quantized_model_reload_from_hp_weights
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
supports_any_eagle
from
vllm.model_executor.models.interfaces
import
supports_any_eagle
from
vllm.multimodal
import
NestedTensors
from
vllm.multimodal
import
NestedTensors
...
@@ -316,6 +319,7 @@ class AutoWeightsLoader:
...
@@ -316,6 +319,7 @@ class AutoWeightsLoader:
)
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
@
support_quantized_model_reload_from_hp_weights
def
load_weights
(
def
load_weights
(
self
,
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment