Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8134bea
Unverified
Commit
c8134bea
authored
Jun 05, 2025
by
Jerry Zhang
Committed by
GitHub
Jun 05, 2025
Browse files
Fix AOPerModuleConfig name changes (#18869)
Signed-off-by:
Jerry Zhang
<
jerryzh168@gmail.com
>
parent
cb6d572e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
5 deletions
+25
-5
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-0
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+3
-3
vllm/model_executor/layers/quantization/torchao.py
vllm/model_executor/layers/quantization/torchao.py
+19
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
c8134bea
...
...
@@ -424,6 +424,9 @@ steps:
-
vllm/model_executor/layers/quantization
-
tests/quantization
commands
:
# temporary install here since we need nightly, will move to requirements/test.in
# after torchao 0.12 release
-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
-
VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
-
label
:
LM Eval Small Models
# 53min
...
...
tests/quantization/test_torchao.py
View file @
c8134bea
...
...
@@ -13,7 +13,7 @@ TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_pre_quantized_model
(
vllm_runner
):
with
vllm_runner
(
"drisspg/f
loat8_dynamic_act_float8_weight
-opt-125m"
,
with
vllm_runner
(
"drisspg/f
p8
-opt-125m"
,
quantization
=
"torchao"
,
dtype
=
"bfloat16"
,
enforce_eager
=
True
)
as
llm
:
...
...
@@ -30,10 +30,10 @@ def test_pre_quantized_model(vllm_runner):
"cuda:0"
,
# {"": "cuda"},
])
def
test_opt_125m_int
4
wo_model_loading_with_params
(
vllm_runner
,
def
test_opt_125m_int
8
wo_model_loading_with_params
(
vllm_runner
,
pt_load_map_location
):
torch
.
_dynamo
.
reset
()
model_name
=
"jerryzh168/opt-125m-int
4
wo"
model_name
=
"jerryzh168/opt-125m-int
8
wo
-partial-quant
"
with
vllm_runner
(
model_name
=
model_name
,
quantization
=
"torchao"
,
dtype
=
"bfloat16"
,
...
...
vllm/model_executor/layers/quantization/torchao.py
View file @
c8134bea
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
...
...
@@ -13,12 +14,28 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
class
TorchAOConfig
(
QuantizationConfig
):
"""Config class for torchao."""
def
__init__
(
self
,
torchao_config
)
->
None
:
self
.
torchao_config
=
torchao_config
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
if is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
logger.info(
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
# TODO: remove after the torch dependency is updated to 2.8
if is_torch_equal_or_newer(
"2.7.0") and not is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
def
__repr__
(
self
)
->
str
:
return
f
"TorchAOConfig(
{
self
.
torchao_config
}
)"
...
...
@@ -61,10 +78,10 @@ class TorchAOConfig(QuantizationConfig):
if
not
isinstance
(
layer
,
LinearBase
):
return
None
from
torchao.quantization
import
AOPer
ModuleConfig
from
torchao.quantization
import
Module
FqnTo
Config
module_fqn
=
prefix
if
isinstance
(
self
.
torchao_config
,
AOPer
ModuleConfig
):
if
isinstance
(
self
.
torchao_config
,
Module
FqnTo
Config
):
module_fqn_to_config
=
self
.
torchao_config
.
module_fqn_to_config
c
=
module_fqn_to_config
.
get
(
module_fqn
)
or
module_fqn_to_config
.
get
(
"_default"
,
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment