Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1e62e999
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "12d0eb5f3e52e43b12b7592a2fdb1d31a50245ea"
Unverified
Commit
1e62e999
authored
Nov 18, 2020
by
Sylvain Gugger
Committed by
GitHub
Nov 18, 2020
Browse files
Fixes the training resuming with gradient accumulation (#8624)
parent
cdfa56af
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
1 deletion
+42
-1
src/transformers/trainer.py
src/transformers/trainer.py
+2
-1
tests/test_trainer.py
tests/test_trainer.py
+40
-0
No files found.
src/transformers/trainer.py
View file @
1e62e999
...
@@ -676,11 +676,12 @@ class Trainer:
...
@@ -676,11 +676,12 @@ class Trainer:
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
))
self
.
state
=
TrainerState
.
load_from_json
(
os
.
path
.
join
(
model_path
,
"trainer_state.json"
))
epochs_trained
=
self
.
state
.
global_step
//
num_update_steps_per_epoch
epochs_trained
=
self
.
state
.
global_step
//
num_update_steps_per_epoch
steps_trained_in_current_epoch
=
self
.
state
.
global_step
%
(
num_update_steps_per_epoch
)
steps_trained_in_current_epoch
=
self
.
state
.
global_step
%
(
num_update_steps_per_epoch
)
steps_trained_in_current_epoch
*=
self
.
args
.
gradient_accumulation_steps
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
state
.
global_step
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
state
.
global_step
)
logger
.
info
(
" Will skip the first %d
step
s in the first epoch"
,
steps_trained_in_current_epoch
)
logger
.
info
(
" Will skip the first %d
batche
s in the first epoch"
,
steps_trained_in_current_epoch
)
# Update the references
# Update the references
self
.
callback_handler
.
model
=
self
.
model
self
.
callback_handler
.
model
=
self
.
model
...
...
tests/test_trainer.py
View file @
1e62e999
...
@@ -465,6 +465,14 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -465,6 +465,14 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer
.
train
()
trainer
.
train
()
self
.
check_saved_checkpoints
(
tmpdir
,
5
,
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
),
False
)
self
.
check_saved_checkpoints
(
tmpdir
,
5
,
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
),
False
)
def
test_gradient_accumulation
(
self
):
# Training with half the batch size but accumulation steps as 2 should give the same results.
trainer
=
get_regression_trainer
(
gradient_accumulation_steps
=
2
,
per_device_train_batch_size
=
4
,
learning_rate
=
0.1
)
trainer
.
train
()
self
.
check_trained_model
(
trainer
.
model
)
def
test_can_resume_training
(
self
):
def
test_can_resume_training
(
self
):
if
torch
.
cuda
.
device_count
()
>
2
:
if
torch
.
cuda
.
device_count
()
>
2
:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
...
@@ -514,6 +522,38 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -514,6 +522,38 @@ class TrainerIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
self
.
assertEqual
(
state
,
state1
)
def
test_resume_training_with_gradient_accumulation
(
self
):
if
torch
.
cuda
.
device_count
()
>
2
:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
gradient_accumulation_steps
=
2
,
per_device_train_batch_size
=
4
,
save_steps
=
5
,
learning_rate
=
0.1
,
)
trainer
.
train
()
(
a
,
b
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state
=
dataclasses
.
asdict
(
trainer
.
state
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-5"
)
# Reinitialize trainer and load model
model
=
RegressionPreTrainedModel
.
from_pretrained
(
checkpoint
)
trainer
=
Trainer
(
model
,
trainer
.
args
,
train_dataset
=
trainer
.
train_dataset
)
trainer
.
train
(
model_path
=
checkpoint
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
def
test_load_best_model_at_end
(
self
):
def
test_load_best_model_at_end
(
self
):
total
=
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
)
total
=
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment