Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96846bb3
Unverified
Commit
96846bb3
authored
Jun 12, 2025
by
mobicham
Committed by
GitHub
Jun 12, 2025
Browse files
Fix TorchAOConfig skip layers (#19265)
Signed-off-by:
mobicham
<
hicham@mobiuslabs.com
>
parent
b6efafd9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
7 deletions
+72
-7
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+15
-0
vllm/model_executor/layers/quantization/torchao.py
vllm/model_executor/layers/quantization/torchao.py
+57
-7
No files found.
tests/quantization/test_torchao.py
View file @
96846bb3
...
@@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
...
@@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
print
(
output
)
print
(
output
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_qwenvl_int8wo_model_loading_with_params
(
vllm_runner
):
torch
.
_dynamo
.
reset
()
model_name
=
"mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
with
vllm_runner
(
model_name
=
model_name
,
quantization
=
"torchao"
,
dtype
=
"bfloat16"
,
pt_load_map_location
=
"cuda:0"
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
print
(
output
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/model_executor/layers/quantization/torchao.py
View file @
96846bb3
...
@@ -17,11 +17,30 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -17,11 +17,30 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
should_skip
(
prefix
:
str
,
skip_modules
:
list
[
str
])
->
bool
:
"""
Robust skipping logic:
should_skip("model.model.layers.1.q_proj",
["model.model.layers.1.q_proj"]) # True
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
"""
for
s
in
skip_modules
:
if
prefix
==
s
:
return
True
if
f
".
{
s
}
."
in
f
".
{
prefix
}
."
:
return
True
return
False
class
TorchAOConfig
(
QuantizationConfig
):
class
TorchAOConfig
(
QuantizationConfig
):
"""Config class for torchao."""
"""Config class for torchao."""
def
__init__
(
self
,
torchao_config
)
->
None
:
def
__init__
(
self
,
self
.
torchao_config
=
torchao_config
torchao_config
,
skip_modules
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
"""
"""
# TorchAO quantization relies on tensor subclasses. In order,
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
# to enable proper caching this needs standalone compile
...
@@ -36,6 +55,8 @@ class TorchAOConfig(QuantizationConfig):
...
@@ -36,6 +55,8 @@ class TorchAOConfig(QuantizationConfig):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
"""
self
.
torchao_config
=
torchao_config
self
.
skip_modules
=
skip_modules
or
[]
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
f
"TorchAOConfig(
{
self
.
torchao_config
}
)"
return
f
"TorchAOConfig(
{
self
.
torchao_config
}
)"
...
@@ -67,11 +88,28 @@ class TorchAOConfig(QuantizationConfig):
...
@@ -67,11 +88,28 @@ class TorchAOConfig(QuantizationConfig):
hf_config
=
cls
.
get_from_keys_or
(
config
,
[
"quant_type"
],
None
)
hf_config
=
cls
.
get_from_keys_or
(
config
,
[
"quant_type"
],
None
)
assert
hf_config
is
not
None
,
"quant_type must be specified"
assert
hf_config
is
not
None
,
"quant_type must be specified"
assert
(
len
(
hf_config
)
==
1
and
"default"
in
hf_config
assert
len
(
hf_config
)
==
1
and
"default"
in
hf_config
,
(
),
"Expected only one key 'default' in quant_type dictionary"
"Expected only one key 'default' in quant_type dictionary"
)
quant_type
=
hf_config
[
"default"
]
quant_type
=
hf_config
[
"default"
]
ao_config
=
config_from_dict
(
quant_type
)
ao_config
=
config_from_dict
(
quant_type
)
return
cls
(
ao_config
)
# Adds skipped modules defined in "modules_to_not_convert"
skip_modules
=
config
.
get
(
"modules_to_not_convert"
,
[])
or
[]
# Adds skipped modules defined in "module_fqn_to_config"
_data
=
quant_type
.
get
(
"_data"
,
{})
if
not
isinstance
(
_data
,
dict
):
_data
=
{}
module_fqn
=
_data
.
get
(
"module_fqn_to_config"
,
{})
if
not
isinstance
(
module_fqn
,
dict
):
module_fqn
=
{}
for
layer
,
layer_cfg
in
module_fqn
.
items
():
if
layer_cfg
is
None
:
skip_modules
.
append
(
layer
)
return
cls
(
ao_config
,
skip_modules
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
...
@@ -80,13 +118,16 @@ class TorchAOConfig(QuantizationConfig):
...
@@ -80,13 +118,16 @@ class TorchAOConfig(QuantizationConfig):
from
torchao.quantization
import
ModuleFqnToConfig
from
torchao.quantization
import
ModuleFqnToConfig
if
should_skip
(
prefix
,
self
.
skip_modules
):
return
UnquantizedLinearMethod
()
module_fqn
=
prefix
module_fqn
=
prefix
if
isinstance
(
self
.
torchao_config
,
ModuleFqnToConfig
):
if
isinstance
(
self
.
torchao_config
,
ModuleFqnToConfig
):
module_fqn_to_config
=
self
.
torchao_config
.
module_fqn_to_config
module_fqn_to_config
=
self
.
torchao_config
.
module_fqn_to_config
c
=
module_fqn_to_config
.
get
(
c
=
module_fqn_to_config
.
get
(
module_fqn
)
or
module_fqn_to_config
.
get
(
"_default"
,
None
)
module_fqn
)
or
module_fqn_to_config
.
get
(
"_default"
,
None
)
if
c
is
not
None
:
if
c
is
not
None
:
current_torchao_config
=
TorchAOConfig
(
c
)
current_torchao_config
=
TorchAOConfig
(
c
,
self
.
skip_modules
)
return
TorchAOLinearMethod
(
current_torchao_config
)
return
TorchAOLinearMethod
(
current_torchao_config
)
else
:
else
:
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
...
@@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor,
...
@@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""
"""
from
torchao.core.config
import
AOBaseConfig
from
torchao.core.config
import
AOBaseConfig
from
torchao.quantization
import
quantize_
from
torchao.quantization
import
quantize_
assert
isinstance
(
torchao_config
,
AOBaseConfig
),
f
"
{
torchao_config
}
"
assert
isinstance
(
torchao_config
,
AOBaseConfig
),
f
"
{
torchao_config
}
"
dummy_linear
=
torch
.
nn
.
Linear
(
param
.
shape
[
1
],
param
.
shape
[
0
],
bias
=
False
)
"""
Avoid real weight allocation for faster load, since we will
end up setting it to param.
"""
with
torch
.
device
(
"meta"
):
dummy_linear
=
torch
.
nn
.
Linear
(
param
.
shape
[
1
],
param
.
shape
[
0
],
bias
=
False
)
dummy_linear
.
weight
=
param
dummy_linear
.
weight
=
param
quantize_
(
dummy_linear
,
torchao_config
)
quantize_
(
dummy_linear
,
torchao_config
)
return
dummy_linear
.
weight
return
dummy_linear
.
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment