Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
bd46c5ec
"vscode:/vscode.git/clone" did not exist on "7f51f286a5397cb3e5c5a25693681aa4955e6241"
Commit
bd46c5ec
authored
Nov 21, 2017
by
Myle Ott
Browse files
Prefer command-line configuration over checkpoint for optimizer state
parent
19fafae6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
11 deletions
+31
-11
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+31
-11
No files found.
fairseq/multiprocessing_trainer.py
View file @
bd46c5ec
...
...
@@ -77,21 +77,39 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_max_bsz_seen
=
0
def
_build_optimizer
(
self
):
# When resuming training from a checkpoint, we load the old optimizer
# state that includes things like learning rate, momentum factor, etc.
# We use this dictionary to override values stored in the checkpoint,
# e.g., we might prefer the values specified on the command line.
self
.
_override_optim_state
=
{}
if
self
.
args
.
optimizer
==
'adagrad'
:
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
[
0
],
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
torch
.
optim
.
Adagrad
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
elif
self
.
args
.
optimizer
==
'adam'
:
return
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
[
0
],
betas
=
eval
(
self
.
args
.
adam_betas
),
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'betas'
:
eval
(
self
.
args
.
adam_betas
),
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
elif
self
.
args
.
optimizer
==
'nag'
:
return
NAG
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
[
0
],
momentum
=
self
.
args
.
momentum
,
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'momentum'
:
self
.
args
.
momentum
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
NAG
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
elif
self
.
args
.
optimizer
==
'sgd'
:
return
torch
.
optim
.
SGD
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
[
0
],
momentum
=
self
.
args
.
momentum
,
weight_decay
=
self
.
args
.
weight_decay
)
self
.
_override_optim_state
=
{
'lr'
:
self
.
args
.
lr
[
0
],
'momentum'
:
self
.
args
.
momentum
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
return
torch
.
optim
.
SGD
(
self
.
model
.
parameters
(),
**
self
.
_override_optim_state
)
else
:
raise
ValueError
(
'Unknown optimizer: {}'
.
format
(
self
.
args
.
optimizer
))
...
...
@@ -142,6 +160,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
extra_state
,
self
.
_optim_history
=
utils
.
load_state
(
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
cuda_device
=
device_id
)
for
group
in
self
.
optimizer
.
param_groups
:
group
.
update
(
self
.
_override_optim_state
)
return
extra_state
def
set_seed
(
self
,
seed
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment