Unverified Commit da5ef25d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Push to hub save (#15327)

* Adapt doc and push at every save

* style
parent 9f831bde
......@@ -966,7 +966,7 @@ class Trainer:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
......@@ -1634,7 +1634,7 @@ class Trainer:
self.store_flos()
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir)
self.save_model(output_dir, _internal_call=True)
if self.deepspeed:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_fp16_weights_on_model_save` is True
......@@ -2002,7 +2002,7 @@ class Trainer:
else:
return self.args.process_index == 0
def save_model(self, output_dir: Optional[str] = None):
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""
Will save the model, so you can reload it using `from_pretrained()`.
......@@ -2051,6 +2051,10 @@ class Trainer:
elif self.args.should_save:
self._save(output_dir)
# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info(f"Saving model checkpoint to {output_dir}")
......@@ -2768,9 +2772,10 @@ class Trainer:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
self.save_model()
self.save_model(_internal_call=True)
# Only push from one node.
if not self.is_world_process_zero():
......
......@@ -365,9 +365,18 @@ class TrainingArguments:
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to upload the trained model to the hub after training. If this is activated, and
`output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
Whether or not to push the model to the Hub every time the model is saved. If this is activated,
`output_dir` will begin a git directory synced with the the repo (determined by `hub_model_id`) and the
content will be pushed each time a save is triggered (depneding on your `save_strategy`). Calling
[`~Trainer.save_model`] will also trigger a push
<Tip warning={true}>
If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
pushed.
</Tip>
resume_from_checkpoint (`str`, *optional*):
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
......@@ -384,7 +393,7 @@ class TrainingArguments:
Defines the scope of what is pushed to the Hub and when. Possible values are:
- `"end"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a
draft of a model card at the end of training.
draft of a model card when the [`~Trainer.save_model`] method is called.
- `"every_save"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and
a draft of a model card each time there is a model save. The pushes are asynchronous to not block
training, and in case the save are very frequent, a new push is only attempted if the previous one is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment