Unverified Commit 3081d386 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Push to hub when saving checkpoints (#13503)

* Push to hub when saving checkpoints

* Add model card

* Revert partial model card

* Small fix for checkpoint

* Add tests

* Add documentation

* Fix tests

* Bump huggingface_hub

* Fix test
parent 51e5eca6
......@@ -119,6 +119,29 @@ TFTrainingArguments
:members:
Checkpoints
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
By default, :class:`~transformers.Trainer` will save all checkpoints in the :obj:`output_dir` you set in the
:class:`~transformers.TrainingArguments` you are using. Those will go in subfolder named :obj:`checkpoint-xxx` with xxx
being the step at which the training was at.
Resuming training from a checkpoint can be done when calling :meth:`~transformers.Trainer.train` with either:
- :obj:`resume_from_checkpoint=True` which will resume training from the latest checkpoint
- :obj:`resume_from_checkpoint=checkpoint_dir` which will resume training from the specific checkpoint in the directory
passed.
In addition, you can easily save your checkpoints on the Model Hub when using :obj:`push_to_hub=True`. By default, all
the models saved in intermediate checkpoints are saved in different commits, but not the optimizer state. You can adapt
the :obj:`hub-strategy` value of your :class:`~transformers.TrainingArguments` to either:
- :obj:`"checkpoint"`: the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to
resume training easily with :obj:`trainer.train(resume_from_checkpoint="output_dir/last-checkpoint")`.
- :obj:`"all_checkpoints"`: all checkpoints are pushed like they appear in the output folder (so you will get one
checkpoint folder per folder in your final repository)
Logging
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -100,7 +100,7 @@ _deps = [
"flax>=0.3.4",
"fugashi>=1.0",
"GitPython<3.1.19",
"huggingface-hub>=0.0.12",
"huggingface-hub>=0.0.17",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
......
......@@ -18,7 +18,7 @@ deps = {
"flax": "flax>=0.3.4",
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"huggingface-hub": "huggingface-hub>=0.0.12",
"huggingface-hub": "huggingface-hub>=0.0.17",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
......
......@@ -110,6 +110,8 @@ from .trainer_utils import (
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
ShardedDDPOption,
TrainerMemoryTracker,
......@@ -180,6 +182,14 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
......@@ -389,6 +399,12 @@ class Trainer:
# Create clone of distant repo and output directory if needed
if self.args.push_to_hub:
self.init_git_repo()
# In case of pull, we need to make sure every process has the latest.
if is_torch_tpu_available():
xm.rendezvous("init git repo")
elif args.local_rank != -1:
dist.barrier()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
......@@ -901,9 +917,9 @@ class Trainer:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
def call_model_init(self, trial=None):
model_init_argcount = number_of_arguments(self.model_init)
......@@ -1183,9 +1199,9 @@ class Trainer:
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, "trainer_state.json")
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
......@@ -1520,9 +1536,9 @@ class Trainer:
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.dp_rank() == 0:
......@@ -1530,20 +1546,20 @@ class Trainer:
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
......@@ -1563,7 +1579,7 @@ class Trainer:
# Save the Trainer state
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
# Save RNG state in non-distributed training
rng_states = {
......@@ -1590,6 +1606,9 @@ class Trainer:
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
......@@ -1603,15 +1622,15 @@ class Trainer:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
os.path.join(checkpoint, "scheduler.pt")
if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
os.path.join(checkpoint, SCHEDULER_NAME)
):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
......@@ -1622,13 +1641,13 @@ class Trainer:
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location)
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
if self.use_amp and os.path.isfile(os.path.join(checkpoint, "scaler.pt")):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, "scaler.pt")))
if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
def hyperparameter_search(
self,
......@@ -1908,7 +1927,7 @@ class Trainer:
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
......@@ -1953,7 +1972,7 @@ class Trainer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def store_flos(self):
# Storing the number of floating-point operations that went into the model
......@@ -2476,9 +2495,9 @@ class Trainer:
def init_git_repo(self):
"""
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
Initializes a git repo in :obj:`self.args.hub_model_id`.
"""
if not self.args.should_save:
if not self.is_world_process_zero():
return
use_auth_token = True if self.args.hub_token is None else self.args.hub_token
if self.args.hub_model_id is None:
......@@ -2486,17 +2505,36 @@ class Trainer:
else:
repo_name = self.args.hub_model_id
self.repo = Repository(
self.args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
try:
self.repo = Repository(
self.args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
except EnvironmentError:
if self.args.overwrite_output_dir:
# Try again after wiping output_dir
shutil.rmtree(self.args.output_dir)
self.repo = Repository(
self.args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
else:
raise
self.repo.git_pull()
# By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
if (
not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
):
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"])
self.push_in_progress = None
def create_model_card(
self,
language: Optional[str] = None,
......@@ -2525,18 +2563,61 @@ class Trainer:
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
f.write(model_card)
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
return
# If we haven't finished the last push, we don't do this one.
if self.push_in_progress is not None and not self.push_in_progress.is_done:
return
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
try:
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
# Temporarily move the checkpoint just saved for the push
tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
# subfolder.
if os.path.isdir(tmp_checkpoint):
shutil.rmtree(tmp_checkpoint)
shutil.move(checkpoint_folder, tmp_checkpoint)
if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
else:
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
_, self.push_in_progress = self.repo.push_to_hub(commit_message=commit_message, blocking=False)
finally:
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
# Move back the checkpoint to its place
shutil.move(tmp_checkpoint, checkpoint_folder)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters:
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
commit_message (:obj:`str`, `optional`, defaults to :obj:`"End of training"`):
Message to commit while pushing.
blocking (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether the function should return only when the :obj:`git push` has finished.
kwargs:
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
Returns:
The url of the commit of your model in the given repository.
The url of the commit of your model in the given repository if :obj:`blocking=False`, a tuple with the url
of the commit and an object to track the progress of the commit if :obj:`blocking=True`
"""
if self.args.should_save:
......@@ -2553,7 +2634,7 @@ class Trainer:
if not self.is_world_process_zero():
return
return self.repo.push_to_hub(commit_message=commit_message)
return self.repo.push_to_hub(commit_message=commit_message, blocking=blocking)
#
# Deprecated code
......
......@@ -125,6 +125,13 @@ class EvaluationStrategy(ExplicitEnum):
EPOCH = "epoch"
class HubStrategy(ExplicitEnum):
END = "end"
EVERY_SAVE = "every_save"
CHECKPOINT = "checkpoint"
ALL_CHECKPOINTS = "all_checkpoints"
class BestRun(NamedTuple):
"""
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
......
......@@ -32,7 +32,7 @@ from .file_utils import (
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
from .utils import logging
......@@ -343,6 +343,22 @@ class TrainingArguments:
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
:obj:`output_dir`.
hub_strategy (:obj:`str` or :class:`~transformers.trainer_utils.HubStrategy`, `optional`, defaults to :obj:`"every_save"`):
Defines the scope of what is pushed to the Hub and when. Possible values are:
- :obj:`"end"`: push the model, its configuration, the tokenizer (if passed along to the
:class:`~transformers.Trainer`) and a draft of a model card at the end of training.
- :obj:`"every_save"`: push the model, its configuration, the tokenizer (if passed along to the
:class:`~transformers.Trainer`) and a draft of a model card each time there is a model save. The pushes
are asynchronous to not block training, and in case the save are very frequent, a new push is only
attempted if the previous one is finished. A last push is made with the final model at the end of
training.
- :obj:`"checkpoint"`: like :obj:`"every_save"` but the latest checkpoint is also pushed in a subfolder
named last-checkpoint, allowing you to resume training easily with
:obj:`trainer.train(resume_from_checkpoint="last-checkpoint")`.
- :obj:`"all_checkpoints"`: like :obj:`"checkpoint"` but all checkpoints are pushed like they appear in the
output folder (so you will get one checkpoint folder per folder in your final repository)
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
......@@ -618,6 +634,10 @@ class TrainingArguments:
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
# Deprecated arguments
push_to_hub_model_id: str = field(
......@@ -668,6 +688,7 @@ class TrainingArguments:
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = IntervalStrategy(self.save_strategy)
self.hub_strategy = HubStrategy(self.hub_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
......
......@@ -18,13 +18,14 @@ import gc
import os
import random
import re
import subprocess
import tempfile
import unittest
from pathlib import Path
import numpy as np
from huggingface_hub import HfApi
from huggingface_hub import HfApi, Repository
from requests.exceptions import HTTPError
from transformers import (
AutoTokenizer,
......@@ -1284,10 +1285,11 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-trainer")
except HTTPError:
pass
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
try:
cls._api.delete_repo(token=cls._token, name=model)
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
......@@ -1336,6 +1338,55 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
self.assertEqual(model.a.item(), trainer.model.a.item())
self.assertEqual(model.b.item(), trainer.model.b.item())
def get_commit_history(self, repo):
commit_logs = subprocess.run(
"git log".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=repo,
).stdout
commits = commit_logs.split("\n\n")[1::2]
return [commit.strip() for commit in commits]
def test_push_to_hub_with_saves_each_epoch(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-epoch"),
push_to_hub=True,
hub_token=self._token,
save_strategy="epoch",
)
trainer.train()
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)]
expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits)
print(commits, len(commits))
def test_push_to_hub_with_saves_each_n_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-step"),
push_to_hub=True,
hub_token=self._token,
save_strategy="steps",
save_steps=5,
)
trainer.train()
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)]
expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits)
print(commits, len(commits))
@require_torch
@require_optuna
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment