Unverified Commit 32634bce authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make username optional in hub_model_id (#13940)

parent 708ffff6
...@@ -39,10 +39,12 @@ class PushToHubCallback(Callback): ...@@ -39,10 +39,12 @@ class PushToHubCallback(Callback):
tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`): tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`):
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
hub_model_id (:obj:`str`, `optional`): hub_model_id (:obj:`str`, `optional`):
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
of with :obj:`"organization_name/model"`. Will default to :obj:`user_name/output_dir_name` with for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member of with
`output_dir_name` being the name of :obj:`output_dir`. :obj:`"organization_name/model"`.
Will default to to the name of :obj:`output_dir`.
hub_token (:obj:`str`, `optional`): hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`. :obj:`huggingface-cli login`.
...@@ -56,11 +58,11 @@ class PushToHubCallback(Callback): ...@@ -56,11 +58,11 @@ class PushToHubCallback(Callback):
self.save_steps = save_steps self.save_steps = save_steps
output_dir = Path(output_dir) output_dir = Path(output_dir)
if hub_model_id is None: if hub_model_id is None:
repo_name = get_full_repo_name(output_dir.absolute().name, token=hub_token) hub_model_id = output_dir.absolute().name
else: if "/" not in hub_model_id:
repo_name = hub_model_id hub_model_id = get_full_repo_name(hub_model_id, token=hub_token)
self.output_dir = output_dir self.output_dir = output_dir
self.repo = Repository(str(output_dir), clone_from=repo_name) self.repo = Repository(str(output_dir), clone_from=hub_model_id)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.last_job = None self.last_job = None
......
...@@ -2542,9 +2542,11 @@ class Trainer: ...@@ -2542,9 +2542,11 @@ class Trainer:
return return
use_auth_token = True if self.args.hub_token is None else self.args.hub_token use_auth_token = True if self.args.hub_token is None else self.args.hub_token
if self.args.hub_model_id is None: if self.args.hub_model_id is None:
repo_name = get_full_repo_name(Path(self.args.output_dir).name, token=self.args.hub_token) repo_name = Path(self.args.output_dir).absolute().name
else: else:
repo_name = self.args.hub_model_id repo_name = self.args.hub_model_id
if "/" not in repo_name:
repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
try: try:
self.repo = Repository( self.repo = Repository(
......
...@@ -349,12 +349,13 @@ class TrainingArguments: ...@@ -349,12 +349,13 @@ class TrainingArguments:
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details. details.
hub_model_id (:obj:`str`, `optional`): hub_model_id (:obj:`str`, `optional`):
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
of with :obj:`"organization_name/model"`. for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member of with
:obj:`"organization_name/model"`. Will default to :obj:`user_name/output_dir_name` with `output_dir_name`
being the name of :obj:`output_dir`.
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of Will default to to the name of :obj:`output_dir`.
:obj:`output_dir`.
hub_strategy (:obj:`str` or :class:`~transformers.trainer_utils.HubStrategy`, `optional`, defaults to :obj:`"every_save"`): hub_strategy (:obj:`str` or :class:`~transformers.trainer_utils.HubStrategy`, `optional`, defaults to :obj:`"every_save"`):
Defines the scope of what is pushed to the Hub and when. Possible values are: Defines the scope of what is pushed to the Hub and when. Possible values are:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment