Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
145b4eac
Commit
145b4eac
authored
Jan 20, 2026
by
zhuwenwen
Browse files
update pt_weights_iterator
parent
9bc81d6d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
20 deletions
+33
-20
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+33
-20
No files found.
vllm/model_executor/model_loader/weight_utils.py
View file @
145b4eac
...
@@ -693,26 +693,39 @@ def pt_weights_iterator(
...
@@ -693,26 +693,39 @@ def pt_weights_iterator(
pt_load_map_location
:
Union
[
str
,
dict
[
str
,
str
]]
=
"cpu"
,
pt_load_map_location
:
Union
[
str
,
dict
[
str
,
str
]]
=
"cpu"
,
)
->
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model bin/pt files."""
"""Iterate over the weights in the model bin/pt files."""
total_count
=
0
if
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
:
for
bin_file
in
hf_weights_files
:
total_count
=
0
state
=
torch
.
load
(
bin_file
,
map_location
=
pt_load_map_location
,
weights_only
=
True
)
for
bin_file
in
hf_weights_files
:
total_count
+=
len
(
state
)
state
=
torch
.
load
(
bin_file
,
map_location
=
pt_load_map_location
,
weights_only
=
True
)
del
state
total_count
+=
len
(
state
)
del
state
current_count
=
0
for
bin_file
in
tqdm
(
current_count
=
0
hf_weights_files
,
for
bin_file
in
tqdm
(
desc
=
"Loading pt checkpoint shards"
,
hf_weights_files
,
disable
=
not
enable_tqdm
(
use_tqdm_on_load
),
desc
=
"Loading pt checkpoint shards"
,
bar_format
=
_BAR_FORMAT
,
disable
=
not
enable_tqdm
(
use_tqdm_on_load
),
):
bar_format
=
_BAR_FORMAT
,
state
=
torch
.
load
(
bin_file
,
map_location
=
pt_load_map_location
,
weights_only
=
True
)
):
for
name
,
param
in
state
.
items
():
state
=
torch
.
load
(
bin_file
,
map_location
=
pt_load_map_location
,
weights_only
=
True
)
current_count
+=
1
for
name
,
param
in
state
.
items
():
param
.
current_count
=
current_count
current_count
+=
1
param
.
total_count
=
total_count
param
.
current_count
=
current_count
yield
name
,
param
param
.
total_count
=
total_count
del
state
yield
name
,
param
del
state
else
:
for
bin_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading pt checkpoint shards"
,
disable
=
not
enable_tqdm
(
use_tqdm_on_load
),
bar_format
=
_BAR_FORMAT
,
):
state
=
torch
.
load
(
bin_file
,
map_location
=
pt_load_map_location
,
weights_only
=
True
)
yield
from
state
.
items
()
del
state
def
multi_thread_pt_weights_iterator
(
def
multi_thread_pt_weights_iterator
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment