Commit 9cff1203 authored by zhuwenwen's avatar zhuwenwen
Browse files

update pt_weights_iterator

parent b4a253fc
......@@ -512,6 +512,13 @@ def pt_weights_iterator(
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
total_count = 0
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu", weights_only=True)
total_count += len(state)
del state
current_count = 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
......@@ -519,7 +526,11 @@ def pt_weights_iterator(
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items()
for name, param in state.items():
current_count += 1
param.current_count = current_count
param.total_count = total_count
yield name, param
del state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment