Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7c10dd22
"examples/seq2seq/test_tatoeba_conversion.py" did not exist on "aacac8f708a9f139f4cf976e76f40be23ef68b57"
Unverified
Commit
7c10dd22
authored
Dec 01, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 01, 2020
Browse files
Better support for resuming training (#8878)
parent
21db560d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
11 deletions
+65
-11
src/transformers/trainer.py
src/transformers/trainer.py
+25
-11
src/transformers/training_args.py
src/transformers/training_args.py
+10
-0
tests/test_trainer.py
tests/test_trainer.py
+30
-0
No files found.
src/transformers/trainer.py
View file @
7c10dd22
...
...
@@ -665,12 +665,12 @@ class Trainer:
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples =
%d"
,
num_examples
)
logger
.
info
(
" Num Epochs =
%d"
,
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per device =
%d"
,
self
.
args
.
per_device_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) =
%d"
,
total_train_batch_size
)
logger
.
info
(
" Gradient Accumulation steps =
%d"
,
self
.
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps =
%d"
,
max_steps
)
logger
.
info
(
f
" Num examples =
{
num_examples
}
"
)
logger
.
info
(
f
" Num Epochs =
{
num_train_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
self
.
state
.
epoch
=
0
epochs_trained
=
0
...
...
@@ -680,13 +680,20 @@ class Trainer:
if
model_path
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
)):
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
))
epochs_trained
=
self
.
state
.
global_step
//
num_update_steps_per_epoch
if
not
self
.
args
.
ignore_data_skip
:
steps_trained_in_current_epoch
=
self
.
state
.
global_step
%
(
num_update_steps_per_epoch
)
steps_trained_in_current_epoch
*=
self
.
args
.
gradient_accumulation_steps
else
:
steps_trained_in_current_epoch
=
0
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
state
.
global_step
)
logger
.
info
(
" Will skip the first %d batches in the first epoch"
,
steps_trained_in_current_epoch
)
logger
.
info
(
f
" Continuing training from epoch
{
epochs_trained
}
"
)
logger
.
info
(
f
" Continuing training from global step
{
self
.
state
.
global_step
}
"
)
if
not
self
.
args
.
ignore_data_skip
:
logger
.
info
(
f
" Will skip the first
{
epochs_trained
}
epochs then the first
{
steps_trained_in_current_epoch
}
"
"batches in the first epoch."
)
# Update the references
self
.
callback_handler
.
model
=
self
.
model
...
...
@@ -712,6 +719,13 @@ class Trainer:
self
.
control
=
self
.
callback_handler
.
on_train_begin
(
self
.
args
,
self
.
state
,
self
.
control
)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if
not
self
.
args
.
ignore_data_skip
:
for
epoch
in
range
(
epochs_trained
):
# We just need to begin an iteration to create the randomization of the sampler.
for
_
in
train_dataloader
:
break
for
epoch
in
range
(
epochs_trained
,
num_train_epochs
):
if
isinstance
(
train_dataloader
,
DataLoader
)
and
isinstance
(
train_dataloader
.
sampler
,
DistributedSampler
):
train_dataloader
.
sampler
.
set_epoch
(
epoch
)
...
...
src/transformers/training_args.py
View file @
7c10dd22
...
...
@@ -189,6 +189,10 @@ class TrainingArguments:
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If there are more than one devices, whether to use model parallelism to distribute the model's modules
across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
"""
output_dir
:
str
=
field
(
...
...
@@ -350,6 +354,12 @@ class TrainingArguments:
greater_is_better
:
Optional
[
bool
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Whether the `metric_for_best_model` should be maximized or not."
}
)
ignore_data_skip
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
def
__post_init__
(
self
):
if
self
.
disable_tqdm
is
None
:
...
...
tests/test_trainer.py
View file @
7c10dd22
...
...
@@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
# Reinitialize trainer and load model
model
=
RegressionPreTrainedModel
.
from_pretrained
(
checkpoint
)
trainer
=
Trainer
(
model
,
trainer
.
args
,
train_dataset
=
trainer
.
train_dataset
)
trainer
.
train
(
model_path
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
# With a regular model that is not a PreTrainedModel
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
trainer
=
get_regression_trainer
(
...
...
@@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
# Reinitialize trainer and load model
model
=
RegressionModel
()
state_dict
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
WEIGHTS_NAME
))
model
.
load_state_dict
(
state_dict
)
trainer
=
Trainer
(
model
,
trainer
.
args
,
train_dataset
=
trainer
.
train_dataset
)
trainer
.
train
(
model_path
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
def
test_resume_training_with_gradient_accumulation
(
self
):
if
torch
.
cuda
.
device_count
()
>
2
:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment