Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7c10dd22
Unverified
Commit
7c10dd22
authored
Dec 01, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 01, 2020
Browse files
Better support for resuming training (#8878)
parent
21db560d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
11 deletions
+65
-11
src/transformers/trainer.py
src/transformers/trainer.py
+25
-11
src/transformers/training_args.py
src/transformers/training_args.py
+10
-0
tests/test_trainer.py
tests/test_trainer.py
+30
-0
No files found.
src/transformers/trainer.py
View file @
7c10dd22
...
...
@@ -665,12 +665,12 @@ class Trainer:
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples =
%d"
,
num_examples
)
logger
.
info
(
" Num Epochs =
%d"
,
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per device =
%d"
,
self
.
args
.
per_device_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) =
%d"
,
total_train_batch_size
)
logger
.
info
(
" Gradient Accumulation steps =
%d"
,
self
.
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps =
%d"
,
max_steps
)
logger
.
info
(
f
" Num examples =
{
num_examples
}
"
)
logger
.
info
(
f
" Num Epochs =
{
num_train_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
self
.
state
.
epoch
=
0
epochs_trained
=
0
...
...
@@ -680,13 +680,20 @@ class Trainer:
if
model_path
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
)):
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
))
epochs_trained
=
self
.
state
.
global_step
//
num_update_steps_per_epoch
steps_trained_in_current_epoch
=
self
.
state
.
global_step
%
(
num_update_steps_per_epoch
)
steps_trained_in_current_epoch
*=
self
.
args
.
gradient_accumulation_steps
if
not
self
.
args
.
ignore_data_skip
:
steps_trained_in_current_epoch
=
self
.
state
.
global_step
%
(
num_update_steps_per_epoch
)
steps_trained_in_current_epoch
*=
self
.
args
.
gradient_accumulation_steps
else
:
steps_trained_in_current_epoch
=
0
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
state
.
global_step
)
logger
.
info
(
" Will skip the first %d batches in the first epoch"
,
steps_trained_in_current_epoch
)
logger
.
info
(
f
" Continuing training from epoch
{
epochs_trained
}
"
)
logger
.
info
(
f
" Continuing training from global step
{
self
.
state
.
global_step
}
"
)
if
not
self
.
args
.
ignore_data_skip
:
logger
.
info
(
f
" Will skip the first
{
epochs_trained
}
epochs then the first
{
steps_trained_in_current_epoch
}
"
"batches in the first epoch."
)
# Update the references
self
.
callback_handler
.
model
=
self
.
model
...
...
@@ -712,6 +719,13 @@ class Trainer:
self
.
control
=
self
.
callback_handler
.
on_train_begin
(
self
.
args
,
self
.
state
,
self
.
control
)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if
not
self
.
args
.
ignore_data_skip
:
for
epoch
in
range
(
epochs_trained
):
# We just need to begin an iteration to create the randomization of the sampler.
for
_
in
train_dataloader
:
break
for
epoch
in
range
(
epochs_trained
,
num_train_epochs
):
if
isinstance
(
train_dataloader
,
DataLoader
)
and
isinstance
(
train_dataloader
.
sampler
,
DistributedSampler
):
train_dataloader
.
sampler
.
set_epoch
(
epoch
)
...
...
src/transformers/training_args.py
View file @
7c10dd22
...
...
@@ -189,6 +189,10 @@ class TrainingArguments:
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If there are more than one devices, whether to use model parallelism to distribute the model's modules
across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
"""
output_dir
:
str
=
field
(
...
...
@@ -350,6 +354,12 @@ class TrainingArguments:
greater_is_better
:
Optional
[
bool
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Whether the `metric_for_best_model` should be maximized or not."
}
)
ignore_data_skip
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
def
__post_init__
(
self
):
if
self
.
disable_tqdm
is
None
:
...
...
tests/test_trainer.py
View file @
7c10dd22
...
...
@@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
# Reinitialize trainer and load model
model
=
RegressionPreTrainedModel
.
from_pretrained
(
checkpoint
)
trainer
=
Trainer
(
model
,
trainer
.
args
,
train_dataset
=
trainer
.
train_dataset
)
trainer
.
train
(
model_path
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
# With a regular model that is not a PreTrainedModel
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
trainer
=
get_regression_trainer
(
...
...
@@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
# Reinitialize trainer and load model
model
=
RegressionModel
()
state_dict
=
torch
.
load
(
os
.
path
.
join
(
checkpoint
,
WEIGHTS_NAME
))
model
.
load_state_dict
(
state_dict
)
trainer
=
Trainer
(
model
,
trainer
.
args
,
train_dataset
=
trainer
.
train_dataset
)
trainer
.
train
(
model_path
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
def
test_resume_training_with_gradient_accumulation
(
self
):
if
torch
.
cuda
.
device_count
()
>
2
:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment