Unverified Commit c79bbc3b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix multiple deletions of the same files in save_pretrained (#16947)

* Fix multiple deletions of the same files in save_pretrained

* Add is_main_process argument
parent bfbec177
...@@ -444,7 +444,11 @@ def main(): ...@@ -444,7 +444,11 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
if accelerator.is_main_process: if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir) feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -490,7 +494,9 @@ def main(): ...@@ -490,7 +494,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir) feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -506,7 +512,9 @@ def main(): ...@@ -506,7 +512,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir) feature_extractor.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -580,7 +580,9 @@ def main(): ...@@ -580,7 +580,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -596,7 +598,9 @@ def main(): ...@@ -596,7 +598,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -627,7 +627,9 @@ def main(): ...@@ -627,7 +627,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -643,7 +645,9 @@ def main(): ...@@ -643,7 +645,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -588,7 +588,9 @@ def main(): ...@@ -588,7 +588,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -604,7 +606,9 @@ def main(): ...@@ -604,7 +606,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -817,7 +817,9 @@ def main(): ...@@ -817,7 +817,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -961,7 +963,9 @@ def main(): ...@@ -961,7 +963,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -832,7 +832,9 @@ def main(): ...@@ -832,7 +832,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -930,7 +932,9 @@ def main(): ...@@ -930,7 +932,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -553,7 +553,11 @@ def main(): ...@@ -553,7 +553,11 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
if accelerator.is_main_process: if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir) feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -613,7 +617,9 @@ def main(): ...@@ -613,7 +617,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir) feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -629,7 +635,9 @@ def main(): ...@@ -629,7 +635,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir) feature_extractor.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -669,7 +669,9 @@ def main(): ...@@ -669,7 +669,9 @@ def main():
if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None: if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process: if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process:
repo.push_to_hub( repo.push_to_hub(
...@@ -720,7 +722,9 @@ def main(): ...@@ -720,7 +722,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
......
...@@ -687,7 +687,9 @@ def main(): ...@@ -687,7 +687,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -703,7 +705,9 @@ def main(): ...@@ -703,7 +705,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -539,7 +539,9 @@ def main(): ...@@ -539,7 +539,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -555,7 +557,9 @@ def main(): ...@@ -555,7 +557,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -691,7 +691,9 @@ def main(): ...@@ -691,7 +691,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -707,7 +709,9 @@ def main(): ...@@ -707,7 +709,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -662,7 +662,9 @@ def main(): ...@@ -662,7 +662,9 @@ def main():
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub( repo.push_to_hub(
...@@ -678,7 +680,9 @@ def main(): ...@@ -678,7 +680,9 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process: if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import re import re
import shutil import shutil
import tempfile import tempfile
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -1347,7 +1348,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1347,7 +1348,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
save_config: bool = True, is_main_process: bool = True,
state_dict: Optional[dict] = None, state_dict: Optional[dict] = None,
save_function: Callable = torch.save, save_function: Callable = torch.save,
push_to_hub: bool = False, push_to_hub: bool = False,
...@@ -1361,10 +1362,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1361,10 +1362,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist. Directory to which to save. Will be created if it doesn't exist.
save_config (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether or not to save the config of the model. Useful when in distributed training like TPUs and need Whether the process calling this is the main process or not. Useful when in distributed training like
to call this function on all processes. In this case, set `save_config=True` only on the main process TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
to avoid race conditions. the main process to avoid race conditions.
state_dict (nested dictionary of `torch.Tensor`): state_dict (nested dictionary of `torch.Tensor`):
The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
save parts of the model or if special precautions need to be taken when recovering the state dictionary save parts of the model or if special precautions need to be taken when recovering the state dictionary
...@@ -1397,6 +1398,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1397,6 +1398,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
kwargs: kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
if "save_config" in kwargs:
warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
)
is_main_process = kwargs.pop("save_config")
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
...@@ -1424,7 +1431,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1424,7 +1431,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
custom_object_save(self, save_directory, config=self.config) custom_object_save(self, save_directory, config=self.config)
# Save the config # Save the config
if save_config: if is_main_process:
model_to_save.config.save_pretrained(save_directory) model_to_save.config.save_pretrained(save_directory)
# Save the model # Save the model
...@@ -1443,7 +1450,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1443,7 +1450,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Clean the folder from a previous save # Clean the folder from a previous save
for filename in os.listdir(save_directory): for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename) full_filename = os.path.join(save_directory, filename)
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename): # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
if (
filename.startswith(WEIGHTS_NAME[:-4])
and os.path.isfile(full_filename)
and filename not in shards.keys()
and is_main_process
):
os.remove(full_filename) os.remove(full_filename)
# Save the model # Save the model
......
...@@ -2173,7 +2173,7 @@ class Trainer: ...@@ -2173,7 +2173,7 @@ class Trainer:
if isinstance(unwrap_model(self.model), PreTrainedModel): if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained( unwrap_model(self.model).save_pretrained(
output_dir, output_dir,
save_config=self.args.should_save, is_main_process=self.args.should_save,
state_dict=self.model.state_dict(), state_dict=self.model.state_dict(),
save_function=xm.save, save_function=xm.save,
) )
...@@ -2182,7 +2182,7 @@ class Trainer: ...@@ -2182,7 +2182,7 @@ class Trainer:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save) self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
if self.tokenizer is not None and self.args.should_save: if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment