Unverified Commit 08c84433 authored by Nick Doiron's avatar Nick Doiron Committed by GitHub
Browse files

Accept token in trainer.push_to_hub() (#30093)



* pass token to trainer.push_to_hub

* fmt

* Update src/transformers/trainer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pass token to create_repo, update_folder

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0201f642
...@@ -3909,7 +3909,7 @@ class Trainer: ...@@ -3909,7 +3909,7 @@ class Trainer:
else: else:
return 0 return 0
def init_hf_repo(self): def init_hf_repo(self, token: Optional[str] = None):
""" """
Initializes a git repo in `self.args.hub_model_id`. Initializes a git repo in `self.args.hub_model_id`.
""" """
...@@ -3922,7 +3922,8 @@ class Trainer: ...@@ -3922,7 +3922,8 @@ class Trainer:
else: else:
repo_name = self.args.hub_model_id repo_name = self.args.hub_model_id
repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) token = token if token is not None else self.args.hub_token
repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True)
self.hub_model_id = repo_url.repo_id self.hub_model_id = repo_url.repo_id
self.push_in_progress = None self.push_in_progress = None
...@@ -4067,7 +4068,13 @@ class Trainer: ...@@ -4067,7 +4068,13 @@ class Trainer:
logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.") logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
self.push_in_progress.wait_until_done() self.push_in_progress.wait_until_done()
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: def push_to_hub(
self,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
token: Optional[str] = None,
**kwargs,
) -> str:
""" """
Upload `self.model` and `self.tokenizer` or `self.image_processor` to the 🤗 model hub on the repo `self.args.hub_model_id`. Upload `self.model` and `self.tokenizer` or `self.image_processor` to the 🤗 model hub on the repo `self.args.hub_model_id`.
...@@ -4076,6 +4083,8 @@ class Trainer: ...@@ -4076,6 +4083,8 @@ class Trainer:
Message to commit while pushing. Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`): blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished. Whether the function should return only when the `git push` has finished.
token (`str`, *optional*, defaults to `None`):
Token with write permission to overwrite Trainer's original args.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to [`~Trainer.create_model_card`]. Additional keyword arguments passed along to [`~Trainer.create_model_card`].
...@@ -4089,10 +4098,11 @@ class Trainer: ...@@ -4089,10 +4098,11 @@ class Trainer:
model_name = Path(self.args.output_dir).name model_name = Path(self.args.output_dir).name
else: else:
model_name = self.args.hub_model_id.split("/")[-1] model_name = self.args.hub_model_id.split("/")[-1]
token = token if token is not None else self.args.hub_token
# In case the user calls this method with args.push_to_hub = False # In case the user calls this method with args.push_to_hub = False
if self.hub_model_id is None: if self.hub_model_id is None:
self.init_hf_repo() self.init_hf_repo(token=token)
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save. # self.args.should_save.
...@@ -4125,7 +4135,7 @@ class Trainer: ...@@ -4125,7 +4135,7 @@ class Trainer:
repo_id=self.hub_model_id, repo_id=self.hub_model_id,
folder_path=self.args.output_dir, folder_path=self.args.output_dir,
commit_message=commit_message, commit_message=commit_message,
token=self.args.hub_token, token=token,
run_as_future=not blocking, run_as_future=not blocking,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"], ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment