"docs/vscode:/vscode.git/clone" did not exist on "2aed2c9fa7f9f751b994c6563893dcb048e0ae7b"
Commit 9cff1203 authored by zhuwenwen's avatar zhuwenwen
Browse files

update pt_weights_iterator

parent b4a253fc
...@@ -512,6 +512,13 @@ def pt_weights_iterator( ...@@ -512,6 +512,13 @@ def pt_weights_iterator(
use_tqdm_on_load: bool, use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files.""" """Iterate over the weights in the model bin/pt files."""
total_count = 0
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu", weights_only=True)
total_count += len(state)
del state
current_count = 0
for bin_file in tqdm( for bin_file in tqdm(
hf_weights_files, hf_weights_files,
desc="Loading pt checkpoint shards", desc="Loading pt checkpoint shards",
...@@ -519,7 +526,11 @@ def pt_weights_iterator( ...@@ -519,7 +526,11 @@ def pt_weights_iterator(
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, map_location="cpu", weights_only=True) state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items() for name, param in state.items():
current_count += 1
param.current_count = current_count
param.total_count = total_count
yield name, param
del state del state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment