Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e59d4d01
"examples/legacy/seq2seq/run_eval.py" did not exist on "d9d0f1140b9ffbcffd327f5de918ae8a89961518"
Unverified
Commit
e59d4d01
authored
Sep 09, 2021
by
Sylvain Gugger
Committed by
GitHub
Sep 09, 2021
Browse files
Refactor internals for Trainer push_to_hub (#13486)
parent
3dd538c4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
21 deletions
+79
-21
src/transformers/file_utils.py
src/transformers/file_utils.py
+10
-0
src/transformers/trainer.py
src/transformers/trainer.py
+18
-10
src/transformers/training_args.py
src/transformers/training_args.py
+48
-8
tests/test_trainer.py
tests/test_trainer.py
+3
-3
No files found.
src/transformers/file_utils.py
View file @
e59d4d01
...
...
@@ -2238,3 +2238,13 @@ class PushToHubMixin:
commit_message
=
"add model"
return
repo
.
push_to_hub
(
commit_message
=
commit_message
)
def
get_full_repo_name
(
model_id
:
str
,
organization
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
):
if
token
is
None
:
token
=
HfFolder
.
get_token
()
if
organization
is
None
:
username
=
HfApi
().
whoami
(
token
)[
"name"
]
return
f
"
{
username
}
/
{
model_id
}
"
else
:
return
f
"
{
organization
}
/
{
model_id
}
"
src/transformers/trainer.py
View file @
e59d4d01
...
...
@@ -51,6 +51,8 @@ from torch import nn
from
torch.utils.data
import
DataLoader
,
Dataset
,
IterableDataset
,
RandomSampler
,
SequentialSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
huggingface_hub
import
Repository
from
.
import
__version__
from
.configuration_utils
import
PretrainedConfig
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
...
...
@@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
from
.file_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PushToHubMixin
,
get_full_repo_name
,
is_apex_available
,
is_datasets_available
,
is_in_notebook
,
...
...
@@ -2478,15 +2480,17 @@ class Trainer:
"""
if
not
self
.
args
.
should_save
:
return
use_auth_token
=
True
if
self
.
args
.
push_to_hub_token
is
None
else
self
.
args
.
push_to_hub_token
repo_url
=
PushToHubMixin
.
_get_repo_url_from_name
(
self
.
args
.
push_to_hub_model_id
,
organization
=
self
.
args
.
push_to_hub_organization
,
use_auth_token
=
True
if
self
.
args
.
hub_token
is
None
else
self
.
args
.
hub_token
if
self
.
args
.
hub_model_id
is
None
:
repo_name
=
get_full_repo_name
(
Path
(
self
.
args
.
output_dir
).
name
,
token
=
self
.
args
.
hub_token
)
else
:
repo_name
=
self
.
args
.
hub_model_id
self
.
repo
=
Repository
(
self
.
args
.
output_dir
,
clone_from
=
repo_name
,
use_auth_token
=
use_auth_token
,
)
self
.
repo
=
PushToHubMixin
.
_create_or_get_repo
(
self
.
args
.
output_dir
,
repo_url
=
repo_url
,
use_auth_token
=
use_auth_token
)
# By default, ignore the checkpoint folders
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
args
.
output_dir
,
".gitignore"
)):
...
...
@@ -2523,7 +2527,7 @@ class Trainer:
def
push_to_hub
(
self
,
commit_message
:
Optional
[
str
]
=
"add model"
,
**
kwargs
)
->
str
:
"""
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.
push_to_
hub_model_id`.
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters:
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
...
...
@@ -2536,7 +2540,11 @@ class Trainer:
"""
if
self
.
args
.
should_save
:
self
.
create_model_card
(
model_name
=
self
.
args
.
push_to_hub_model_id
,
**
kwargs
)
if
self
.
args
.
hub_model_id
is
None
:
model_name
=
Path
(
self
.
args
.
output_dir
).
name
else
:
model_name
=
self
.
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
self
.
create_model_card
(
model_name
=
model_name
,
**
kwargs
)
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
self
.
save_model
()
...
...
src/transformers/training_args.py
View file @
e59d4d01
...
...
@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
from
.debug_utils
import
DebugOption
from
.file_utils
import
(
cached_property
,
get_full_repo_name
,
is_sagemaker_dp_enabled
,
is_sagemaker_mp_enabled
,
is_torch_available
,
...
...
@@ -335,12 +336,14 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
push_to_hub_model_id (:obj:`str`, `optional`):
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`.
Will default to the name of :obj:`output_dir`.
push_to_hub_organization (:obj:`str`, `optional`):
The name of the organization in with to which push the :class:`~transformers.Trainer`.
push_to_hub_token (:obj:`str`, `optional`):
hub_model_id (:obj:`str`, `optional`):
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
of with :obj:`"organization_name/model"`.
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
:obj:`output_dir`.
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
"""
...
...
@@ -612,6 +615,11 @@ class TrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"The path to a folder with a valid checkpoint for your model."
},
)
hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
# Deprecated arguments
push_to_hub_model_id
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to which push the `Trainer`."
}
)
...
...
@@ -761,8 +769,40 @@ class TrainingArguments:
self
.
hf_deepspeed_config
=
HfTrainerDeepSpeedConfig
(
self
.
deepspeed
)
self
.
hf_deepspeed_config
.
trainer_config_process
(
self
)
if
self
.
push_to_hub_model_id
is
None
:
self
.
push_to_hub_model_id
=
Path
(
self
.
output_dir
).
name
if
self
.
push_to_hub_token
is
not
None
:
warnings
.
warn
(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_token` instead."
,
FutureWarning
,
)
self
.
hub_token
=
self
.
push_to_hub_token
if
self
.
push_to_hub_model_id
is
not
None
:
self
.
hub_model_id
=
get_full_repo_name
(
self
.
push_to_hub_model_id
,
organization
=
self
.
push_to_hub_organization
,
token
=
self
.
hub_token
)
if
self
.
push_to_hub_organization
is
not
None
:
warnings
.
warn
(
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
f
"argument (in this case
{
self
.
hub_model_id
}
)."
,
FutureWarning
,
)
else
:
warnings
.
warn
(
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f
"
{
self
.
hub_model_id
}
)."
,
FutureWarning
,
)
elif
self
.
push_to_hub_organization
is
not
None
:
self
.
hub_model_id
=
f
"
{
self
.
push_to_hub_organization
}
/
{
Path
(
self
.
output_dir
).
name
}
"
warnings
.
warn
(
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f
"
{
self
.
hub_model_id
}
)."
,
FutureWarning
,
)
def
__str__
(
self
):
self_as_dict
=
asdict
(
self
)
...
...
tests/test_trainer.py
View file @
e59d4d01
...
...
@@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer
=
get_regression_trainer
(
output_dir
=
os
.
path
.
join
(
tmp_dir
,
"test-trainer"
),
push_to_hub
=
True
,
push_to_
hub_token
=
self
.
_token
,
hub_token
=
self
.
_token
,
)
url
=
trainer
.
push_to_hub
()
...
...
@@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer
=
get_regression_trainer
(
output_dir
=
os
.
path
.
join
(
tmp_dir
,
"test-trainer-org"
),
push_to_hub
=
True
,
push_to_hub_organization
=
"valid_
org"
,
push_to_
hub_token
=
self
.
_token
,
hub_model_id
=
"valid_org/test-trainer-
org"
,
hub_token
=
self
.
_token
,
)
url
=
trainer
.
push_to_hub
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment