Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
10bf4074
Commit
10bf4074
authored
Dec 04, 2017
by
Myle Ott
Browse files
Rebuild optimizer when loading checkpoints
parent
9f3ccaa6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
13 deletions
+19
-13
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+16
-3
fairseq/utils.py
fairseq/utils.py
+3
-10
No files found.
fairseq/multiprocessing_trainer.py
View file @
10bf4074
...
@@ -157,11 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -157,11 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return
extra_state
return
extra_state
def
_async_load_checkpoint
(
self
,
rank
,
device_id
,
filename
):
def
_async_load_checkpoint
(
self
,
rank
,
device_id
,
filename
):
extra_state
,
self
.
_optim_history
=
utils
.
load_state
(
extra_state
,
self
.
_optim_history
,
last_optim_state
=
utils
.
load_model_state
(
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
filename
,
self
.
model
,
cuda_device
=
device_id
)
self
.
lr_scheduler
,
cuda_device
=
device_id
)
if
last_optim_state
is
not
None
:
# rebuild optimizer after loading model, since params may have changed
self
.
optimizer
=
self
.
_build_optimizer
()
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
# only load optimizer and lr_scheduler if they match the checkpoint
last_optim
=
self
.
_optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
self
.
criterion
.
__class__
.
__name__
:
self
.
optimizer
.
load_state_dict
(
last_optim_state
)
self
.
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
# override learning rate, momentum, etc. with latest values
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
.
update
(
self
.
_override_optim_state
)
group
.
update
(
self
.
_override_optim_state
)
return
extra_state
return
extra_state
def
set_seed
(
self
,
seed
):
def
set_seed
(
self
,
seed
):
...
...
fairseq/utils.py
View file @
10bf4074
...
@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_
...
@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_
torch_persistent_save
(
state_dict
,
filename
)
torch_persistent_save
(
state_dict
,
filename
)
def
load_state
(
filename
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
def
load_
model_
state
(
filename
,
model
,
cuda_device
=
None
):
if
not
os
.
path
.
exists
(
filename
):
if
not
os
.
path
.
exists
(
filename
):
return
None
,
[]
return
None
,
[]
,
None
if
cuda_device
is
None
:
if
cuda_device
is
None
:
state
=
torch
.
load
(
filename
)
state
=
torch
.
load
(
filename
)
else
:
else
:
...
@@ -103,14 +103,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
...
@@ -103,14 +103,7 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
raise
Exception
(
'Cannot load model parameters from checkpoint, '
raise
Exception
(
'Cannot load model parameters from checkpoint, '
'please ensure that the architectures match'
)
'please ensure that the architectures match'
)
# only load optimizer and lr_scheduler if they match with the checkpoint
return
state
[
'extra_state'
],
state
[
'optimizer_history'
],
state
[
'last_optimizer_state'
]
optim_history
=
state
[
'optimizer_history'
]
last_optim
=
optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
criterion
.
__class__
.
__name__
:
optimizer
.
load_state_dict
(
state
[
'last_optimizer_state'
])
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
return
state
[
'extra_state'
],
optim_history
def
_upgrade_state_dict
(
state
):
def
_upgrade_state_dict
(
state
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment