Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0328a6c
Unverified
Commit
c0328a6c
authored
Apr 19, 2021
by
Sylvain Gugger
Committed by
GitHub
Apr 19, 2021
Browse files
Load checkpoint without re-creating the model (#11318)
parent
95037a16
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
12 deletions
+61
-12
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+1
-1
src/transformers/trainer.py
src/transformers/trainer.py
+20
-11
tests/test_trainer.py
tests/test_trainer.py
+40
-0
No files found.
src/transformers/configuration_utils.py
View file @
c0328a6c
...
...
@@ -271,7 +271,7 @@ class PretrainedConfig(object):
self
.
_name_or_path
=
str
(
kwargs
.
pop
(
"name_or_path"
,
""
))
# Drop the transformers version info
kwargs
.
pop
(
"transformers_version"
,
None
)
self
.
transformers_version
=
kwargs
.
pop
(
"transformers_version"
,
None
)
# Additional attributes without default values
for
key
,
value
in
kwargs
.
items
():
...
...
src/transformers/trainer.py
View file @
c0328a6c
...
...
@@ -55,9 +55,12 @@ from torch.utils.data.dataset import Dataset, IterableDataset
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
.
import
__version__
from
.configuration_utils
import
PretrainedConfig
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.dependency_versions_check
import
dep_version_check
from
.file_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
is_apex_available
,
is_datasets_available
,
...
...
@@ -999,14 +1002,23 @@ class Trainer:
logger
.
info
(
f
"Loading model from
{
resume_from_checkpoint
}
)."
)
if
os
.
path
.
isfile
(
os
.
path
.
join
(
resume_from_checkpoint
,
CONFIG_NAME
)):
config
=
PretrainedConfig
.
from_json_file
(
os
.
path
.
join
(
resume_from_checkpoint
,
CONFIG_NAME
))
checkpoint_version
=
config
.
transformers_version
if
checkpoint_version
is
not
None
and
checkpoint_version
!=
__version__
:
logger
.
warn
(
f
"You are resuming training from a checkpoint trained with
{
checkpoint_version
}
of "
f
"Transformers but your current version is
{
__version__
}
. This is not recommended and could "
"yield to errors or unwanted behaviors."
)
if
self
.
deepspeed
:
# will be resumed in deepspeed_init
pass
elif
isinstance
(
self
.
model
,
PreTrainedModel
):
self
.
model
=
self
.
model
.
from_pretrained
(
resume_from_checkpoint
)
model_reloaded
=
True
else
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
))
# We load the model state dict on the CPU to avoid an OOM error.
state_dict
=
torch
.
load
(
os
.
path
.
join
(
resume_from_checkpoint
,
WEIGHTS_NAME
),
map_location
=
"cpu"
)
# If the model is on the GPU, it still works!
self
.
model
.
load_state_dict
(
state_dict
)
# If model was re-initialized, put it on the right device and update self.model_wrapped
...
...
@@ -1293,12 +1305,9 @@ class Trainer:
logger
.
info
(
f
"Loading best model from
{
self
.
state
.
best_model_checkpoint
}
(score:
{
self
.
state
.
best_metric
}
)."
)
if
isinstance
(
self
.
model
,
PreTrainedModel
):
self
.
model
=
self
.
model
.
from_pretrained
(
self
.
state
.
best_model_checkpoint
)
if
self
.
place_model_on_device
:
self
.
model
=
self
.
model
.
to
(
args
.
device
)
else
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
WEIGHTS_NAME
))
# We load the model state dict on the CPU to avoid an OOM error.
state_dict
=
torch
.
load
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
WEIGHTS_NAME
),
map_location
=
"cpu"
)
# If the model is on the GPU, it still works!
self
.
model
.
load_state_dict
(
state_dict
)
if
self
.
deepspeed
:
...
...
tests/test_trainer.py
View file @
c0328a6c
...
...
@@ -725,6 +725,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self
.
assertEqual
(
b
,
b1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
def
test_resume_training_with_frozen_params
(
self
):
if
torch
.
cuda
.
device_count
()
>
2
:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
per_device_train_batch_size
=
4
,
save_steps
=
5
,
learning_rate
=
0.1
,
)
trainer
.
model
.
a
.
requires_grad_
(
False
)
trainer
.
train
()
(
a
,
b
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state
=
dataclasses
.
asdict
(
trainer
.
state
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-5"
)
# Reinitialize trainer
trainer
=
get_regression_trainer
(
output_dir
=
tmpdir
,
train_len
=
128
,
per_device_train_batch_size
=
4
,
save_steps
=
5
,
learning_rate
=
0.1
,
)
trainer
.
model
.
a
.
requires_grad_
(
False
)
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
self
.
assertFalse
(
trainer
.
model
.
a
.
requires_grad
)
(
a1
,
b1
)
=
trainer
.
model
.
a
.
item
(),
trainer
.
model
.
b
.
item
()
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
def
test_load_best_model_at_end
(
self
):
total
=
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment