Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
eea50f38
Commit
eea50f38
authored
Oct 12, 2017
by
Myle Ott
Browse files
Refactor model saving/loading to be more reusable
parent
3f970086
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
46 deletions
+66
-46
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+11
-12
fairseq/utils.py
fairseq/utils.py
+19
-29
train.py
train.py
+36
-5
No files found.
fairseq/multiprocessing_trainer.py
View file @
eea50f38
...
...
@@ -100,14 +100,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
_async_get_model
(
self
,
rank
,
device_id
):
return
self
.
model
def
save_checkpoint
(
self
,
args
,
epoch
,
batch_offset
,
val_loss
=
Non
e
):
def
save_checkpoint
(
self
,
filename
,
extra_stat
e
):
"""Save a checkpoint for the current model."""
self
.
call_async
(
0
,
'_async_save_checkpoint'
,
args
=
args
,
epoch
=
epoch
,
batch_offset
=
batch_offset
,
val_loss
=
val_loss
).
gen
()
self
.
call_async
(
0
,
'_async_save_checkpoint'
,
filename
=
filename
,
extra_state
=
extra_state
).
gen
()
def
_async_save_checkpoint
(
self
,
rank
,
device_id
,
args
,
epoch
,
batch_offset
,
val_loss
):
utils
.
save_
checkpoint
(
args
,
epoch
,
batch_offset
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
val_loss
,
self
.
_optim_history
)
def
_async_save_checkpoint
(
self
,
rank
,
device_id
,
filename
,
extra_state
):
utils
.
save_
state
(
filename
,
self
.
args
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
self
.
_optim_history
,
extra_state
)
def
load_checkpoint
(
self
,
filename
):
"""Load a checkpoint into the model replicas in each process."""
...
...
@@ -115,14 +114,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
call_async
(
rank
,
'_async_load_checkpoint'
,
filename
=
filename
)
for
rank
in
range
(
self
.
num_replicas
)
])
e
poch
,
batch_offset
=
results
[
0
]
return
e
poch
,
batch_offset
e
xtra_state
=
results
[
0
]
return
e
xtra_state
def
_async_load_checkpoint
(
self
,
rank
,
device_id
,
filename
):
e
poch
,
batch_offset
,
self
.
_optim_history
=
utils
.
load_
checkpoint
(
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
cuda_device
=
device_id
)
return
e
poch
,
batch_offset
e
xtra_state
,
self
.
_optim_history
=
utils
.
load_
state
(
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
cuda_device
=
device_id
)
return
e
xtra_state
def
train_step
(
self
,
samples
):
"""Do forward, backward and gradient step in parallel."""
...
...
fairseq/utils.py
View file @
eea50f38
...
...
@@ -46,16 +46,14 @@ def torch_persistent_save(*args, **kwargs):
logging
.
error
(
traceback
.
format_exc
())
def
save_checkpoint
(
args
,
epoch
,
batch_offset
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
val_loss
=
None
,
optim_history
=
None
):
def
save_state
(
filename
,
args
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
optim_history
=
None
,
extra_state
=
None
):
if
optim_history
is
None
:
optim_history
=
[]
if
extra_state
is
None
:
extra_state
=
{}
state_dict
=
{
'args'
:
args
,
'epoch'
:
epoch
,
'batch_offset'
:
batch_offset
,
'model'
:
model
.
state_dict
(),
'val_loss'
:
val_loss
,
'optimizer_history'
:
optim_history
+
[
{
'criterion_name'
:
criterion
.
__class__
.
__name__
,
...
...
@@ -63,26 +61,14 @@ def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_s
'best_loss'
:
lr_scheduler
.
best
,
}
],
'extra_state'
:
extra_state
,
}
torch_persistent_save
(
state_dict
,
filename
)
if
batch_offset
==
0
:
if
not
args
.
no_epoch_checkpoints
:
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}.pt'
.
format
(
epoch
))
torch_persistent_save
(
state_dict
,
epoch_filename
)
assert
val_loss
is
not
None
if
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
:
save_checkpoint
.
best
=
val_loss
best_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_best.pt'
)
torch_persistent_save
(
state_dict
,
best_filename
)
last_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_last.pt'
)
torch_persistent_save
(
state_dict
,
last_filename
)
def
load_checkpoint
(
filename
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
def
load_state
(
filename
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
if
not
os
.
path
.
exists
(
filename
):
return
1
,
0
,
[]
return
None
,
[]
if
cuda_device
is
None
:
state
=
torch
.
load
(
filename
)
else
:
...
...
@@ -92,23 +78,17 @@ def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_de
)
state
=
_upgrade_state_dict
(
state
)
# load model parameters
model
.
load_state_dict
(
state
[
'model'
])
epoch
=
state
[
'epoch'
]
+
1
batch_offset
=
state
[
'batch_offset'
]
# only load optimizer and lr_scheduler if they match with the checkpoint
opt_str
=
''
optim_history
=
state
[
'optimizer_history'
]
last_optim
=
optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
criterion
.
__class__
.
__name__
:
optimizer
.
load_state_dict
(
last_optim
[
'optimizer'
])
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
opt_str
=
'; criterion: {}'
.
format
(
last_optim
[
'criterion_name'
])
gpu_str
=
' on GPU #{}'
.
format
(
cuda_device
)
if
cuda_device
is
not
None
else
''
print
(
'| loaded checkpoint {} (epoch {}{}){}'
.
format
(
filename
,
epoch
,
opt_str
,
gpu_str
))
return
epoch
,
batch_offset
,
optim_history
return
state
[
'extra_state'
]
,
optim_history
def
_upgrade_state_dict
(
state
):
...
...
@@ -124,6 +104,16 @@ def _upgrade_state_dict(state):
]
del
state
[
'optimizer'
]
del
state
[
'best_loss'
]
# move extra_state into sub-dictionary
if
'epoch'
in
state
and
'extra_state'
not
in
state
:
state
[
'extra_state'
]
=
{
'epoch'
:
state
[
'epoch'
],
'batch_offset'
:
state
[
'batch_offset'
],
'val_loss'
:
state
[
'val_loss'
],
}
del
state
[
'epoch'
]
del
state
[
'batch_offset'
]
del
state
[
'val_loss'
]
return
state
...
...
train.py
View file @
eea50f38
...
...
@@ -62,16 +62,25 @@ def main():
print
(
'| using {} GPUs (with max tokens per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
))
# Build model
print
(
'| model {}'
.
format
(
args
.
arch
))
# Build model and criterion
model
=
utils
.
build_model
(
args
,
dataset
)
criterion
=
utils
.
build_criterion
(
args
,
dataset
)
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
# Start multiprocessing
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
# Load the latest checkpoint if one is available
epoch
,
batch_offset
=
trainer
.
load_checkpoint
(
os
.
path
.
join
(
args
.
save_dir
,
args
.
restore_file
))
checkpoint_path
=
os
.
path
.
join
(
args
.
save_dir
,
args
.
restore_file
)
extra_state
=
trainer
.
load_checkpoint
(
checkpoint_path
)
if
extra_state
is
not
None
:
epoch
=
extra_state
[
'epoch'
]
batch_offset
=
extra_state
[
'batch_offset'
]
print
(
'| loaded checkpoint {} (epoch {})'
.
format
(
checkpoint_path
,
epoch
))
if
batch_offset
==
0
:
epoch
+=
1
else
:
epoch
,
batch_offset
=
1
,
0
# Train until the learning rate gets too small
val_loss
=
None
...
...
@@ -89,7 +98,7 @@ def main():
if
k
==
0
:
if
not
args
.
no_save
:
# save checkpoint
trainer
.
save_checkpoint
(
args
,
epoch
,
0
,
val_loss
)
save_checkpoint
(
trainer
,
args
,
epoch
,
0
,
val_loss
)
# only use first validation loss to update the learning schedule
lr
=
trainer
.
lr_step
(
val_loss
,
epoch
)
...
...
@@ -151,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
# ignore the first mini-batch in words-per-second calculation
wps_meter
.
reset
()
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
trainer
.
save_checkpoint
(
args
,
epoch
,
i
+
1
)
save_checkpoint
(
trainer
,
args
,
epoch
,
i
+
1
)
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
.
format
(
loss_meter
.
avg
,
math
.
pow
(
2
,
loss_meter
.
avg
))
...
...
@@ -166,6 +175,28 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
t
.
write
(
fmt
)
def
save_checkpoint
(
trainer
,
args
,
epoch
,
batch_offset
,
val_loss
):
extra_state
=
{
'epoch'
:
epoch
,
'batch_offset'
:
batch_offset
,
'val_loss'
:
val_loss
,
}
if
batch_offset
==
0
:
if
not
args
.
no_epoch_checkpoints
:
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}.pt'
.
format
(
epoch
))
trainer
.
save_checkpoint
(
epoch_filename
,
extra_state
)
assert
val_loss
is
not
None
if
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
:
save_checkpoint
.
best
=
val_loss
best_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_best.pt'
)
trainer
.
save_checkpoint
(
best_filename
,
extra_state
)
last_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_last.pt'
)
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment