Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
da5ef25d
Unverified
Commit
da5ef25d
authored
Jan 27, 2022
by
Sylvain Gugger
Committed by
GitHub
Jan 27, 2022
Browse files
Push to hub save (#15327)
* Adapt doc and push at every save * style
parent
9f831bde
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
7 deletions
+21
-7
src/transformers/trainer.py
src/transformers/trainer.py
+9
-4
src/transformers/training_args.py
src/transformers/training_args.py
+12
-3
No files found.
src/transformers/trainer.py
View file @
da5ef25d
...
@@ -966,7 +966,7 @@ class Trainer:
...
@@ -966,7 +966,7 @@ class Trainer:
return
return
with
tune
.
checkpoint_dir
(
step
=
self
.
state
.
global_step
)
as
checkpoint_dir
:
with
tune
.
checkpoint_dir
(
step
=
self
.
state
.
global_step
)
as
checkpoint_dir
:
output_dir
=
os
.
path
.
join
(
checkpoint_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
self
.
state
.
global_step
}
"
)
output_dir
=
os
.
path
.
join
(
checkpoint_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
self
.
state
.
global_step
}
"
)
self
.
save_model
(
output_dir
)
self
.
save_model
(
output_dir
,
_internal_call
=
True
)
if
self
.
args
.
should_save
:
if
self
.
args
.
should_save
:
self
.
state
.
save_to_json
(
os
.
path
.
join
(
output_dir
,
TRAINER_STATE_NAME
))
self
.
state
.
save_to_json
(
os
.
path
.
join
(
output_dir
,
TRAINER_STATE_NAME
))
torch
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
OPTIMIZER_NAME
))
torch
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
OPTIMIZER_NAME
))
...
@@ -1634,7 +1634,7 @@ class Trainer:
...
@@ -1634,7 +1634,7 @@ class Trainer:
self
.
store_flos
()
self
.
store_flos
()
output_dir
=
os
.
path
.
join
(
run_dir
,
checkpoint_folder
)
output_dir
=
os
.
path
.
join
(
run_dir
,
checkpoint_folder
)
self
.
save_model
(
output_dir
)
self
.
save_model
(
output_dir
,
_internal_call
=
True
)
if
self
.
deepspeed
:
if
self
.
deepspeed
:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_fp16_weights_on_model_save` is True
# config `stage3_gather_fp16_weights_on_model_save` is True
...
@@ -2002,7 +2002,7 @@ class Trainer:
...
@@ -2002,7 +2002,7 @@ class Trainer:
else
:
else
:
return
self
.
args
.
process_index
==
0
return
self
.
args
.
process_index
==
0
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
,
_internal_call
:
bool
=
False
):
"""
"""
Will save the model, so you can reload it using `from_pretrained()`.
Will save the model, so you can reload it using `from_pretrained()`.
...
@@ -2051,6 +2051,10 @@ class Trainer:
...
@@ -2051,6 +2051,10 @@ class Trainer:
elif
self
.
args
.
should_save
:
elif
self
.
args
.
should_save
:
self
.
_save
(
output_dir
)
self
.
_save
(
output_dir
)
# Push to the Hub when `save_model` is called by the user.
if
self
.
args
.
push_to_hub
and
not
_internal_call
:
self
.
push_to_hub
(
commit_message
=
"Model save"
)
def
_save_tpu
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
def
_save_tpu
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
logger
.
info
(
f
"Saving model checkpoint to
{
output_dir
}
"
)
logger
.
info
(
f
"Saving model checkpoint to
{
output_dir
}
"
)
...
@@ -2768,9 +2772,10 @@ class Trainer:
...
@@ -2768,9 +2772,10 @@ class Trainer:
model_name
=
Path
(
self
.
args
.
output_dir
).
name
model_name
=
Path
(
self
.
args
.
output_dir
).
name
else
:
else
:
model_name
=
self
.
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
model_name
=
self
.
args
.
hub_model_id
.
split
(
"/"
)[
-
1
]
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
# self.args.should_save.
self
.
save_model
()
self
.
save_model
(
_internal_call
=
True
)
# Only push from one node.
# Only push from one node.
if
not
self
.
is_world_process_zero
():
if
not
self
.
is_world_process_zero
():
...
...
src/transformers/training_args.py
View file @
da5ef25d
...
@@ -365,9 +365,18 @@ class TrainingArguments:
...
@@ -365,9 +365,18 @@ class TrainingArguments:
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
down the training and evaluation speed.
push_to_hub (`bool`, *optional*, defaults to `False`):
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to upload the trained model to the hub after training. If this is activated, and
Whether or not to push the model to the Hub every time the model is saved. If this is activated,
`output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
`output_dir` will begin a git directory synced with the the repo (determined by `hub_model_id`) and the
content will be pushed each time a save is triggered (depneding on your `save_strategy`). Calling
[`~Trainer.save_model`] will also trigger a push
<Tip warning={true}>
If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
pushed.
pushed.
</Tip>
resume_from_checkpoint (`str`, *optional*):
resume_from_checkpoint (`str`, *optional*):
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
...
@@ -384,7 +393,7 @@ class TrainingArguments:
...
@@ -384,7 +393,7 @@ class TrainingArguments:
Defines the scope of what is pushed to the Hub and when. Possible values are:
Defines the scope of what is pushed to the Hub and when. Possible values are:
- `"end"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a
- `"end"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a
draft of a model card
at
the
end of training
.
draft of a model card
when
the
[`~Trainer.save_model`] method is called
.
- `"every_save"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and
- `"every_save"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and
a draft of a model card each time there is a model save. The pushes are asynchronous to not block
a draft of a model card each time there is a model save. The pushes are asynchronous to not block
training, and in case the save are very frequent, a new push is only attempted if the previous one is
training, and in case the save are very frequent, a new push is only attempted if the previous one is
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment