Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5ae935d2
Unverified
Commit
5ae935d2
authored
Oct 22, 2020
by
Sylvain Gugger
Committed by
GitHub
Oct 22, 2020
Browse files
Reload checkpoint (#7984)
* Fix checkpoint loading in Trainer * Fix typo
parent
467573dd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
17 deletions
+35
-17
src/transformers/trainer.py
src/transformers/trainer.py
+29
-12
src/transformers/trainer_callback.py
src/transformers/trainer_callback.py
+3
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+3
-4
No files found.
src/transformers/trainer.py
View file @
5ae935d2
...
@@ -628,18 +628,7 @@ class Trainer:
...
@@ -628,18 +628,7 @@ class Trainer:
self
.
state
.
is_hyper_param_search
=
trial
is
not
None
self
.
state
.
is_hyper_param_search
=
trial
is
not
None
# Check if saved optimizer or scheduler states exist
# Check if saved optimizer or scheduler states exist
if
(
self
.
_load_optimizer_and_scheduler
(
model_path
)
model_path
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
))
):
# Load in optimizer and scheduler states
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
reissue_pt_warnings
(
caught_warnings
)
# Mixed precision training with apex (torch < 1.6)
# Mixed precision training with apex (torch < 1.6)
model
=
self
.
model
model
=
self
.
model
...
@@ -919,6 +908,34 @@ class Trainer:
...
@@ -919,6 +908,34 @@ class Trainer:
if
self
.
is_world_process_zero
():
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
)
self
.
_rotate_checkpoints
(
use_mtime
=
True
)
def
_load_optimizer_and_scheduler
(
self
,
model_path
):
"""If optimizer and scheduler states exist, load them."""
if
(
model_path
is
not
None
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
))
):
# Load in optimizer and scheduler states
if
is_torch_tpu_available
():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
"cpu"
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
lr_scheduler_state
=
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
),
map_location
=
"cpu"
)
reissue_pt_warnings
(
caught_warnings
)
xm
.
send_cpu_data_to_device
(
optimizer_state
,
self
.
args
.
device
)
xm
.
send_cpu_data_to_device
(
lr_scheduler_state
,
self
.
args
.
device
)
self
.
optimizer
.
load_state_dict
(
optimizer_state
)
self
.
lr_scheduler
.
load_state_dict
(
lr_scheduler_state
)
else
:
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
caught_warnings
:
self
.
lr_scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
reissue_pt_warnings
(
caught_warnings
)
def
hyperparameter_search
(
def
hyperparameter_search
(
self
,
self
,
hp_space
:
Optional
[
Callable
[[
"optuna.Trial"
],
Dict
[
str
,
float
]]]
=
None
,
hp_space
:
Optional
[
Callable
[[
"optuna.Trial"
],
Dict
[
str
,
float
]]]
=
None
,
...
...
src/transformers/trainer_callback.py
View file @
5ae935d2
...
@@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback):
...
@@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback):
def
on_train_begin
(
self
,
args
,
state
,
control
,
**
kwargs
):
def
on_train_begin
(
self
,
args
,
state
,
control
,
**
kwargs
):
if
state
.
is_local_process_zero
:
if
state
.
is_local_process_zero
:
self
.
training_bar
=
tqdm
(
total
=
state
.
max_steps
)
self
.
training_bar
=
tqdm
(
total
=
state
.
max_steps
)
self
.
current_step
=
0
def
on_step_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
def
on_step_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
if
state
.
is_local_process_zero
:
if
state
.
is_local_process_zero
:
self
.
training_bar
.
update
(
1
)
self
.
training_bar
.
update
(
state
.
global_step
-
self
.
current_step
)
self
.
current_step
=
state
.
global_step
def
on_prediction_step
(
self
,
args
,
state
,
control
,
eval_dataloader
=
None
,
**
kwargs
):
def
on_prediction_step
(
self
,
args
,
state
,
control
,
eval_dataloader
=
None
,
**
kwargs
):
if
state
.
is_local_process_zero
:
if
state
.
is_local_process_zero
:
...
...
src/transformers/trainer_pt_utils.py
View file @
5ae935d2
...
@@ -23,6 +23,7 @@ from typing import List, Optional, Union
...
@@ -23,6 +23,7 @@ from typing import List, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.optim.lr_scheduler
import
SAVE_STATE_WARNING
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
...
@@ -33,8 +34,6 @@ from .utils import logging
...
@@ -33,8 +34,6 @@ from .utils import logging
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
PT_LR_SCHEDULER_WARNING
=
"Please also save or load the state of the optimzer when saving or loading the scheduler."
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -112,10 +111,10 @@ def distributed_broadcast_scalars(
...
@@ -112,10 +111,10 @@ def distributed_broadcast_scalars(
def
reissue_pt_warnings
(
caught_warnings
):
def
reissue_pt_warnings
(
caught_warnings
):
# Reissue warnings that are not the
PT_LR_SCHEDULER
_WARNING
# Reissue warnings that are not the
SAVE_STATE
_WARNING
if
len
(
caught_warnings
)
>
1
:
if
len
(
caught_warnings
)
>
1
:
for
w
in
caught_warnings
:
for
w
in
caught_warnings
:
if
w
.
category
!=
UserWarning
or
w
.
message
!=
PT_LR_SCHEDULER
_WARNING
:
if
w
.
category
!=
UserWarning
or
w
.
message
!=
SAVE_STATE
_WARNING
:
warnings
.
warn
(
w
.
message
,
w
.
category
)
warnings
.
warn
(
w
.
message
,
w
.
category
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment