Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
eea50f38
Commit
eea50f38
authored
Oct 12, 2017
by
Myle Ott
Browse files
Refactor model saving/loading to be more reusable
parent
3f970086
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
46 deletions
+66
-46
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+11
-12
fairseq/utils.py
fairseq/utils.py
+19
-29
train.py
train.py
+36
-5
No files found.
fairseq/multiprocessing_trainer.py
View file @
eea50f38
...
@@ -100,14 +100,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -100,14 +100,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
_async_get_model
(
self
,
rank
,
device_id
):
def
_async_get_model
(
self
,
rank
,
device_id
):
return
self
.
model
return
self
.
model
def
save_checkpoint
(
self
,
args
,
epoch
,
batch_offset
,
val_loss
=
Non
e
):
def
save_checkpoint
(
self
,
filename
,
extra_stat
e
):
"""Save a checkpoint for the current model."""
"""Save a checkpoint for the current model."""
self
.
call_async
(
0
,
'_async_save_checkpoint'
,
args
=
args
,
epoch
=
epoch
,
self
.
call_async
(
0
,
'_async_save_checkpoint'
,
filename
=
filename
,
extra_state
=
extra_state
).
gen
()
batch_offset
=
batch_offset
,
val_loss
=
val_loss
).
gen
()
def
_async_save_checkpoint
(
self
,
rank
,
device_id
,
args
,
epoch
,
batch_offset
,
val_loss
):
def
_async_save_checkpoint
(
self
,
rank
,
device_id
,
filename
,
extra_state
):
utils
.
save_
checkpoint
(
args
,
epoch
,
batch_offset
,
self
.
model
,
self
.
criterion
,
utils
.
save_
state
(
filename
,
self
.
args
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
optimizer
,
self
.
lr_scheduler
,
val_loss
,
self
.
_optim_history
)
self
.
lr_scheduler
,
self
.
_optim_history
,
extra_state
)
def
load_checkpoint
(
self
,
filename
):
def
load_checkpoint
(
self
,
filename
):
"""Load a checkpoint into the model replicas in each process."""
"""Load a checkpoint into the model replicas in each process."""
...
@@ -115,14 +114,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -115,14 +114,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
call_async
(
rank
,
'_async_load_checkpoint'
,
filename
=
filename
)
self
.
call_async
(
rank
,
'_async_load_checkpoint'
,
filename
=
filename
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
in
range
(
self
.
num_replicas
)
])
])
e
poch
,
batch_offset
=
results
[
0
]
e
xtra_state
=
results
[
0
]
return
e
poch
,
batch_offset
return
e
xtra_state
def
_async_load_checkpoint
(
self
,
rank
,
device_id
,
filename
):
def
_async_load_checkpoint
(
self
,
rank
,
device_id
,
filename
):
e
poch
,
batch_offset
,
self
.
_optim_history
=
utils
.
load_
checkpoint
(
e
xtra_state
,
self
.
_optim_history
=
utils
.
load_
state
(
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
cuda_device
=
device_id
)
self
.
lr_scheduler
,
cuda_device
=
device_id
)
return
e
poch
,
batch_offset
return
e
xtra_state
def
train_step
(
self
,
samples
):
def
train_step
(
self
,
samples
):
"""Do forward, backward and gradient step in parallel."""
"""Do forward, backward and gradient step in parallel."""
...
...
fairseq/utils.py
View file @
eea50f38
...
@@ -46,16 +46,14 @@ def torch_persistent_save(*args, **kwargs):
...
@@ -46,16 +46,14 @@ def torch_persistent_save(*args, **kwargs):
logging
.
error
(
traceback
.
format_exc
())
logging
.
error
(
traceback
.
format_exc
())
def
save_checkpoint
(
args
,
epoch
,
batch_offset
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
def
save_state
(
filename
,
args
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
optim_history
=
None
,
extra_state
=
None
):
val_loss
=
None
,
optim_history
=
None
):
if
optim_history
is
None
:
if
optim_history
is
None
:
optim_history
=
[]
optim_history
=
[]
if
extra_state
is
None
:
extra_state
=
{}
state_dict
=
{
state_dict
=
{
'args'
:
args
,
'args'
:
args
,
'epoch'
:
epoch
,
'batch_offset'
:
batch_offset
,
'model'
:
model
.
state_dict
(),
'model'
:
model
.
state_dict
(),
'val_loss'
:
val_loss
,
'optimizer_history'
:
optim_history
+
[
'optimizer_history'
:
optim_history
+
[
{
{
'criterion_name'
:
criterion
.
__class__
.
__name__
,
'criterion_name'
:
criterion
.
__class__
.
__name__
,
...
@@ -63,26 +61,14 @@ def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_s
...
@@ -63,26 +61,14 @@ def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_s
'best_loss'
:
lr_scheduler
.
best
,
'best_loss'
:
lr_scheduler
.
best
,
}
}
],
],
'extra_state'
:
extra_state
,
}
}
torch_persistent_save
(
state_dict
,
filename
)
if
batch_offset
==
0
:
if
not
args
.
no_epoch_checkpoints
:
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}.pt'
.
format
(
epoch
))
torch_persistent_save
(
state_dict
,
epoch_filename
)
assert
val_loss
is
not
None
def
load_state
(
filename
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
if
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
:
save_checkpoint
.
best
=
val_loss
best_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_best.pt'
)
torch_persistent_save
(
state_dict
,
best_filename
)
last_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_last.pt'
)
torch_persistent_save
(
state_dict
,
last_filename
)
def
load_checkpoint
(
filename
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
if
not
os
.
path
.
exists
(
filename
):
if
not
os
.
path
.
exists
(
filename
):
return
1
,
0
,
[]
return
None
,
[]
if
cuda_device
is
None
:
if
cuda_device
is
None
:
state
=
torch
.
load
(
filename
)
state
=
torch
.
load
(
filename
)
else
:
else
:
...
@@ -92,23 +78,17 @@ def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_de
...
@@ -92,23 +78,17 @@ def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_de
)
)
state
=
_upgrade_state_dict
(
state
)
state
=
_upgrade_state_dict
(
state
)
# load model parameters
model
.
load_state_dict
(
state
[
'model'
])
model
.
load_state_dict
(
state
[
'model'
])
epoch
=
state
[
'epoch'
]
+
1
batch_offset
=
state
[
'batch_offset'
]
# only load optimizer and lr_scheduler if they match with the checkpoint
# only load optimizer and lr_scheduler if they match with the checkpoint
opt_str
=
''
optim_history
=
state
[
'optimizer_history'
]
optim_history
=
state
[
'optimizer_history'
]
last_optim
=
optim_history
[
-
1
]
last_optim
=
optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
criterion
.
__class__
.
__name__
:
if
last_optim
[
'criterion_name'
]
==
criterion
.
__class__
.
__name__
:
optimizer
.
load_state_dict
(
last_optim
[
'optimizer'
])
optimizer
.
load_state_dict
(
last_optim
[
'optimizer'
])
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
opt_str
=
'; criterion: {}'
.
format
(
last_optim
[
'criterion_name'
])
gpu_str
=
' on GPU #{}'
.
format
(
cuda_device
)
if
cuda_device
is
not
None
else
''
print
(
'| loaded checkpoint {} (epoch {}{}){}'
.
format
(
filename
,
epoch
,
opt_str
,
gpu_str
))
return
epoch
,
batch_offset
,
optim_history
return
state
[
'extra_state'
]
,
optim_history
def
_upgrade_state_dict
(
state
):
def
_upgrade_state_dict
(
state
):
...
@@ -124,6 +104,16 @@ def _upgrade_state_dict(state):
...
@@ -124,6 +104,16 @@ def _upgrade_state_dict(state):
]
]
del
state
[
'optimizer'
]
del
state
[
'optimizer'
]
del
state
[
'best_loss'
]
del
state
[
'best_loss'
]
# move extra_state into sub-dictionary
if
'epoch'
in
state
and
'extra_state'
not
in
state
:
state
[
'extra_state'
]
=
{
'epoch'
:
state
[
'epoch'
],
'batch_offset'
:
state
[
'batch_offset'
],
'val_loss'
:
state
[
'val_loss'
],
}
del
state
[
'epoch'
]
del
state
[
'batch_offset'
]
del
state
[
'val_loss'
]
return
state
return
state
...
...
train.py
View file @
eea50f38
...
@@ -62,16 +62,25 @@ def main():
...
@@ -62,16 +62,25 @@ def main():
print
(
'| using {} GPUs (with max tokens per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
))
print
(
'| using {} GPUs (with max tokens per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
))
# Build model
# Build model and criterion
print
(
'| model {}'
.
format
(
args
.
arch
))
model
=
utils
.
build_model
(
args
,
dataset
)
model
=
utils
.
build_model
(
args
,
dataset
)
criterion
=
utils
.
build_criterion
(
args
,
dataset
)
criterion
=
utils
.
build_criterion
(
args
,
dataset
)
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
# Start multiprocessing
# Start multiprocessing
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
# Load the latest checkpoint if one is available
# Load the latest checkpoint if one is available
epoch
,
batch_offset
=
trainer
.
load_checkpoint
(
os
.
path
.
join
(
args
.
save_dir
,
args
.
restore_file
))
checkpoint_path
=
os
.
path
.
join
(
args
.
save_dir
,
args
.
restore_file
)
extra_state
=
trainer
.
load_checkpoint
(
checkpoint_path
)
if
extra_state
is
not
None
:
epoch
=
extra_state
[
'epoch'
]
batch_offset
=
extra_state
[
'batch_offset'
]
print
(
'| loaded checkpoint {} (epoch {})'
.
format
(
checkpoint_path
,
epoch
))
if
batch_offset
==
0
:
epoch
+=
1
else
:
epoch
,
batch_offset
=
1
,
0
# Train until the learning rate gets too small
# Train until the learning rate gets too small
val_loss
=
None
val_loss
=
None
...
@@ -89,7 +98,7 @@ def main():
...
@@ -89,7 +98,7 @@ def main():
if
k
==
0
:
if
k
==
0
:
if
not
args
.
no_save
:
if
not
args
.
no_save
:
# save checkpoint
# save checkpoint
trainer
.
save_checkpoint
(
args
,
epoch
,
0
,
val_loss
)
save_checkpoint
(
trainer
,
args
,
epoch
,
0
,
val_loss
)
# only use first validation loss to update the learning schedule
# only use first validation loss to update the learning schedule
lr
=
trainer
.
lr_step
(
val_loss
,
epoch
)
lr
=
trainer
.
lr_step
(
val_loss
,
epoch
)
...
@@ -151,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
...
@@ -151,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
# ignore the first mini-batch in words-per-second calculation
# ignore the first mini-batch in words-per-second calculation
wps_meter
.
reset
()
wps_meter
.
reset
()
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
trainer
.
save_checkpoint
(
args
,
epoch
,
i
+
1
)
save_checkpoint
(
trainer
,
args
,
epoch
,
i
+
1
)
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
.
format
(
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
.
format
(
loss_meter
.
avg
,
math
.
pow
(
2
,
loss_meter
.
avg
))
loss_meter
.
avg
,
math
.
pow
(
2
,
loss_meter
.
avg
))
...
@@ -166,6 +175,28 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
...
@@ -166,6 +175,28 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
t
.
write
(
fmt
)
t
.
write
(
fmt
)
def
save_checkpoint
(
trainer
,
args
,
epoch
,
batch_offset
,
val_loss
):
extra_state
=
{
'epoch'
:
epoch
,
'batch_offset'
:
batch_offset
,
'val_loss'
:
val_loss
,
}
if
batch_offset
==
0
:
if
not
args
.
no_epoch_checkpoints
:
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}.pt'
.
format
(
epoch
))
trainer
.
save_checkpoint
(
epoch_filename
,
extra_state
)
assert
val_loss
is
not
None
if
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
:
save_checkpoint
.
best
=
val_loss
best_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_best.pt'
)
trainer
.
save_checkpoint
(
best_filename
,
extra_state
)
last_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_last.pt'
)
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
"""Evaluate the model on the validation set and return the average loss."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment