Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
03c4c4aa
Unverified
Commit
03c4c4aa
authored
Nov 04, 2025
by
Jerry Zhang
Committed by
GitHub
Nov 04, 2025
Browse files
Support using Int4PreshuffledTensor after loading (#26066)
Signed-off-by:
Jerry Zhang
<
jerryzh168@gmail.com
>
parent
2ec401bc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
208 additions
and
4 deletions
+208
-4
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+144
-2
vllm/model_executor/layers/quantization/torchao.py
vllm/model_executor/layers/quantization/torchao.py
+64
-2
No files found.
tests/quantization/test_torchao.py
View file @
03c4c4aa
...
@@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
...
@@ -99,7 +99,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_on
_the_fly
_quant_config_dict_json
(
vllm_runner
):
def
test_on
line
_quant_config_dict_json
(
vllm_runner
):
"""Testing on the fly quantization, load_weights integration point,
"""Testing on the fly quantization, load_weights integration point,
with config dict serialized to json string
with config dict serialized to json string
"""
"""
...
@@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner):
...
@@ -133,7 +133,7 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner):
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
def
test_on
_the_fly
_quant_config_file
(
vllm_runner
):
def
test_on
line
_quant_config_file
(
vllm_runner
):
"""Testing on the fly quantization, load_weights integration point,
"""Testing on the fly quantization, load_weights integration point,
with config file
with config file
"""
"""
...
@@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
...
@@ -252,6 +252,148 @@ def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
)
as
llm
:
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
4
)
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
4
)
assert
output
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
@
pytest
.
mark
.
skip
(
reason
=
"since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def
test_opt_125m_int4wo_model_running_preshuffled_kernel
(
vllm_runner
,
monkeypatch
):
"""We load a model with Int4Tensor (plain format) linear weights
and verify that the weight is updated to Int4PreshuffledTensor
after loading in vllm
"""
from
torchao.quantization
import
Int4PreshuffledTensor
from
torchao.utils
import
_is_fbgemm_gpu_genai_available
,
is_sm_at_least_90
torch
.
_dynamo
.
reset
()
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
model_name
=
"torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev"
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
# have meta kernel implemented yet, can remove this flag after that is implemented
with
vllm_runner
(
model_name
=
model_name
,
quantization
=
"torchao"
,
dtype
=
"bfloat16"
,
pt_load_map_location
=
"cuda:0"
,
enforce_eager
=
True
,
)
as
llm
:
def
has_int4_preshuffled_tensor_weight
(
model
):
return
isinstance
(
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
qkv_proj
.
weight
,
Int4PreshuffledTensor
,
)
def
get_weight_attrs
(
model
):
weight
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
qkv_proj
.
weight
return
[
weight
.
requires_grad
,
weight
.
input_dim
,
weight
.
output_dim
,
hasattr
(
weight
,
"weight_loader"
),
]
llm_engine
=
llm
.
get_llm
().
llm_engine
has_int4_preshuffled_tensor
=
any
(
llm_engine
.
apply_model
(
has_int4_preshuffled_tensor_weight
)
)
weight_attrs
=
llm_engine
.
apply_model
(
get_weight_attrs
)[
0
]
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
# fbgemm_gpu_genai
# library is installed, otherwise it should be using Int4Tensor
if
_is_fbgemm_gpu_genai_available
()
and
is_sm_at_least_90
():
assert
has_int4_preshuffled_tensor
else
:
assert
not
has_int4_preshuffled_tensor
assert
weight_attrs
==
[
False
,
1
,
0
,
True
]
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
@
pytest
.
mark
.
skip
(
reason
=
"since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def
test_opt_125m_int4wo_model_running_preshuffled_kernel_online_quant
(
vllm_runner
,
monkeypatch
):
"""We load a bf16 model and online quantize the model to int4, then verify that
the weights are updated to Int4PreshuffledTensor after online quantization
"""
from
torchao.quantization
import
Int4PreshuffledTensor
from
torchao.utils
import
_is_fbgemm_gpu_genai_available
,
is_sm_at_least_90
torch
.
_dynamo
.
reset
()
model_name
=
"facebook/opt-125m"
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
import
json
from
torchao.core.config
import
config_to_dict
from
torchao.quantization
import
Int4WeightOnlyConfig
torchao_quant_config
=
Int4WeightOnlyConfig
(
group_size
=
128
,
int4_packing_format
=
"plain"
)
hf_overrides
=
{
"quantization_config_dict_json"
:
json
.
dumps
(
config_to_dict
(
torchao_quant_config
)
)
}
# Note: using enforce_eager=True because the `bf16i4bf16_shuffled` doesn't
# have meta kernel implemented yet, can remove this flag after that is implemented
with
vllm_runner
(
model_name
=
model_name
,
quantization
=
"torchao"
,
dtype
=
"bfloat16"
,
pt_load_map_location
=
"cuda:0"
,
hf_overrides
=
hf_overrides
,
enforce_eager
=
True
,
)
as
llm
:
def
has_int4_preshuffled_tensor_weight
(
model
):
return
isinstance
(
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
qkv_proj
.
weight
,
Int4PreshuffledTensor
,
)
def
get_weight_attrs
(
model
):
weight
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
qkv_proj
.
weight
return
[
weight
.
requires_grad
,
weight
.
input_dim
,
weight
.
output_dim
,
hasattr
(
weight
,
"weight_loader"
),
]
llm_engine
=
llm
.
get_llm
().
llm_engine
has_int4_preshuffled_tensor
=
any
(
llm_engine
.
apply_model
(
has_int4_preshuffled_tensor_weight
)
)
weight_attrs
=
llm_engine
.
apply_model
(
get_weight_attrs
)[
0
]
# making sure we are using Int4PreshuffledTensor on H100 GPU, when
# fbgemm_gpu_genai
# library is installed, otherwise it should be using Int4Tensor
if
_is_fbgemm_gpu_genai_available
()
and
is_sm_at_least_90
():
assert
has_int4_preshuffled_tensor
else
:
assert
not
has_int4_preshuffled_tensor
assert
weight_attrs
==
[
False
,
1
,
0
,
True
]
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
assert
output
...
...
vllm/model_executor/layers/quantization/torchao.py
View file @
03c4c4aa
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
import
importlib
import
json
import
json
import
types
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
...
@@ -27,6 +28,39 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -27,6 +28,39 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_bond_method_to_cls
(
func
,
obj
):
if
hasattr
(
func
,
"__self__"
)
or
not
callable
(
func
):
# If the function is already bound to an instance, return it as is
return
func
else
:
return
types
.
MethodType
(
func
,
obj
)
def
_get_weight_attrs
(
param
):
# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr
=
{}
for
key
in
param
.
__dict__
:
if
hasattr
(
param
,
key
):
attr
=
getattr
(
param
,
key
)
if
not
callable
(
attr
):
recorded_weight_attr
[
key
]
=
attr
elif
hasattr
(
attr
,
"__self__"
)
and
param
is
attr
.
__self__
:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
recorded_weight_attr
[
key
]
=
attr
.
__func__
else
:
recorded_weight_attr
[
key
]
=
attr
return
recorded_weight_attr
def
_restore_weight_attrs
(
param
,
recorded_weight_attr
):
for
attr_name
,
attr
in
recorded_weight_attr
.
items
():
if
not
hasattr
(
param
,
attr_name
):
setattr
(
param
,
attr_name
,
_bond_method_to_cls
(
attr
,
param
))
def
torchao_version_at_least
(
torchao_version
:
str
)
->
bool
:
def
torchao_version_at_least
(
torchao_version
:
str
)
->
bool
:
if
find_spec
(
"torchao"
):
if
find_spec
(
"torchao"
):
try
:
try
:
...
@@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
...
@@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
return
False
return
False
if
torchao_version_at_least
(
"0.15.0"
):
from
torchao.prototype.tensor_conversion.api
import
(
convert_to_packed_tensor_based_on_current_hardware
,
)
else
:
convert_to_packed_tensor_based_on_current_hardware
=
lambda
t
:
t
class
TorchAOConfig
(
QuantizationConfig
):
class
TorchAOConfig
(
QuantizationConfig
):
"""Config class for torchao."""
"""Config class for torchao."""
...
@@ -307,12 +349,32 @@ class TorchAOLinearMethod(LinearMethodBase):
...
@@ -307,12 +349,32 @@ class TorchAOLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
self
.
quant_config
.
is_checkpoint_torchao_serialized
:
if
self
.
quant_config
.
is_checkpoint_torchao_serialized
:
if
not
hasattr
(
layer
,
"weight"
):
return
# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr
=
_get_weight_attrs
(
layer
.
weight
)
layer
.
weight
=
Parameter
(
convert_to_packed_tensor_based_on_current_hardware
(
layer
.
weight
),
requires_grad
=
layer
.
weight
.
requires_grad
,
)
_restore_weight_attrs
(
layer
.
weight
,
recorded_weight_attr
)
return
return
# quantize the weight
on the fly
if the checkpoint is not already
#
online
quantize the weight if the checkpoint is not already
# quantized by torchao
# quantized by torchao
recorded_weight_attr
=
_get_weight_attrs
(
layer
.
weight
)
weight
=
torchao_quantize_param_data
(
weight
=
torchao_quantize_param_data
(
layer
.
weight
,
self
.
quant_config
.
torchao_config
layer
.
weight
,
self
.
quant_config
.
torchao_config
)
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
weight
=
torch
.
nn
.
Parameter
(
convert_to_packed_tensor_based_on_current_hardware
(
weight
),
weight
.
requires_grad
,
)
_restore_weight_attrs
(
weight
,
recorded_weight_attr
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment