Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e59d4d01
Unverified
Commit
e59d4d01
authored
Sep 09, 2021
by
Sylvain Gugger
Committed by
GitHub
Sep 09, 2021
Browse files
Refactor internals for Trainer push_to_hub (#13486)
parent
3dd538c4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
21 deletions
+79
-21
src/transformers/file_utils.py
src/transformers/file_utils.py
+10
-0
src/transformers/trainer.py
src/transformers/trainer.py
+18
-10
src/transformers/training_args.py
src/transformers/training_args.py
+48
-8
tests/test_trainer.py
tests/test_trainer.py
+3
-3
No files found.
src/transformers/file_utils.py
View file @
e59d4d01
...
@@ -2238,3 +2238,13 @@ class PushToHubMixin:
...
@@ -2238,3 +2238,13 @@ class PushToHubMixin:
commit_message
=
"add model"
commit_message
=
"add model"
return
repo
.
push_to_hub
(
commit_message
=
commit_message
)
return
repo
.
push_to_hub
(
commit_message
=
commit_message
)
def
get_full_repo_name
(
model_id
:
str
,
organization
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
):
if
token
is
None
:
token
=
HfFolder
.
get_token
()
if
organization
is
None
:
username
=
HfApi
().
whoami
(
token
)[
"name"
]
return
f
"
{
username
}
/
{
model_id
}
"
else
:
return
f
"
{
organization
}
/
{
model_id
}
"
src/transformers/trainer.py
View file @
e59d4d01
...
@@ -51,6 +51,8 @@ from torch import nn
...
@@ -51,6 +51,8 @@ from torch import nn
from
torch.utils.data
import
DataLoader
,
Dataset
,
IterableDataset
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
DataLoader
,
Dataset
,
IterableDataset
,
RandomSampler
,
SequentialSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
huggingface_hub
import
Repository
from
.
import
__version__
from
.
import
__version__
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
...
@@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
...
@@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
from
.file_utils
import
(
from
.file_utils
import
(
CONFIG_NAME
,
CONFIG_NAME
,
WEIGHTS_NAME
,
WEIGHTS_NAME
,
PushToHubMixin
,
get_full_repo_name
,
is_apex_available
,
is_apex_available
,
is_datasets_available
,
is_datasets_available
,
is_in_notebook
,
is_in_notebook
,
...
@@ -2478,15 +2480,17 @@ class Trainer:
...
@@ -2478,15 +2480,17 @@ class Trainer:
"""
"""
if
not
self
.
args
.
should_save
:
if
not
self
.
args
.
should_save
:
return
return
use_auth_token
=
True
if
self
.
args
.
push_to_hub_token
is
None
else
self
.
args
.
push_to_hub_token
use_auth_token
=
True
if
self
.
args
.
hub_token
is
None
else
self
.
args
.
hub_token
repo_url
=
PushToHubMixin
.
_get_repo_url_from_name
(
if
self
.
args
.
hub_model_id
is
None
:
self
.
args
.
push_to_hub_model_id
,
repo_name
=
get_full_repo_name
(
Path
(
self
.
args
.
output_dir
).
name
,
token
=
self
.
args
.
hub_token
)
organization
=
self
.
args
.
push_to_hub_organization
,
else
:
repo_name
=
self
.
args
.
hub_model_id
self
.
repo
=
Repository
(
self
.
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
)
)
self
.
repo
=
PushToHubMixin
.
_create_or_get_repo
(
self
.
args
.
output_dir
,
repo_url
=
repo_url
,
use_auth_token
=
use_auth_token
)
# By default, ignore the checkpoint folders
# By default, ignore the checkpoint folders
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
args
.
output_dir
,
".gitignore"
)):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
args
.
output_dir
,
".gitignore"
)):
...
@@ -2523,7 +2527,7 @@ class Trainer:
...
@@ -2523,7 +2527,7 @@ class Trainer:
def
push_to_hub
(
self
,
commit_message
:
Optional
[
str
]
=
"add model"
,
**
kwargs
)
->
str
:
def
push_to_hub
(
self
,
commit_message
:
Optional
[
str
]
=
"add model"
,
**
kwargs
)
->
str
:
"""
"""
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.
push_to_
hub_model_id`.
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters:
Parameters:
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
...
@@ -2536,7 +2540,11 @@ class Trainer:
...
@@ -2536,7 +2540,11 @@ class Trainer:
"""
"""
if
self
.
args
.
should_save
:
if
self
.
args
.
should_save
:
self
.
create_model_card
(
model_name
=
self
.
args
.
push_to_hub_model_id
,
**
kwargs
)
if
self
.
args
.
hub_model_id
is
None
:
model_name
=
Path
(
self
.
args
.
output_dir
).
name
else
:
model_name
=
self
.
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
self
.
create_model_card
(
model_name
=
model_name
,
**
kwargs
)
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
# self.args.should_save.
self
.
save_model
()
self
.
save_model
()
...
...
src/transformers/training_args.py
View file @
e59d4d01
...
@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
...
@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
from
.debug_utils
import
DebugOption
from
.debug_utils
import
DebugOption
from
.file_utils
import
(
from
.file_utils
import
(
cached_property
,
cached_property
,
get_full_repo_name
,
is_sagemaker_dp_enabled
,
is_sagemaker_dp_enabled
,
is_sagemaker_mp_enabled
,
is_sagemaker_mp_enabled
,
is_torch_available
,
is_torch_available
,
...
@@ -335,12 +336,14 @@ class TrainingArguments:
...
@@ -335,12 +336,14 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
details.
push_to_hub_model_id (:obj:`str`, `optional`):
hub_model_id (:obj:`str`, `optional`):
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`.
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
Will default to the name of :obj:`output_dir`.
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
push_to_hub_organization (:obj:`str`, `optional`):
of with :obj:`"organization_name/model"`.
The name of the organization in with to which push the :class:`~transformers.Trainer`.
push_to_hub_token (:obj:`str`, `optional`):
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
:obj:`output_dir`.
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
:obj:`huggingface-cli login`.
"""
"""
...
@@ -612,6 +615,11 @@ class TrainingArguments:
...
@@ -612,6 +615,11 @@ class TrainingArguments:
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"The path to a folder with a valid checkpoint for your model."
},
metadata
=
{
"help"
:
"The path to a folder with a valid checkpoint for your model."
},
)
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
# Deprecated arguments
push_to_hub_model_id
:
str
=
field
(
push_to_hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to which push the `Trainer`."
}
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to which push the `Trainer`."
}
)
)
...
@@ -761,8 +769,40 @@ class TrainingArguments:
...
@@ -761,8 +769,40 @@ class TrainingArguments:
self
.
hf_deepspeed_config
=
HfTrainerDeepSpeedConfig
(
self
.
deepspeed
)
self
.
hf_deepspeed_config
=
HfTrainerDeepSpeedConfig
(
self
.
deepspeed
)
self
.
hf_deepspeed_config
.
trainer_config_process
(
self
)
self
.
hf_deepspeed_config
.
trainer_config_process
(
self
)
if
self
.
push_to_hub_model_id
is
None
:
if
self
.
push_to_hub_token
is
not
None
:
self
.
push_to_hub_model_id
=
Path
(
self
.
output_dir
).
name
warnings
.
warn
(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_token` instead."
,
FutureWarning
,
)
self
.
hub_token
=
self
.
push_to_hub_token
if
self
.
push_to_hub_model_id
is
not
None
:
self
.
hub_model_id
=
get_full_repo_name
(
self
.
push_to_hub_model_id
,
organization
=
self
.
push_to_hub_organization
,
token
=
self
.
hub_token
)
if
self
.
push_to_hub_organization
is
not
None
:
warnings
.
warn
(
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
f
"argument (in this case
{
self
.
hub_model_id
}
)."
,
FutureWarning
,
)
else
:
warnings
.
warn
(
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f
"
{
self
.
hub_model_id
}
)."
,
FutureWarning
,
)
elif
self
.
push_to_hub_organization
is
not
None
:
self
.
hub_model_id
=
f
"
{
self
.
push_to_hub_organization
}
/
{
Path
(
self
.
output_dir
).
name
}
"
warnings
.
warn
(
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f
"
{
self
.
hub_model_id
}
)."
,
FutureWarning
,
)
def
__str__
(
self
):
def
__str__
(
self
):
self_as_dict
=
asdict
(
self
)
self_as_dict
=
asdict
(
self
)
...
...
tests/test_trainer.py
View file @
e59d4d01
...
@@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
...
@@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer
=
get_regression_trainer
(
trainer
=
get_regression_trainer
(
output_dir
=
os
.
path
.
join
(
tmp_dir
,
"test-trainer"
),
output_dir
=
os
.
path
.
join
(
tmp_dir
,
"test-trainer"
),
push_to_hub
=
True
,
push_to_hub
=
True
,
push_to_
hub_token
=
self
.
_token
,
hub_token
=
self
.
_token
,
)
)
url
=
trainer
.
push_to_hub
()
url
=
trainer
.
push_to_hub
()
...
@@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
...
@@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer
=
get_regression_trainer
(
trainer
=
get_regression_trainer
(
output_dir
=
os
.
path
.
join
(
tmp_dir
,
"test-trainer-org"
),
output_dir
=
os
.
path
.
join
(
tmp_dir
,
"test-trainer-org"
),
push_to_hub
=
True
,
push_to_hub
=
True
,
push_to_hub_organization
=
"valid_
org"
,
hub_model_id
=
"valid_org/test-trainer-
org"
,
push_to_
hub_token
=
self
.
_token
,
hub_token
=
self
.
_token
,
)
)
url
=
trainer
.
push_to_hub
()
url
=
trainer
.
push_to_hub
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment