Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5ae935d2
Unverified
Commit
5ae935d2
authored
Oct 22, 2020
by
Sylvain Gugger
Committed by
GitHub
Oct 22, 2020
Browse files
Reload checkpoint (#7984)
* Fix checkpoint loading in Trainer * Fix typo
parent
467573dd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
17 deletions
+35
-17
src/transformers/trainer.py
src/transformers/trainer.py
+29
-12
src/transformers/trainer_callback.py
src/transformers/trainer_callback.py
+3
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+3
-4
No files found.
src/transformers/trainer.py
View file @
5ae935d2
...
...
@@ -628,18 +628,7 @@ class Trainer:
self
.
state
.
is_hyper_param_search
=
trial
is
not
None
# Check if saved optimizer or scheduler states exist
if
(
model_path
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
))
):
# Load in optimizer and scheduler states
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
reissue_pt_warnings
(
caught_warnings
)
self
.
_load_optimizer_and_scheduler
(
model_path
)
# Mixed precision training with apex (torch < 1.6)
model
=
self
.
model
...
...
@@ -919,6 +908,34 @@ class Trainer:
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
)
def
_load_optimizer_and_scheduler
(
self
,
model_path
):
"""If optimizer and scheduler states exist, load them."""
if
(
model_path
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
))
):
# Load in optimizer and scheduler states
if
is_torch_tpu_available
():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
"cpu"
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
lr_scheduler_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
),
map_location
=
"cpu"
)
reissue_pt_warnings
(
caught_warnings
)
xm
.
send_cpu_data_to_device
(
optimizer_state
,
self
.
args
.
device
)
xm
.
send_cpu_data_to_device
(
lr_scheduler_state
,
self
.
args
.
device
)
self
.
optimizer
.
load_state_dict
(
optimizer_state
)
self
.
lr_scheduler
.
load_state_dict
(
lr_scheduler_state
)
else
:
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
reissue_pt_warnings
(
caught_warnings
)
def
hyperparameter_search
(
self
,
hp_space
:
Optional
[
Callable
[[
"optuna.Trial"
],
Dict
[
str
,
float
]]]
=
None
,
...
...
src/transformers/trainer_callback.py
View file @
5ae935d2
...
...
@@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback):
def
on_train_begin
(
self
,
args
,
state
,
control
,
**
kwargs
):
if
state
.
is_local_process_zero
:
self
.
training_bar
=
tqdm
(
total
=
state
.
max_steps
)
self
.
current_step
=
0
def
on_step_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
if
state
.
is_local_process_zero
:
self
.
training_bar
.
update
(
1
)
self
.
training_bar
.
update
(
state
.
global_step
-
self
.
current_step
)
self
.
current_step
=
state
.
global_step
def
on_prediction_step
(
self
,
args
,
state
,
control
,
eval_dataloader
=
None
,
**
kwargs
):
if
state
.
is_local_process_zero
:
...
...
src/transformers/trainer_pt_utils.py
View file @
5ae935d2
...
...
@@ -23,6 +23,7 @@ from typing import List, Optional, Union
import
numpy
as
np
import
torch
from
torch.optim.lr_scheduler
import
SAVE_STATE_WARNING
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
...
...
@@ -33,8 +34,6 @@ from .utils import logging
if
is_torch_tpu_available
():
import
torch_xla.core.xla_model
as
xm
PT_LR_SCHEDULER_WARNING
=
"Please also save or load the state of the optimzer when saving or loading the scheduler."
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -112,10 +111,10 @@ def distributed_broadcast_scalars(
def
reissue_pt_warnings
(
caught_warnings
):
# Reissue warnings that are not the
PT_LR_SCHEDULER
_WARNING
# Reissue warnings that are not the
SAVE_STATE
_WARNING
if
len
(
caught_warnings
)
>
1
:
for
w
in
caught_warnings
:
if
w
.
category
!=
UserWarning
or
w
.
message
!=
PT_LR_SCHEDULER
_WARNING
:
if
w
.
category
!=
UserWarning
or
w
.
message
!=
SAVE_STATE
_WARNING
:
warnings
.
warn
(
w
.
message
,
w
.
category
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment