Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b32260ab
Unverified
Commit
b32260ab
authored
Oct 07, 2025
by
liangel-02
Committed by
GitHub
Oct 07, 2025
Browse files
[torchao] safetensors integration (#25969)
Signed-off-by:
Angel Li
<
liangel@meta.com
>
parent
f80e7866
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
60 additions
and
0 deletions
+60
-0
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+17
-0
vllm/config/load.py
vllm/config/load.py
+4
-0
vllm/model_executor/layers/quantization/torchao.py
vllm/model_executor/layers/quantization/torchao.py
+15
-0
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+5
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+19
-0
No files found.
tests/quantization/test_torchao.py
View file @
b32260ab
...
@@ -216,5 +216,22 @@ def test_reload_weights():
...
@@ -216,5 +216,22 @@ def test_reload_weights():
# print("-" * 60)
# print("-" * 60)
@
pytest
.
mark
.
skipif
(
not
TORCHAO_AVAILABLE
,
reason
=
"torchao is not available"
)
@
pytest
.
mark
.
skip
(
reason
=
"since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def
test_opt_125m_float8_weight_only_safetensors_model_loading_with_params
(
vllm_runner
):
torch
.
_dynamo
.
reset
()
model_name
=
(
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
)
with
vllm_runner
(
model_name
=
model_name
,
dtype
=
"bfloat16"
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
32
)
assert
output
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/config/load.py
View file @
b32260ab
...
@@ -59,6 +59,10 @@ class LoadConfig:
...
@@ -59,6 +59,10 @@ class LoadConfig:
This is recommended for models on network filesystems (e.g., Lustre, NFS)
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
initialization. However, it uses more CPU RAM.
- "torchao": Weights are loaded in upfront and then reconstructed
into torchao tensor subclasses. This is used when the checkpoint
was quantized using torchao and saved using safetensors.
Needs torchao >= 0.14.0
"""
"""
model_loader_extra_config
:
Union
[
dict
,
TensorizerConfig
]
=
field
(
model_loader_extra_config
:
Union
[
dict
,
TensorizerConfig
]
=
field
(
default_factory
=
dict
default_factory
=
dict
...
...
vllm/model_executor/layers/quantization/torchao.py
View file @
b32260ab
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
import
json
import
json
from
importlib.util
import
find_spec
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
packaging
import
version
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
torchao_version_at_least
(
torchao_version
:
str
)
->
bool
:
if
find_spec
(
"torchao"
):
try
:
if
version
.
parse
(
importlib
.
metadata
.
version
(
"torchao"
))
>=
version
.
parse
(
torchao_version
):
return
True
except
(
ImportError
,
version
.
InvalidVersion
):
return
False
return
False
def
should_skip
(
prefix
:
str
,
skip_modules
:
list
[
str
])
->
bool
:
def
should_skip
(
prefix
:
str
,
skip_modules
:
list
[
str
])
->
bool
:
"""
"""
Robust skipping logic:
Robust skipping logic:
...
...
vllm/model_executor/model_loader/default_loader.py
View file @
b32260ab
...
@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
...
@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.config.load
import
LoadConfig
from
vllm.config.load
import
LoadConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.torchao
import
torchao_version_at_least
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_safetensors_index_file_from_hf
,
...
@@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
)
)
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
if
model_config
.
quantization
==
"torchao"
and
torchao_version_at_least
(
"0.14.0"
):
self
.
load_config
.
safetensors_load_strategy
=
"torchao"
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# if we don't have `model.weight_metadata_and_attr_saved` defined and
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
b32260ab
...
@@ -54,6 +54,8 @@ except ImportError:
...
@@ -54,6 +54,8 @@ except ImportError:
SafeTensorsFileLoader
=
fastsafetensors
.
placeholder_attr
(
"SafeTensorsFileLoader"
)
SafeTensorsFileLoader
=
fastsafetensors
.
placeholder_attr
(
"SafeTensorsFileLoader"
)
SingleGroup
=
fastsafetensors
.
placeholder_attr
(
"SingleGroup"
)
SingleGroup
=
fastsafetensors
.
placeholder_attr
(
"SingleGroup"
)
from
vllm.model_executor.layers.quantization.torchao
import
torchao_version_at_least
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# use system-level temp directory for file locks, so that multiple users
# use system-level temp directory for file locks, so that multiple users
...
@@ -602,6 +604,23 @@ def safetensors_weights_iterator(
...
@@ -602,6 +604,23 @@ def safetensors_weights_iterator(
with
open
(
st_file
,
"rb"
)
as
f
:
with
open
(
st_file
,
"rb"
)
as
f
:
state_dict
=
load
(
f
.
read
())
state_dict
=
load
(
f
.
read
())
yield
from
state_dict
.
items
()
yield
from
state_dict
.
items
()
elif
safetensors_load_strategy
==
"torchao"
:
if
not
torchao_version_at_least
(
"0.14.0"
):
raise
ValueError
(
"Please use torchao version >= 0.14.0
\
to load torchao safetensors checkpoint"
)
from
torchao.prototype.safetensors.safetensors_support
import
(
unflatten_tensor_state_dict
,
)
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
state_dict
=
{}
for
name
in
f
.
keys
():
# noqa: SIM118
state_dict
[
name
]
=
f
.
get_tensor
(
name
)
metadata
=
f
.
metadata
()
updated_state_dict
=
unflatten_tensor_state_dict
(
state_dict
,
metadata
)
yield
from
updated_state_dict
.
items
()
else
:
else
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
for
name
in
f
.
keys
():
# noqa: SIM118
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment