Unverified Commit 8e5d4e49 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

`Trainer.push_to_hub` always tries to push to the Hub (#15463)

parent 37800f13
...@@ -403,7 +403,7 @@ class Trainer: ...@@ -403,7 +403,7 @@ class Trainer:
# Create clone of distant repo and output directory if needed # Create clone of distant repo and output directory if needed
if self.args.push_to_hub: if self.args.push_to_hub:
self.init_git_repo() self.init_git_repo(at_init=True)
# In case of pull, we need to make sure every process has the latest. # In case of pull, we need to make sure every process has the latest.
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("init git repo") xm.rendezvous("init git repo")
...@@ -2657,9 +2657,15 @@ class Trainer: ...@@ -2657,9 +2657,15 @@ class Trainer:
else: else:
return 0 return 0
def init_git_repo(self): def init_git_repo(self, at_init: bool = False):
""" """
Initializes a git repo in `self.args.hub_model_id`. Initializes a git repo in `self.args.hub_model_id`.
Args:
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
out.
""" """
if not self.is_world_process_zero(): if not self.is_world_process_zero():
return return
...@@ -2678,7 +2684,7 @@ class Trainer: ...@@ -2678,7 +2684,7 @@ class Trainer:
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
except EnvironmentError: except EnvironmentError:
if self.args.overwrite_output_dir: if self.args.overwrite_output_dir and at_init:
# Try again after wiping output_dir # Try again after wiping output_dir
shutil.rmtree(self.args.output_dir) shutil.rmtree(self.args.output_dir)
self.repo = Repository( self.repo = Repository(
...@@ -2790,6 +2796,10 @@ class Trainer: ...@@ -2790,6 +2796,10 @@ class Trainer:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
the commit and an object to track the progress of the commit if `blocking=True` the commit and an object to track the progress of the commit if `blocking=True`
""" """
# If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
# it might fail.
if not hasattr(self, "repo"):
self.init_git_repo()
if self.args.should_save: if self.args.should_save:
if self.args.hub_model_id is None: if self.args.hub_model_id is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment