Unverified Commit 3081d386 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Push to hub when saving checkpoints (#13503)

* Push to hub when saving checkpoints

* Add model card

* Revert partial model card

* Small fix for checkpoint

* Add tests

* Add documentation

* Fix tests

* Bump huggingface_hub

* Fix test
parent 51e5eca6
...@@ -119,6 +119,29 @@ TFTrainingArguments ...@@ -119,6 +119,29 @@ TFTrainingArguments
:members: :members:
Checkpoints
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
By default, :class:`~transformers.Trainer` will save all checkpoints in the :obj:`output_dir` you set in the
:class:`~transformers.TrainingArguments` you are using. Those will go in subfolder named :obj:`checkpoint-xxx` with xxx
being the step at which the training was at.
Resuming training from a checkpoint can be done when calling :meth:`~transformers.Trainer.train` with either:
- :obj:`resume_from_checkpoint=True` which will resume training from the latest checkpoint
- :obj:`resume_from_checkpoint=checkpoint_dir` which will resume training from the specific checkpoint in the directory
passed.
In addition, you can easily save your checkpoints on the Model Hub when using :obj:`push_to_hub=True`. By default, all
the models saved in intermediate checkpoints are saved in different commits, but not the optimizer state. You can adapt
the :obj:`hub-strategy` value of your :class:`~transformers.TrainingArguments` to either:
- :obj:`"checkpoint"`: the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to
resume training easily with :obj:`trainer.train(resume_from_checkpoint="output_dir/last-checkpoint")`.
- :obj:`"all_checkpoints"`: all checkpoints are pushed like they appear in the output folder (so you will get one
checkpoint folder per folder in your final repository)
Logging Logging
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -100,7 +100,7 @@ _deps = [ ...@@ -100,7 +100,7 @@ _deps = [
"flax>=0.3.4", "flax>=0.3.4",
"fugashi>=1.0", "fugashi>=1.0",
"GitPython<3.1.19", "GitPython<3.1.19",
"huggingface-hub>=0.0.12", "huggingface-hub>=0.0.17",
"importlib_metadata", "importlib_metadata",
"ipadic>=1.0.0,<2.0", "ipadic>=1.0.0,<2.0",
"isort>=5.5.4", "isort>=5.5.4",
......
...@@ -18,7 +18,7 @@ deps = { ...@@ -18,7 +18,7 @@ deps = {
"flax": "flax>=0.3.4", "flax": "flax>=0.3.4",
"fugashi": "fugashi>=1.0", "fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19", "GitPython": "GitPython<3.1.19",
"huggingface-hub": "huggingface-hub>=0.0.12", "huggingface-hub": "huggingface-hub>=0.0.17",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0", "ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
......
...@@ -110,6 +110,8 @@ from .trainer_utils import ( ...@@ -110,6 +110,8 @@ from .trainer_utils import (
EvalLoopOutput, EvalLoopOutput,
EvalPrediction, EvalPrediction,
HPSearchBackend, HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput, PredictionOutput,
ShardedDDPOption, ShardedDDPOption,
TrainerMemoryTracker, TrainerMemoryTracker,
...@@ -180,6 +182,14 @@ if TYPE_CHECKING: ...@@ -180,6 +182,14 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
class Trainer: class Trainer:
""" """
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
...@@ -389,6 +399,12 @@ class Trainer: ...@@ -389,6 +399,12 @@ class Trainer:
# Create clone of distant repo and output directory if needed # Create clone of distant repo and output directory if needed
if self.args.push_to_hub: if self.args.push_to_hub:
self.init_git_repo() self.init_git_repo()
# In case of pull, we need to make sure every process has the latest.
if is_torch_tpu_available():
xm.rendezvous("init git repo")
elif args.local_rank != -1:
dist.barrier()
if self.args.should_save: if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True) os.makedirs(self.args.output_dir, exist_ok=True)
...@@ -901,9 +917,9 @@ class Trainer: ...@@ -901,9 +917,9 @@ class Trainer:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir) self.save_model(output_dir)
if self.args.should_save: if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
def call_model_init(self, trial=None): def call_model_init(self, trial=None):
model_init_argcount = number_of_arguments(self.model_init) model_init_argcount = number_of_arguments(self.model_init)
...@@ -1183,9 +1199,9 @@ class Trainer: ...@@ -1183,9 +1199,9 @@ class Trainer:
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile( if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, "trainer_state.json") os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
): ):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json")) self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip: if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
...@@ -1520,9 +1536,9 @@ class Trainer: ...@@ -1520,9 +1536,9 @@ class Trainer:
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
if smp.dp_rank() == 0: if smp.dp_rank() == 0:
...@@ -1530,20 +1546,20 @@ class Trainer: ...@@ -1530,20 +1546,20 @@ class Trainer:
opt_state_dict = self.optimizer.state_dict() opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process # Save it and the scheduler on the main process
if self.args.should_save: if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp: if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt")) torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed: elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched # deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp: if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt")) torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint # Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None: if metrics is not None and self.args.metric_for_best_model is not None:
...@@ -1563,7 +1579,7 @@ class Trainer: ...@@ -1563,7 +1579,7 @@ class Trainer:
# Save the Trainer state # Save the Trainer state
if self.args.should_save: if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
# Save RNG state in non-distributed training # Save RNG state in non-distributed training
rng_states = { rng_states = {
...@@ -1590,6 +1606,9 @@ class Trainer: ...@@ -1590,6 +1606,9 @@ class Trainer:
else: else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints. # Maybe delete some older checkpoints.
if self.args.should_save: if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
...@@ -1603,15 +1622,15 @@ class Trainer: ...@@ -1603,15 +1622,15 @@ class Trainer:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return return
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile( if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
os.path.join(checkpoint, "scheduler.pt") os.path.join(checkpoint, SCHEDULER_NAME)
): ):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
if is_torch_tpu_available(): if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device. # On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu") optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu") lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(optimizer_state, self.args.device) xm.send_cpu_data_to_device(optimizer_state, self.args.device)
...@@ -1622,13 +1641,13 @@ class Trainer: ...@@ -1622,13 +1641,13 @@ class Trainer:
else: else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location) torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
) )
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.use_amp and os.path.isfile(os.path.join(checkpoint, "scaler.pt")): if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, "scaler.pt"))) self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
def hyperparameter_search( def hyperparameter_search(
self, self,
...@@ -1908,7 +1927,7 @@ class Trainer: ...@@ -1908,7 +1927,7 @@ class Trainer:
if xm.is_master_ordinal(): if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
...@@ -1953,7 +1972,7 @@ class Trainer: ...@@ -1953,7 +1972,7 @@ class Trainer:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def store_flos(self): def store_flos(self):
# Storing the number of floating-point operations that went into the model # Storing the number of floating-point operations that went into the model
...@@ -2476,9 +2495,9 @@ class Trainer: ...@@ -2476,9 +2495,9 @@ class Trainer:
def init_git_repo(self): def init_git_repo(self):
""" """
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`. Initializes a git repo in :obj:`self.args.hub_model_id`.
""" """
if not self.args.should_save: if not self.is_world_process_zero():
return return
use_auth_token = True if self.args.hub_token is None else self.args.hub_token use_auth_token = True if self.args.hub_token is None else self.args.hub_token
if self.args.hub_model_id is None: if self.args.hub_model_id is None:
...@@ -2486,17 +2505,36 @@ class Trainer: ...@@ -2486,17 +2505,36 @@ class Trainer:
else: else:
repo_name = self.args.hub_model_id repo_name = self.args.hub_model_id
self.repo = Repository( try:
self.args.output_dir, self.repo = Repository(
clone_from=repo_name, self.args.output_dir,
use_auth_token=use_auth_token, clone_from=repo_name,
) use_auth_token=use_auth_token,
)
except EnvironmentError:
if self.args.overwrite_output_dir:
# Try again after wiping output_dir
shutil.rmtree(self.args.output_dir)
self.repo = Repository(
self.args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
else:
raise
self.repo.git_pull()
# By default, ignore the checkpoint folders # By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")): if (
not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
):
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"]) writer.writelines(["checkpoint-*/"])
self.push_in_progress = None
def create_model_card( def create_model_card(
self, self,
language: Optional[str] = None, language: Optional[str] = None,
...@@ -2525,18 +2563,61 @@ class Trainer: ...@@ -2525,18 +2563,61 @@ class Trainer:
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
f.write(model_card) f.write(model_card)
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str: def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
return
# If we haven't finished the last push, we don't do this one.
if self.push_in_progress is not None and not self.push_in_progress.is_done:
return
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
try:
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
# Temporarily move the checkpoint just saved for the push
tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
# subfolder.
if os.path.isdir(tmp_checkpoint):
shutil.rmtree(tmp_checkpoint)
shutil.move(checkpoint_folder, tmp_checkpoint)
if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
else:
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
_, self.push_in_progress = self.repo.push_to_hub(commit_message=commit_message, blocking=False)
finally:
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
# Move back the checkpoint to its place
shutil.move(tmp_checkpoint, checkpoint_folder)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
""" """
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`. Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters: Parameters:
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`): commit_message (:obj:`str`, `optional`, defaults to :obj:`"End of training"`):
Message to commit while pushing. Message to commit while pushing.
blocking (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether the function should return only when the :obj:`git push` has finished.
kwargs: kwargs:
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`. Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
Returns: Returns:
The url of the commit of your model in the given repository. The url of the commit of your model in the given repository if :obj:`blocking=False`, a tuple with the url
of the commit and an object to track the progress of the commit if :obj:`blocking=True`
""" """
if self.args.should_save: if self.args.should_save:
...@@ -2553,7 +2634,7 @@ class Trainer: ...@@ -2553,7 +2634,7 @@ class Trainer:
if not self.is_world_process_zero(): if not self.is_world_process_zero():
return return
return self.repo.push_to_hub(commit_message=commit_message) return self.repo.push_to_hub(commit_message=commit_message, blocking=blocking)
# #
# Deprecated code # Deprecated code
......
...@@ -125,6 +125,13 @@ class EvaluationStrategy(ExplicitEnum): ...@@ -125,6 +125,13 @@ class EvaluationStrategy(ExplicitEnum):
EPOCH = "epoch" EPOCH = "epoch"
class HubStrategy(ExplicitEnum):
END = "end"
EVERY_SAVE = "every_save"
CHECKPOINT = "checkpoint"
ALL_CHECKPOINTS = "all_checkpoints"
class BestRun(NamedTuple): class BestRun(NamedTuple):
""" """
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`). The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
......
...@@ -32,7 +32,7 @@ from .file_utils import ( ...@@ -32,7 +32,7 @@ from .file_utils import (
is_torch_tpu_available, is_torch_tpu_available,
torch_required, torch_required,
) )
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
from .utils import logging from .utils import logging
...@@ -343,6 +343,22 @@ class TrainingArguments: ...@@ -343,6 +343,22 @@ class TrainingArguments:
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
:obj:`output_dir`. :obj:`output_dir`.
hub_strategy (:obj:`str` or :class:`~transformers.trainer_utils.HubStrategy`, `optional`, defaults to :obj:`"every_save"`):
Defines the scope of what is pushed to the Hub and when. Possible values are:
- :obj:`"end"`: push the model, its configuration, the tokenizer (if passed along to the
:class:`~transformers.Trainer`) and a draft of a model card at the end of training.
- :obj:`"every_save"`: push the model, its configuration, the tokenizer (if passed along to the
:class:`~transformers.Trainer`) and a draft of a model card each time there is a model save. The pushes
are asynchronous to not block training, and in case the save are very frequent, a new push is only
attempted if the previous one is finished. A last push is made with the final model at the end of
training.
- :obj:`"checkpoint"`: like :obj:`"every_save"` but the latest checkpoint is also pushed in a subfolder
named last-checkpoint, allowing you to resume training easily with
:obj:`trainer.train(resume_from_checkpoint="last-checkpoint")`.
- :obj:`"all_checkpoints"`: like :obj:`"checkpoint"` but all checkpoints are pushed like they appear in the
output folder (so you will get one checkpoint folder per folder in your final repository)
hub_token (:obj:`str`, `optional`): hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`. :obj:`huggingface-cli login`.
...@@ -618,6 +634,10 @@ class TrainingArguments: ...@@ -618,6 +634,10 @@ class TrainingArguments:
hub_model_id: str = field( hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
) )
hub_strategy: HubStrategy = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
# Deprecated arguments # Deprecated arguments
push_to_hub_model_id: str = field( push_to_hub_model_id: str = field(
...@@ -668,6 +688,7 @@ class TrainingArguments: ...@@ -668,6 +688,7 @@ class TrainingArguments:
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy) self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy) self.save_strategy = IntervalStrategy(self.save_strategy)
self.hub_strategy = HubStrategy(self.hub_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
......
...@@ -18,13 +18,14 @@ import gc ...@@ -18,13 +18,14 @@ import gc
import os import os
import random import random
import re import re
import subprocess
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from huggingface_hub import HfApi from huggingface_hub import HfApi, Repository
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
...@@ -1284,10 +1285,11 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1284,10 +1285,11 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
cls._api.delete_repo(token=cls._token, name="test-trainer") try:
except HTTPError: cls._api.delete_repo(token=cls._token, name=model)
pass except HTTPError:
pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org") cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
...@@ -1336,6 +1338,55 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1336,6 +1338,55 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
self.assertEqual(model.a.item(), trainer.model.a.item()) self.assertEqual(model.a.item(), trainer.model.a.item())
self.assertEqual(model.b.item(), trainer.model.b.item()) self.assertEqual(model.b.item(), trainer.model.b.item())
def get_commit_history(self, repo):
commit_logs = subprocess.run(
"git log".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=repo,
).stdout
commits = commit_logs.split("\n\n")[1::2]
return [commit.strip() for commit in commits]
def test_push_to_hub_with_saves_each_epoch(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-epoch"),
push_to_hub=True,
hub_token=self._token,
save_strategy="epoch",
)
trainer.train()
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)]
expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits)
print(commits, len(commits))
def test_push_to_hub_with_saves_each_n_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-step"),
push_to_hub=True,
hub_token=self._token,
save_strategy="steps",
save_steps=5,
)
trainer.train()
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)]
expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits)
print(commits, len(commits))
@require_torch @require_torch
@require_optuna @require_optuna
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment