Unverified Commit 161c0a2e authored by Nicholas Broad's avatar Nicholas Broad Committed by GitHub
Browse files

Private repo TrainingArgument (#16707)



* private repo argument to trainer

* format
Co-authored-by: default avatarNicholas Broad <nicholas@nmbroad.com>
parent d4b3e359
...@@ -2770,6 +2770,7 @@ class Trainer: ...@@ -2770,6 +2770,7 @@ class Trainer:
self.args.output_dir, self.args.output_dir,
clone_from=repo_name, clone_from=repo_name,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
private=self.args.hub_private_repo,
) )
except EnvironmentError: except EnvironmentError:
if self.args.overwrite_output_dir and at_init: if self.args.overwrite_output_dir and at_init:
......
...@@ -416,6 +416,8 @@ class TrainingArguments: ...@@ -416,6 +416,8 @@ class TrainingArguments:
hub_token (`str`, *optional*): hub_token (`str`, *optional*):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
`huggingface-cli login`. `huggingface-cli login`.
hub_private_repo (`bool`, *optional*, defaults to `False`):
If True, the Hub repo will be set to private.
gradient_checkpointing (`bool`, *optional*, defaults to `False`): gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass. If True, use gradient checkpointing to save memory at the expense of slower backward pass.
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
...@@ -738,6 +740,7 @@ class TrainingArguments: ...@@ -738,6 +740,7 @@ class TrainingArguments:
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
) )
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
gradient_checkpointing: bool = field( gradient_checkpointing: bool = field(
default=False, default=False,
metadata={ metadata={
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment