Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b32260ab
Unverified
Commit
b32260ab
authored
Oct 07, 2025
by
liangel-02
Committed by
GitHub
Oct 07, 2025
Browse files
[torchao] safetensors integration (#25969)
Signed-off-by:
Angel Li
<
liangel@meta.com
>
parent
f80e7866
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
60 additions
and
0 deletions
+60
-0
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+17
-0
vllm/config/load.py
vllm/config/load.py
+4
-0
vllm/model_executor/layers/quantization/torchao.py
vllm/model_executor/layers/quantization/torchao.py
+15
-0
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+5
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+19
-0
No files found.
tests/quantization/test_torchao.py
View file @
b32260ab
...
...
@@ -216,5 +216,22 @@ def test_reload_weights():
# print("-" * 60)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
@
pytest
.
mark
.
skip
(
reason
=
"since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def
test_opt_125m_float8_weight_only_safetensors_model_loading_with_params
(
vllm_runner
):
torch
.
_dynamo
.
reset
()
model_name
=
(
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
)
with
vllm_runner
(
model_name
=
model_name
,
dtype
=
"bfloat16"
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
vllm/config/load.py
View file @
b32260ab
...
...
@@ -59,6 +59,10 @@ class LoadConfig:
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
- "torchao": Weights are loaded in upfront and then reconstructed
into torchao tensor subclasses. This is used when the checkpoint
was quantized using torchao and saved using safetensors.
Needs torchao >= 0.14.0
"""
model_loader_extra_config
:
Union
[
dict
,
TensorizerConfig
]
=
field
(
default_factory
=
dict
...
...
vllm/model_executor/layers/quantization/torchao.py
View file @
b32260ab
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
import
json
from
importlib.util
import
find_spec
from
typing
import
Any
,
Optional
import
torch
import
torch.nn.functional
as
F
from
packaging
import
version
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
...
...
@@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
def
torchao_version_at_least
(
torchao_version
:
str
)
->
bool
:
if
find_spec
(
"torchao"
):
try
:
if
version
.
parse
(
importlib
.
metadata
.
version
(
"torchao"
))
>=
version
.
parse
(
torchao_version
):
return
True
except
(
ImportError
,
version
.
InvalidVersion
):
return
False
return
False
def
should_skip
(
prefix
:
str
,
skip_modules
:
list
[
str
])
->
bool
:
"""
Robust skipping logic:
...
...
vllm/model_executor/model_loader/default_loader.py
View file @
b32260ab
...
...
@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
ModelConfig
from
vllm.config.load
import
LoadConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.torchao
import
torchao_version_at_least
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
...
...
@@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
)
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
if
model_config
.
quantization
==
"torchao"
and
torchao_version_at_least
(
"0.14.0"
):
self
.
load_config
.
safetensors_load_strategy
=
"torchao"
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
b32260ab
...
...
@@ -54,6 +54,8 @@ except ImportError:
SafeTensorsFileLoader
=
fastsafetensors
.
placeholder_attr
(
"SafeTensorsFileLoader"
)
SingleGroup
=
fastsafetensors
.
placeholder_attr
(
"SingleGroup"
)
from
vllm.model_executor.layers.quantization.torchao
import
torchao_version_at_least
logger
=
init_logger
(
__name__
)
# use system-level temp directory for file locks, so that multiple users
...
...
@@ -602,6 +604,23 @@ def safetensors_weights_iterator(
with
open
(
st_file
,
"rb"
)
as
f
:
state_dict
=
load
(
f
.
read
())
yield
from
state_dict
.
items
()
elif
safetensors_load_strategy
==
"torchao"
:
if
not
torchao_version_at_least
(
"0.14.0"
):
raise
ValueError
(
"Please use torchao version >= 0.14.0
\
to load torchao safetensors checkpoint"
)
from
torchao.prototype.safetensors.safetensors_support
import
(
unflatten_tensor_state_dict
,
)
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
state_dict
=
{}
for
name
in
f
.
keys
():
# noqa: SIM118
state_dict
[
name
]
=
f
.
get_tensor
(
name
)
metadata
=
f
.
metadata
()
updated_state_dict
=
unflatten_tensor_state_dict
(
state_dict
,
metadata
)
yield
from
updated_state_dict
.
items
()
else
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment