Commit 145b4eac authored by zhuwenwen's avatar zhuwenwen
Browse files

update pt_weights_iterator

parent 9bc81d6d
...@@ -693,26 +693,39 @@ def pt_weights_iterator( ...@@ -693,26 +693,39 @@ def pt_weights_iterator(
pt_load_map_location: Union[str, dict[str, str]] = "cpu", pt_load_map_location: Union[str, dict[str, str]] = "cpu",
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files.""" """Iterate over the weights in the model bin/pt files."""
total_count = 0 if os.environ.get('LLAMA_NN') == '1':
for bin_file in hf_weights_files: total_count = 0
state = torch.load(bin_file, map_location=pt_load_map_location, weights_only=True) for bin_file in hf_weights_files:
total_count += len(state) state = torch.load(bin_file, map_location=pt_load_map_location, weights_only=True)
del state total_count += len(state)
del state
current_count = 0
for bin_file in tqdm( current_count = 0
hf_weights_files, for bin_file in tqdm(
desc="Loading pt checkpoint shards", hf_weights_files,
disable=not enable_tqdm(use_tqdm_on_load), desc="Loading pt checkpoint shards",
bar_format=_BAR_FORMAT, disable=not enable_tqdm(use_tqdm_on_load),
): bar_format=_BAR_FORMAT,
state = torch.load(bin_file, map_location=pt_load_map_location, weights_only=True) ):
for name, param in state.items(): state = torch.load(bin_file, map_location=pt_load_map_location, weights_only=True)
current_count += 1 for name, param in state.items():
param.current_count = current_count current_count += 1
param.total_count = total_count param.current_count = current_count
yield name, param param.total_count = total_count
del state yield name, param
del state
else:
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file,
map_location=pt_load_map_location,
weights_only=True)
yield from state.items()
del state
def multi_thread_pt_weights_iterator( def multi_thread_pt_weights_iterator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment