Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d642872
Unverified
Commit
1d642872
authored
Nov 19, 2025
by
liangel-02
Committed by
GitHub
Nov 19, 2025
Browse files
[torchao] fix safetensors for sharding (#28169)
Signed-off-by:
Angel Li
<
liangel@meta.com
>
parent
9ccef8e3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
11 deletions
+23
-11
tests/quantization/test_torchao.py
tests/quantization/test_torchao.py
+4
-5
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+1
-1
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+18
-5
No files found.
tests/quantization/test_torchao.py
View file @
1d642872
...
...
@@ -225,13 +225,12 @@ def test_reload_weights():
@
pytest
.
mark
.
skip
(
reason
=
"since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.1
4
.0.dev+) for now"
"torchao tests that requires newer versions (0.1
5
.0.dev+) for now"
)
def
test_
opt_125m_float8_weight_only_
safetensors_model_loading_with_params
(
vllm_runner
):
def
test_safetensors_model_loading_with_params
(
vllm_runner
):
torch
.
_dynamo
.
reset
()
model_name
=
(
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
)
# using this model to test safetensors loading with file sharding
model_name
=
"torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors"
with
vllm_runner
(
model_name
=
model_name
,
dtype
=
"bfloat16"
)
as
llm
:
output
=
llm
.
generate_greedy
([
"The capital of France is"
],
max_tokens
=
4
)
...
...
vllm/model_executor/model_loader/default_loader.py
View file @
1d642872
...
...
@@ -279,7 +279,7 @@ class DefaultModelLoader(BaseModelLoader):
if
(
hasattr
(
quant_config
,
"is_checkpoint_torchao_serialized"
)
and
quant_config
.
is_checkpoint_torchao_serialized
and
torchao_version_at_least
(
"0.1
4
.0"
)
and
torchao_version_at_least
(
"0.1
5
.0"
)
):
self
.
load_config
.
safetensors_load_strategy
=
"torchao"
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
1d642872
...
...
@@ -595,6 +595,9 @@ def safetensors_weights_iterator(
if
safetensors_load_strategy
==
"eager"
:
loading_desc
+=
" (eager)"
state_dict
=
{}
leftover_state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
st_file
in
tqdm
(
hf_weights_files
,
desc
=
loading_desc
,
...
...
@@ -606,9 +609,11 @@ def safetensors_weights_iterator(
state_dict
=
load
(
f
.
read
())
yield
from
state_dict
.
items
()
elif
safetensors_load_strategy
==
"torchao"
:
if
not
torchao_version_at_least
(
"0.14.0"
):
# we can't load flattened torchao tensor subclasses directly into the model
# instead we reconstruct the subclasses here before returning
if
not
torchao_version_at_least
(
"0.15.0"
):
raise
ValueError
(
"Please use torchao version >= 0.1
4
.0
\
"Please use torchao version >= 0.1
5
.0
\
to load torchao safetensors checkpoint"
)
from
torchao.prototype.safetensors.safetensors_support
import
(
...
...
@@ -616,12 +621,20 @@ def safetensors_weights_iterator(
)
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
state_dict
=
{}
for
name
in
f
.
keys
():
# noqa: SIM118
state_dict
[
name
]
=
f
.
get_tensor
(
name
)
# update with leftover tensor data from previous iteration, if any
state_dict
.
update
(
leftover_state_dict
)
metadata
=
f
.
metadata
()
updated_state_dict
=
unflatten_tensor_state_dict
(
state_dict
,
metadata
)
yield
from
updated_state_dict
.
items
()
# due to sharded checkpoints, we are not guaranteed that we have all
# tensor subclass data on one file
# state_dict has the leftover data from this step and we wait for
# missing information to be provided in a future iteration
unflattened_state_dict
,
leftover_state_dict
=
(
unflatten_tensor_state_dict
(
state_dict
,
metadata
)
)
yield
from
unflattened_state_dict
.
items
()
else
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment