Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3124680
Unverified
Commit
c3124680
authored
Oct 01, 2025
by
Jerry Zhang
Committed by
GitHub
Oct 01, 2025
Browse files
Support RL online quantization with torchao (#23014)
Signed-off-by:
Jerry Zhang
<
jerryzh168@gmail.com
>
parent
4134312b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
465 additions
and
16 deletions
+465
-16
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+121
-5
vllm/model_executor/layers/quantization/torchao.py
vllm/model_executor/layers/quantization/torchao.py
+64
-8
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+29
-2
vllm/model_executor/model_loader/online_quantization.py
vllm/model_executor/model_loader/online_quantization.py
+217
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+8
-1
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+26
-0
No files found.
tests/quantization/test_torchao.py
View file @
c3124680
...
...
@@ -20,7 +20,6 @@ def test_pre_quantized_model(vllm_runner):
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
print
(
output
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
...
...
@@ -42,7 +41,6 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner,
max_tokens
=
32
)
assert
output
print
(
output
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
...
...
@@ -57,7 +55,6 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
max_tokens
=
32
)
assert
output
print
(
output
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
...
...
@@ -72,7 +69,6 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
max_tokens
=
32
)
assert
output
print
(
output
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
...
...
@@ -92,7 +88,127 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
max_tokens
=
32
)
assert
output
print
(
output
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_on_the_fly_quant_config_dict_json
(
vllm_runner
):
"""Testing on the fly quantization, load_weights integration point,
with config dict serialized to json string
"""
torch
.
_dynamo
.
reset
()
model_name
=
"facebook/opt-125m"
import
json
from
torchao.core.config
import
config_to_dict
from
torchao.quantization
import
(
Float8DynamicActivationFloat8WeightConfig
,
PerRow
)
torchao_quant_config
=
Float8DynamicActivationFloat8WeightConfig
(
granularity
=
PerRow
())
hf_overrides
=
{
"quantization_config_dict_json"
:
json
.
dumps
(
config_to_dict
(
torchao_quant_config
))
}
with
vllm_runner
(
model_name
=
model_name
,
dtype
=
"bfloat16"
,
pt_load_map_location
=
"cuda:0"
,
quantization
=
"torchao"
,
hf_overrides
=
hf_overrides
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_on_the_fly_quant_config_file
(
vllm_runner
):
"""Testing on the fly quantization, load_weights integration point,
with config file
"""
torch
.
_dynamo
.
reset
()
model_name
=
"facebook/opt-125m"
import
json
from
tempfile
import
NamedTemporaryFile
from
torchao.core.config
import
config_to_dict
from
torchao.quantization
import
(
Float8DynamicActivationFloat8WeightConfig
,
PerRow
)
config
=
Float8DynamicActivationFloat8WeightConfig
(
granularity
=
PerRow
())
with
NamedTemporaryFile
(
mode
=
"w"
,
delete
=
False
)
as
f
:
f
.
write
(
json
.
dumps
(
config_to_dict
(
config
)))
# close the file to save it
f
.
close
()
config_file_name
=
str
(
f
.
name
)
hf_overrides
=
{
"quantization_config_file"
:
config_file_name
}
with
vllm_runner
(
model_name
=
model_name
,
dtype
=
"bfloat16"
,
pt_load_map_location
=
"cuda:0"
,
quantization
=
"torchao"
,
hf_overrides
=
hf_overrides
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_reload_weights
():
import
json
from
torchao.core.config
import
config_to_dict
from
torchao.quantization
import
(
Float8DynamicActivationFloat8WeightConfig
,
PerRow
)
from
vllm
import
LLM
,
SamplingParams
torchao_quant_config
=
Float8DynamicActivationFloat8WeightConfig
(
granularity
=
PerRow
())
hf_overrides
=
{
"quantization_config_dict_json"
:
json
.
dumps
(
config_to_dict
(
torchao_quant_config
))
}
llm
=
LLM
(
model
=
"Qwen/Qwen3-0.6B"
,
dtype
=
"bfloat16"
,
load_format
=
"dummy"
,
enforce_eager
=
True
,
quantization
=
"torchao"
,
hf_overrides
=
hf_overrides
,
)
# Update load format from `dummy` to `auto`
llm
.
collective_rpc
(
"update_config"
,
args
=
({
"load_config"
:
{
"load_format"
:
"auto"
}
},
))
# Now reload real weights inplace
llm
.
collective_rpc
(
"reload_weights"
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# make sure it runs
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
# can also uncomment locally to make sure the generated
# output makes sense
# prompt = output.prompt
# print(f"Prompt: {prompt!r}")
# print(f"Output: {generated_text!r}")
# print("-" * 60)
if
__name__
==
"__main__"
:
...
...
vllm/model_executor/layers/quantization/torchao.py
View file @
c3124680
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -40,7 +41,8 @@ class TorchAOConfig(QuantizationConfig):
def
__init__
(
self
,
torchao_config
,
skip_modules
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
skip_modules
:
Optional
[
list
[
str
]]
=
None
,
is_checkpoint_torchao_serialized
:
bool
=
False
)
->
None
:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
...
...
@@ -58,9 +60,11 @@ class TorchAOConfig(QuantizationConfig):
super
().
__init__
()
self
.
torchao_config
=
torchao_config
self
.
skip_modules
=
skip_modules
or
[]
self
.
is_checkpoint_torchao_serialized
=
is_checkpoint_torchao_serialized
def
__repr__
(
self
)
->
str
:
return
f
"TorchAOConfig(
{
self
.
torchao_config
}
)"
return
f
"TorchAOConfig(
{
self
.
torchao_config
=
}
,
{
self
.
skip_modules
=
}
, "
\
f
"
{
self
.
is_checkpoint_torchao_serialized
=
}
)"
def
get_name
(
self
)
->
QuantizationMethods
:
return
"torchao"
...
...
@@ -74,7 +78,10 @@ class TorchAOConfig(QuantizationConfig):
@
staticmethod
def
get_config_filenames
()
->
list
[
str
]:
return
[
"config.json"
]
"""torchao doesn't require additional config files, we use
`config.json` from huggingface: `model_config.hf_config`
"""
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"TorchAOConfig"
:
...
...
@@ -87,6 +94,10 @@ class TorchAOConfig(QuantizationConfig):
"`pip install torchao>=0.10.0` to use torchao quantization."
)
from
err
quant_method
=
cls
.
get_from_keys_or
(
config
,
[
"quant_method"
],
None
)
is_checkpoint_torchao_serialized
=
(
quant_method
is
not
None
and
"torchao"
in
quant_method
)
hf_config
=
cls
.
get_from_keys_or
(
config
,
[
"quant_type"
],
None
)
assert
hf_config
is
not
None
,
"quant_type must be specified"
assert
len
(
hf_config
)
==
1
and
"default"
in
hf_config
,
(
...
...
@@ -110,7 +121,38 @@ class TorchAOConfig(QuantizationConfig):
if
layer_cfg
is
None
:
skip_modules
.
append
(
layer
)
return
cls
(
ao_config
,
skip_modules
)
return
cls
(
ao_config
,
skip_modules
,
is_checkpoint_torchao_serialized
)
@
classmethod
def
from_config_file
(
cls
,
config_file
:
str
)
->
"TorchAOConfig"
:
"""Initialize class from a config file. Example:
```
config = (
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
)
fn = "torchao_config.json"
with open(fn, "w") as f:
f.write(json.dumps(config_to_dict(config)))
```
"""
with
open
(
config_file
)
as
f
:
f
.
seek
(
0
)
f_read
=
f
.
read
()
config_dict
=
json
.
loads
(
f_read
)
hf_config
=
{
"quant_type"
:
{
"default"
:
config_dict
}}
return
cls
.
from_config
(
hf_config
)
@
classmethod
def
from_config_dict_json
(
cls
,
config_dict_json
:
str
)
->
"TorchAOConfig"
:
"""Iniitalize class from a config_dict json string, got from
torchao_config_object = some AOBaseConfig object
json.dumps(config_to_dict(torchao_config_object))
"""
config_dict
=
json
.
loads
(
config_dict_json
)
hf_config
=
{
"quant_type"
:
{
"default"
:
config_dict
}}
return
cls
.
from_config
(
hf_config
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
...
...
@@ -128,7 +170,9 @@ class TorchAOConfig(QuantizationConfig):
c
=
module_fqn_to_config
.
get
(
module_fqn
)
or
module_fqn_to_config
.
get
(
"_default"
,
None
)
if
c
is
not
None
:
current_torchao_config
=
TorchAOConfig
(
c
,
self
.
skip_modules
)
current_torchao_config
=
TorchAOConfig
(
c
,
self
.
skip_modules
,
self
.
is_checkpoint_torchao_serialized
)
return
TorchAOLinearMethod
(
current_torchao_config
)
else
:
return
UnquantizedLinearMethod
()
...
...
@@ -197,8 +241,9 @@ class TorchAOLinearMethod(LinearMethodBase):
),
requires_grad
=
False
,
)
weight
=
torchao_quantize_param_data
(
weight
,
self
.
quant_config
.
torchao_config
)
if
self
.
quant_config
.
is_checkpoint_torchao_serialized
:
weight
=
torchao_quantize_param_data
(
weight
,
self
.
quant_config
.
torchao_config
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
...
...
@@ -212,3 +257,14 @@ class TorchAOLinearMethod(LinearMethodBase):
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
self
.
quant_config
.
is_checkpoint_torchao_serialized
:
return
# quantize the weight on the fly if the checkpoint is not already
# quantized by torchao
weight
=
torchao_quantize_param_data
(
layer
.
weight
,
self
.
quant_config
.
torchao_config
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
vllm/model_executor/model_loader/default_loader.py
View file @
c3124680
...
...
@@ -261,8 +261,35 @@ class DefaultModelLoader(BaseModelLoader):
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization
=
not
getattr
(
model
,
"weight_metadata_and_attr_saved"
,
False
)
if
model_config
.
quantization
is
None
:
# model is not quantized
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
))
elif
offline_quantization_or_first_run_of_online_quantization
:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
))
else
:
# to avoid circular dependency
from
vllm.model_executor.model_loader.online_quantization
import
(
load_weights_and_online_quantize
)
# subsequent runs of weight loading with online
# quantization
loaded_weights
=
load_weights_and_online_quantize
(
self
,
model
,
model_config
)
self
.
counter_after_loading_weights
=
time
.
perf_counter
()
logger
.
info
(
"Loading weights took %.2f seconds"
,
...
...
vllm/model_executor/model_loader/online_quantization.py
0 → 100644
View file @
c3124680
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
types
import
torch
from
torch
import
nn
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.default_loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.utils
import
(
process_weights_after_loading
)
logger
=
init_logger
(
__name__
)
# Notes for Online Quantization
# In terms of state of checkpoints, quantization config and their
# correspondance to online quantization:
# | Use Case | Checkpoints | model_config.quantization |
# | no quant | high precision | None |
# | offline quant | quantized | fp8, torchao etc. |
# | online quant | high precision | torchao etc. |
#
# The process for loading non-quantized checkpoint
# 1. load non-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
#
# The process for loading offline quantized checkpoint
# 1. load offline-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
# The process for unquantized model reloading
# (repeated run in RL training loop)
# first run
# UI1. load_weights: load bfloat16 weights
# UI2. process_weights_after_loading: any additional post processing
# subsequent run
# UC1: load_weights: load bfloat16 weights
# (shouldn't be any issues since we didn't change any attributes
# of the weights)
# UC2: process_weights_after_loading: any additional post processing
# The process for weight reloading with online quantization
# (repeated run in RL training loop)
# first run
# I1. load_weights: load bfloat16 weights
# I2. process_weights_after_loading:
# record weight metadata and attributes for R1 and R2
# quantize weights to fp8
# subsequent run
# (beginning model weight is in fp8)
# load_weights:
# R1. restore bfloat16 model weight metadata
# R2. restore the model weight attributes
# R3. reload bfloat16 weights
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights
def
maybe_save_metadata_and_attributes_for_weight_reloading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
):
# following is to support on the fly quantization, currently only supported
# for torchao
if
model_config
.
quantization
!=
"torchao"
:
return
if
getattr
(
model
,
"process_weights_after_loading_already_called"
,
False
):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger
.
warning
(
"process_weights_after_loading already called for model %s"
,
model
)
return
from
vllm.model_executor.model_loader.weight_utils
import
get_quant_config
quant_config
=
get_quant_config
(
model_config
,
None
)
# If checkpoint is already torchao serialized, this means it's
# pre-quantized quantization case, we'll skip saving the metadata
# Otherwise, this is Step I2 of initialization steps of
# online quantization
# This step record the weights metadata and weight attributes so we can
# restore the bfloat16 model weights during the relad step (R1 and R2)
# see Notes in online_quantization.py for more details
if
not
(
hasattr
(
quant_config
,
"is_checkpoint_torchao_serialized"
)
and
\
not
quant_config
.
is_checkpoint_torchao_serialized
):
return
# This is the I2 step of online quantiztion that saves
# metadata and attributes of weights so they can be used in R1 and
# R2 step, note that we only save these during initialization
# Includes two things
# 1. save floating point metadata (shape, dtype, device) for init
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
if
getattr
(
model
,
"weight_metadata_and_attr_saved"
,
False
):
return
# save the dtype, shape and device for model parameter, used for
# restoring the model high precision parameters before
# reloading the weights
assert
not
hasattr
(
model
,
"original_weights_rebuild_keys"
)
model
.
original_weights_rebuild_keys
=
{}
for
name
,
p
in
model
.
named_parameters
():
model
.
original_weights_rebuild_keys
[
name
]
=
{
"shape"
:
p
.
shape
,
"dtype"
:
p
.
dtype
,
"device"
:
p
.
device
,
}
# record the weight attributes (loader functions etc.)
# so these can be recovered later when we reload the weights
# structure: {"weight_name": {"weight_attr_key": attr}}
assert
not
hasattr
(
model
,
"recorded_weight_attr"
)
model
.
recorded_weight_attr
=
{}
for
name
,
param
in
model
.
named_parameters
():
model
.
recorded_weight_attr
[
name
]
=
{}
for
key
in
param
.
__dict__
:
if
hasattr
(
param
,
key
):
attr
=
getattr
(
param
,
key
)
if
not
callable
(
attr
):
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
elif
hasattr
(
attr
,
"__self__"
)
and
param
is
attr
.
__self__
:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
.
__func__
else
:
model
.
recorded_weight_attr
[
name
][
key
]
=
attr
# mark the metadata and attributes saved so we don't run it again
model
.
weight_metadata_and_attr_saved
=
True
def
_bond_method_to_cls
(
func
,
obj
):
if
hasattr
(
func
,
"__self__"
)
or
not
callable
(
func
):
# If the function is already bound to an instance, return it as is
return
func
else
:
return
types
.
MethodType
(
func
,
obj
)
def
load_weights_and_online_quantize
(
model_loader
:
DefaultModelLoader
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
set
[
str
]:
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4 in the Notes
# TODO: Add fp8 support
assert
model_config
.
quantization
==
"torchao"
,
"online "
\
"quantization is only enabled for torchao currently"
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
existing_param_names
=
dict
(
model
.
named_parameters
(
remove_duplicate
=
False
)).
keys
()
named_modules
=
dict
(
model
.
named_modules
(
remove_duplicate
=
False
))
model_device
=
None
# Step R2: recover the parameter to the state before first loading
for
name
,
d
in
model
.
original_weights_rebuild_keys
.
items
():
_shape
=
d
[
"shape"
]
_dtype
=
d
[
"dtype"
]
_device
=
d
[
"device"
]
if
model_device
is
not
None
:
assert
model_device
==
_device
,
"Expecting all weights "
\
"to be in the same device for now, got both: "
\
f
"
{
model_device
}
and
{
_device
}
"
else
:
model_device
=
_device
if
name
in
existing_param_names
:
module_name
,
weight_name
=
name
.
rsplit
(
"."
,
1
)
module
=
named_modules
[
module_name
]
setattr
(
module
,
weight_name
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
_shape
,
dtype
=
_dtype
,
device
=
_device
)))
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for
full_weight_name
,
weight_attr_dict
in
\
model
.
recorded_weight_attr
.
items
():
for
attr_name
,
attr
in
weight_attr_dict
.
items
():
module_name
,
weight_name
=
full_weight_name
.
rsplit
(
"."
,
1
)
module
=
named_modules
[
module_name
]
weight
=
getattr
(
module
,
weight_name
)
if
not
hasattr
(
weight
,
attr_name
):
setattr
(
weight
,
attr_name
,
_bond_method_to_cls
(
attr
,
weight
))
# Step I1: reload bfloat16 / high precision weights
loaded_weights
=
model
.
load_weights
(
model_loader
.
get_all_weights
(
model_config
,
model
))
# Step I2: online quantize the weights
# manually process weights after loading
model
.
process_weights_after_loading_already_called
=
False
process_weights_after_loading
(
model
,
model_config
,
model_device
)
model
.
process_weights_after_loading_already_called
=
True
return
loaded_weights
vllm/model_executor/model_loader/utils.py
View file @
c3124680
...
...
@@ -95,6 +95,13 @@ def initialize_model(
def
process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
# to avoid circular dependency
from
vllm.model_executor.model_loader.online_quantization
import
(
maybe_save_metadata_and_attributes_for_weight_reloading
)
maybe_save_metadata_and_attributes_for_weight_reloading
(
model
,
model_config
)
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
QKVCrossParallelLinear
):
# NOTE(Isotr0py): special case for cross QKV layer because
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
c3124680
...
...
@@ -246,8 +246,34 @@ def get_quant_config(model_config: ModelConfig,
# compressed-tensors uses a compressions_config
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"compression_config"
,
None
)
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
# if hf_quant_config is None, we will try to get config from
# hf_overrides
hf_overrides
=
model_config
.
hf_overrides
quantization_config_file
=
hf_overrides
.
get
(
"quantization_config_file"
,
None
)
if
quantization_config_file
is
not
None
:
if
hasattr
(
quant_cls
,
"from_config_file"
):
return
quant_cls
.
from_config_file
(
quantization_config_file
)
else
:
raise
NotImplementedError
(
"from_config_file is specified in hf_override config, "
"but quant_cls.from_config_file is not implemented in "
f
"
{
quant_cls
}
"
)
quantization_config_json
=
hf_overrides
.
get
(
"quantization_config_dict_json"
,
None
)
if
quantization_config_json
is
not
None
:
if
hasattr
(
quant_cls
,
"from_config_dict_json"
):
return
quant_cls
.
from_config_dict_json
(
quantization_config_json
)
else
:
raise
NotImplementedError
(
"from_config_dict_json is specified in hf_override config, "
"but quant_cls.from_config_dict_json is not implemented in "
f
"
{
quant_cls
}
"
)
# Inflight BNB quantization
if
model_config
.
quantization
==
"bitsandbytes"
:
return
quant_cls
.
from_config
({})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment