Unverified Commit e59d4d01 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Refactor internals for Trainer push_to_hub (#13486)

parent 3dd538c4
...@@ -2238,3 +2238,13 @@ class PushToHubMixin: ...@@ -2238,3 +2238,13 @@ class PushToHubMixin:
commit_message = "add model" commit_message = "add model"
return repo.push_to_hub(commit_message=commit_message) return repo.push_to_hub(commit_message=commit_message)
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = HfApi().whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
...@@ -51,6 +51,8 @@ from torch import nn ...@@ -51,6 +51,8 @@ from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from huggingface_hub import Repository
from . import __version__ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
...@@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check ...@@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
from .file_utils import ( from .file_utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
PushToHubMixin, get_full_repo_name,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_in_notebook, is_in_notebook,
...@@ -2478,15 +2480,17 @@ class Trainer: ...@@ -2478,15 +2480,17 @@ class Trainer:
""" """
if not self.args.should_save: if not self.args.should_save:
return return
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token use_auth_token = True if self.args.hub_token is None else self.args.hub_token
repo_url = PushToHubMixin._get_repo_url_from_name( if self.args.hub_model_id is None:
self.args.push_to_hub_model_id, repo_name = get_full_repo_name(Path(self.args.output_dir).name, token=self.args.hub_token)
organization=self.args.push_to_hub_organization, else:
repo_name = self.args.hub_model_id
self.repo = Repository(
self.args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
self.repo = PushToHubMixin._create_or_get_repo(
self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token
)
# By default, ignore the checkpoint folders # By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")): if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
...@@ -2523,7 +2527,7 @@ class Trainer: ...@@ -2523,7 +2527,7 @@ class Trainer:
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str: def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
""" """
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`. Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters: Parameters:
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`): commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
...@@ -2536,7 +2540,11 @@ class Trainer: ...@@ -2536,7 +2540,11 @@ class Trainer:
""" """
if self.args.should_save: if self.args.should_save:
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs) if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name, **kwargs)
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save. # self.args.should_save.
self.save_model() self.save_model()
......
...@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional ...@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
from .debug_utils import DebugOption from .debug_utils import DebugOption
from .file_utils import ( from .file_utils import (
cached_property, cached_property,
get_full_repo_name,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_available, is_torch_available,
...@@ -335,12 +336,14 @@ class TrainingArguments: ...@@ -335,12 +336,14 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details. details.
push_to_hub_model_id (:obj:`str`, `optional`): hub_model_id (:obj:`str`, `optional`):
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`. The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
Will default to the name of :obj:`output_dir`. name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
push_to_hub_organization (:obj:`str`, `optional`): of with :obj:`"organization_name/model"`.
The name of the organization in with to which push the :class:`~transformers.Trainer`.
push_to_hub_token (:obj:`str`, `optional`): Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
:obj:`output_dir`.
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`. :obj:`huggingface-cli login`.
""" """
...@@ -612,6 +615,11 @@ class TrainingArguments: ...@@ -612,6 +615,11 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."}, metadata={"help": "The path to a folder with a valid checkpoint for your model."},
) )
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
# Deprecated arguments
push_to_hub_model_id: str = field( push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
) )
...@@ -761,8 +769,40 @@ class TrainingArguments: ...@@ -761,8 +769,40 @@ class TrainingArguments:
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self) self.hf_deepspeed_config.trainer_config_process(self)
if self.push_to_hub_model_id is None: if self.push_to_hub_token is not None:
self.push_to_hub_model_id = Path(self.output_dir).name warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_token` instead.",
FutureWarning,
)
self.hub_token = self.push_to_hub_token
if self.push_to_hub_model_id is not None:
self.hub_model_id = get_full_repo_name(
self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
)
if self.push_to_hub_organization is not None:
warnings.warn(
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
f"argument (in this case {self.hub_model_id}).",
FutureWarning,
)
else:
warnings.warn(
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)
elif self.push_to_hub_organization is not None:
self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
warnings.warn(
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)
def __str__(self): def __str__(self):
self_as_dict = asdict(self) self_as_dict = asdict(self)
......
...@@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer"), output_dir=os.path.join(tmp_dir, "test-trainer"),
push_to_hub=True, push_to_hub=True,
push_to_hub_token=self._token, hub_token=self._token,
) )
url = trainer.push_to_hub() url = trainer.push_to_hub()
...@@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-org"), output_dir=os.path.join(tmp_dir, "test-trainer-org"),
push_to_hub=True, push_to_hub=True,
push_to_hub_organization="valid_org", hub_model_id="valid_org/test-trainer-org",
push_to_hub_token=self._token, hub_token=self._token,
) )
url = trainer.push_to_hub() url = trainer.push_to_hub()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment