Commit 145b4eac authored by zhuwenwen's avatar zhuwenwen
Browse files

update pt_weights_iterator

parent 9bc81d6d
......@@ -693,6 +693,7 @@ def pt_weights_iterator(
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
if os.environ.get('LLAMA_NN') == '1':
total_count = 0
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location=pt_load_map_location, weights_only=True)
......@@ -713,6 +714,18 @@ def pt_weights_iterator(
param.total_count = total_count
yield name, param
del state
else:
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file,
map_location=pt_load_map_location,
weights_only=True)
yield from state.items()
del state
def multi_thread_pt_weights_iterator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment