Unverified Commit cffa2b9c authored by kallewoof's avatar kallewoof Committed by GitHub
Browse files

save_pretrained: use tqdm when saving checkpoint shards from offloaded params (#31856)

parent 350aed70
...@@ -2657,7 +2657,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2657,7 +2657,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
): ):
os.remove(full_filename) os.remove(full_filename)
# Save the model # Save the model
for shard_file, tensors in state_dict_split.filename_to_tensors.items(): filename_to_tensors = state_dict_split.filename_to_tensors.items()
if module_map:
filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors} shard = {tensor: state_dict[tensor] for tensor in tensors}
# remake shard with onloaded parameters if necessary # remake shard with onloaded parameters if necessary
if module_map: if module_map:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment