Unverified Commit c5201240 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[misc] only tqdm for first rank (#6672)

parent 97234be0
...@@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference( ...@@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference(
return hf_weights_files return hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def np_cache_weights_iterator( def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
hf_weights_files: List[str] hf_weights_files: List[str]
...@@ -321,6 +328,8 @@ def np_cache_weights_iterator( ...@@ -321,6 +328,8 @@ def np_cache_weights_iterator(
Will dump the model weights to numpy files if they are not already dumped. Will dump the model weights to numpy files if they are not already dumped.
""" """
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
# Convert the model weights from torch tensors to numpy arrays for # Convert the model weights from torch tensors to numpy arrays for
# faster loading. # faster loading.
np_folder = os.path.join(hf_folder, "np") np_folder = os.path.join(hf_folder, "np")
...@@ -331,8 +340,12 @@ def np_cache_weights_iterator( ...@@ -331,8 +340,12 @@ def np_cache_weights_iterator(
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file): if not os.path.exists(weight_names_file):
weight_names: List[str] = [] weight_names: List[str] = []
for bin_file in tqdm(hf_weights_files, for bin_file in tqdm(
desc="Loading np_cache checkpoint shards"): hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
param_path = os.path.join(np_folder, name) param_path = os.path.join(np_folder, name)
...@@ -356,8 +369,14 @@ def safetensors_weights_iterator( ...@@ -356,8 +369,14 @@ def safetensors_weights_iterator(
hf_weights_files: List[str] hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.""" """Iterate over the weights in the model safetensor files."""
for st_file in tqdm(hf_weights_files, enable_tqdm = not torch.distributed.is_initialized(
desc="Loading safetensors checkpoint shards"): ) or torch.distributed.get_rank() == 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f: with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118 for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name) param = f.get_tensor(name)
...@@ -368,8 +387,14 @@ def pt_weights_iterator( ...@@ -368,8 +387,14 @@ def pt_weights_iterator(
hf_weights_files: List[str] hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files.""" """Iterate over the weights in the model bin/pt files."""
for bin_file in tqdm(hf_weights_files, enable_tqdm = not torch.distributed.is_initialized(
desc="Loading pt checkpoint shards"): ) or torch.distributed.get_rank() == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
yield name, param yield name, param
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment