Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
fc312d28
Commit
fc312d28
authored
May 23, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jun 15, 2018
Browse files
ability to checkpoint when reaching certain number of updates
parent
58e2c449
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
89 additions
and
31 deletions
+89
-31
fairseq/options.py
fairseq/options.py
+6
-0
fairseq/progress_bar.py
fairseq/progress_bar.py
+1
-0
fairseq/utils.py
fairseq/utils.py
+22
-4
scripts/average_checkpoints.py
scripts/average_checkpoints.py
+13
-5
train.py
train.py
+47
-22
No files found.
fairseq/options.py
View file @
fc312d28
...
@@ -201,6 +201,12 @@ def add_checkpoint_args(parser):
...
@@ -201,6 +201,12 @@ def add_checkpoint_args(parser):
help
=
'filename in save-dir from which to load checkpoint'
)
help
=
'filename in save-dir from which to load checkpoint'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'save a checkpoint every N epochs'
)
help
=
'save a checkpoint every N epochs'
)
group
.
add_argument
(
'--save-interval-updates'
,
type
=
int
,
metavar
=
'N'
,
help
=
'if specified, saves best/last checkpoint every this many updates. '
'will also validate before saving to determine if val loss is better'
)
group
.
add_argument
(
'--keep-interval-updates'
,
type
=
int
,
default
=
0
,
metavar
=
'N'
,
help
=
'if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt'
)
group
.
add_argument
(
'--no-save'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save'
,
action
=
'store_true'
,
help
=
'don
\'
t save models or checkpoints'
)
help
=
'don
\'
t save models or checkpoints'
)
group
.
add_argument
(
'--no-epoch-checkpoints'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-epoch-checkpoints'
,
action
=
'store_true'
,
...
...
fairseq/progress_bar.py
View file @
fc312d28
...
@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar):
...
@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar):
def
print
(
self
,
stats
):
def
print
(
self
,
stats
):
"""Print end-of-epoch stats."""
"""Print end-of-epoch stats."""
self
.
stats
=
stats
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
)
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
)
print
(
json
.
dumps
(
stats
),
flush
=
True
)
print
(
json
.
dumps
(
stats
),
flush
=
True
)
...
...
fairseq/utils.py
View file @
fc312d28
...
@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict
...
@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict
import
contextlib
import
contextlib
import
logging
import
logging
import
os
import
os
import
re
import
torch
import
torch
import
traceback
import
traceback
...
@@ -351,10 +352,11 @@ def buffered_arange(max):
...
@@ -351,10 +352,11 @@ def buffered_arange(max):
def
convert_padding_direction
(
def
convert_padding_direction
(
src_tokens
,
src_tokens
,
padding_idx
,
src_lengths
,
right_to_left
=
False
,
padding_idx
,
left_to_right
=
False
,
right_to_left
=
False
,
left_to_right
=
False
,
):
):
assert
right_to_left
^
left_to_right
assert
right_to_left
^
left_to_right
pad_mask
=
src_tokens
.
eq
(
padding_idx
)
pad_mask
=
src_tokens
.
eq
(
padding_idx
)
...
@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm):
...
@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm):
def
fill_with_neg_inf
(
t
):
def
fill_with_neg_inf
(
t
):
"""FP16-compatible function that fills a tensor with -inf."""
"""FP16-compatible function that fills a tensor with -inf."""
return
t
.
float
().
fill_
(
float
(
'-inf'
)).
type_as
(
t
)
return
t
.
float
().
fill_
(
float
(
'-inf'
)).
type_as
(
t
)
def
checkpoint_paths
(
path
,
pattern
=
r
'checkpoint(\d+)\.pt'
):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """
pt_regexp
=
re
.
compile
(
pattern
)
files
=
os
.
listdir
(
path
)
entries
=
[]
for
i
,
f
in
enumerate
(
files
):
m
=
pt_regexp
.
fullmatch
(
f
)
if
m
is
not
None
:
idx
=
int
(
m
.
group
(
1
))
if
len
(
m
.
groups
())
>
0
else
i
entries
.
append
((
idx
,
m
.
group
(
0
)))
return
[
os
.
path
.
join
(
path
,
x
[
1
])
for
x
in
sorted
(
entries
,
reverse
=
True
)]
scripts/average_checkpoints.py
View file @
fc312d28
...
@@ -62,10 +62,13 @@ def average_checkpoints(inputs):
...
@@ -62,10 +62,13 @@ def average_checkpoints(inputs):
return
new_state
return
new_state
def
last_n_checkpoints
(
paths
,
n
):
def
last_n_checkpoints
(
paths
,
n
,
update_based
):
assert
len
(
paths
)
==
1
assert
len
(
paths
)
==
1
path
=
paths
[
0
]
path
=
paths
[
0
]
pt_regexp
=
re
.
compile
(
r
'checkpoint(\d+)\.pt'
)
if
update_based
:
pt_regexp
=
re
.
compile
(
r
'checkpoint_\d+_(\d+)\.pt'
)
else
:
pt_regexp
=
re
.
compile
(
r
'checkpoint(\d+)\.pt'
)
files
=
os
.
listdir
(
path
)
files
=
os
.
listdir
(
path
)
entries
=
[]
entries
=
[]
...
@@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n):
...
@@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n):
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'Tool to average the params of input checkpoints to '
description
=
'Tool to average the params of input checkpoints to '
'produce a new checkpoint'
,
'produce a new checkpoint'
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -95,7 +98,7 @@ def main():
...
@@ -95,7 +98,7 @@ def main():
required
=
True
,
required
=
True
,
metavar
=
'FILE'
,
metavar
=
'FILE'
,
help
=
'Write the new checkpoint containing the averaged weights to this '
help
=
'Write the new checkpoint containing the averaged weights to this '
'path.'
,
'path.'
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num'
,
'--num'
,
...
@@ -103,11 +106,16 @@ def main():
...
@@ -103,11 +106,16 @@ def main():
help
=
'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
help
=
'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those'
,
'and average last num of those'
,
)
)
parser
.
add_argument
(
'--update-based-checkpoints'
,
action
=
'store_true'
,
help
=
'if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
if
args
.
num
is
not
None
:
if
args
.
num
is
not
None
:
args
.
inputs
=
last_n_checkpoints
(
args
.
inputs
,
args
.
num
)
args
.
inputs
=
last_n_checkpoints
(
args
.
inputs
,
args
.
num
,
args
.
update_based_checkpoints
)
print
(
'averaging checkpoints: '
,
args
.
inputs
)
print
(
'averaging checkpoints: '
,
args
.
inputs
)
new_state
=
average_checkpoints
(
args
.
inputs
)
new_state
=
average_checkpoints
(
args
.
inputs
)
...
...
train.py
View file @
fc312d28
...
@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar
...
@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar
from
fairseq.fp16_trainer
import
FP16Trainer
from
fairseq.fp16_trainer
import
FP16Trainer
from
fairseq.trainer
import
Trainer
from
fairseq.trainer
import
Trainer
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
from
fairseq.utils
import
checkpoint_paths
def
main
(
args
):
def
main
(
args
):
if
args
.
max_tokens
is
None
:
if
args
.
max_tokens
is
None
:
args
.
max_tokens
=
6000
args
.
max_tokens
=
6000
...
@@ -82,26 +82,22 @@ def main(args):
...
@@ -82,26 +82,22 @@ def main(args):
max_epoch
=
args
.
max_epoch
or
math
.
inf
max_epoch
=
args
.
max_epoch
or
math
.
inf
max_update
=
args
.
max_update
or
math
.
inf
max_update
=
args
.
max_update
or
math
.
inf
lr
=
trainer
.
get_lr
()
lr
=
trainer
.
get_lr
()
first_val_loss
=
None
train_meter
=
StopwatchMeter
()
train_meter
=
StopwatchMeter
()
train_meter
.
start
()
train_meter
.
start
()
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
and
trainer
.
get_num_updates
()
<
max_update
:
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
and
trainer
.
get_num_updates
()
<
max_update
:
# train for one epoch
# train for one epoch
train
(
args
,
trainer
,
next
(
train_dataloader
),
epoch
)
train
(
args
,
trainer
,
next
(
train_dataloader
),
epoch
,
dataset
)
# evaluate on validate set
first_val_loss
=
None
if
epoch
%
args
.
validate_interval
==
0
:
if
epoch
%
args
.
validate_interval
==
0
:
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
first_val_loss
=
val_loss
(
args
,
trainer
,
dataset
,
epoch
)
val_loss
=
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
)
if
k
==
0
:
first_val_loss
=
val_loss
# only use first validation loss to update the learning rate
# only use first validation loss to update the learning rate
lr
=
trainer
.
lr_step
(
epoch
,
first_val_loss
)
lr
=
trainer
.
lr_step
(
epoch
,
first_val_loss
)
# save checkpoint
# save checkpoint
if
not
args
.
no_save
and
epoch
%
args
.
save_interval
==
0
:
if
not
args
.
no_save
and
epoch
%
args
.
save_interval
==
0
:
save_checkpoint
(
trainer
,
args
,
epoch
,
first_val_loss
)
save_checkpoint
(
trainer
,
args
,
epoch
,
end_of_epoch
=
True
,
val_loss
=
first_val_loss
)
epoch
+=
1
epoch
+=
1
train_meter
.
stop
()
train_meter
.
stop
()
...
@@ -120,7 +116,7 @@ def load_dataset(args, splits):
...
@@ -120,7 +116,7 @@ def load_dataset(args, splits):
return
dataset
return
dataset
def
train
(
args
,
trainer
,
itr
,
epoch
):
def
train
(
args
,
trainer
,
itr
,
epoch
,
dataset
):
"""Train the model for one epoch."""
"""Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get
# Set seed based on args.seed and the epoch number so that we get
...
@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch):
...
@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch):
if
i
==
0
:
if
i
==
0
:
trainer
.
get_meter
(
'wps'
).
reset
()
trainer
.
get_meter
(
'wps'
).
reset
()
if
trainer
.
get_num_updates
()
>=
max_update
:
num_updates
=
trainer
.
get_num_updates
()
if
not
args
.
no_save
and
(
args
.
save_interval_updates
or
0
)
>
0
and
num_updates
%
args
.
save_interval_updates
==
0
:
first_val_loss
=
val_loss
(
args
,
trainer
,
dataset
,
epoch
,
num_updates
)
save_checkpoint
(
trainer
,
args
,
epoch
,
end_of_epoch
=
False
,
val_loss
=
first_val_loss
)
if
num_updates
>=
max_update
:
break
break
# log end-of-epoch stats
# log end-of-epoch stats
...
@@ -202,7 +203,7 @@ def get_training_stats(trainer):
...
@@ -202,7 +203,7 @@ def get_training_stats(trainer):
return
stats
return
stats
def
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
):
def
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
,
verbose
):
"""Evaluate the model on the validation set and return the average loss."""
"""Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader
# Initialize dataloader
...
@@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch):
...
@@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch):
for
sample
in
progress
:
for
sample
in
progress
:
log_output
=
trainer
.
valid_step
(
sample
)
log_output
=
trainer
.
valid_step
(
sample
)
# log mid-validation stats
if
verbose
:
stats
=
get_valid_stats
(
trainer
)
# log mid-validation stats
for
k
,
v
in
log_output
.
items
():
stats
=
get_valid_stats
(
trainer
)
if
k
in
[
'loss'
,
'nll_loss'
,
'sample_size'
]:
for
k
,
v
in
log_output
.
items
():
continue
if
k
in
[
'loss'
,
'nll_loss'
,
'sample_size'
]:
extra_meters
[
k
].
update
(
v
)
continue
stats
[
k
]
=
extra_meters
[
k
].
avg
extra_meters
[
k
].
update
(
v
)
progress
.
log
(
stats
)
stats
[
k
]
=
extra_meters
[
k
].
avg
progress
.
log
(
stats
)
# log validation stats
# log validation stats
stats
=
get_valid_stats
(
trainer
)
stats
=
get_valid_stats
(
trainer
)
for
k
,
meter
in
extra_meters
.
items
():
for
k
,
meter
in
extra_meters
.
items
():
stats
[
k
]
=
meter
.
avg
stats
[
k
]
=
meter
.
avg
if
num_updates
is
not
None
:
stats
[
'num_updates'
]
=
num_updates
progress
.
print
(
stats
)
progress
.
print
(
stats
)
return
stats
[
'valid_loss'
]
return
stats
[
'valid_loss'
]
...
@@ -273,16 +279,33 @@ def get_perplexity(loss):
...
@@ -273,16 +279,33 @@ def get_perplexity(loss):
return
float
(
'inf'
)
return
float
(
'inf'
)
def
save_checkpoint
(
trainer
,
args
,
epoch
,
val_loss
=
None
):
def
val_loss
(
args
,
trainer
,
dataset
,
epoch
,
num_updates
=
None
):
# evaluate on validate set
subsets
=
args
.
valid_subset
.
split
(
','
)
# we want to validate all subsets so the results get printed out, but return only the first
losses
=
[
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
,
verbose
=
False
)
for
subset
in
subsets
]
return
losses
[
0
]
if
len
(
losses
)
>
0
else
None
def
save_checkpoint
(
trainer
,
args
,
epoch
,
end_of_epoch
,
val_loss
):
extra_state
=
{
extra_state
=
{
'epoch'
:
epoch
,
'epoch'
:
epoch
,
'val_loss'
:
val_loss
,
'val_loss'
:
val_loss
,
'wall_time'
:
trainer
.
get_meter
(
'wall'
).
elapsed_time
,
'wall_time'
:
trainer
.
get_meter
(
'wall'
).
elapsed_time
,
}
}
if
not
args
.
no_epoch_checkpoints
:
if
end_of_epoch
and
not
args
.
no_epoch_checkpoints
:
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}.pt'
.
format
(
epoch
))
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}.pt'
.
format
(
epoch
))
trainer
.
save_checkpoint
(
epoch_filename
,
extra_state
)
trainer
.
save_checkpoint
(
epoch_filename
,
extra_state
)
elif
not
end_of_epoch
and
args
.
keep_interval_updates
>
0
:
checkpoint_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_{}_{}.pt'
.
format
(
epoch
,
trainer
.
get_num_updates
()))
trainer
.
save_checkpoint
(
checkpoint_filename
,
extra_state
)
# remove old checkpoints
checkpoints
=
checkpoint_paths
(
args
.
save_dir
,
pattern
=
r
'checkpoint_\d+_(\d+)\.pt'
)
# checkpoints are sorted in descending order
for
old_chk
in
checkpoints
[
args
.
keep_interval_updates
:]:
os
.
remove
(
old_chk
)
assert
val_loss
is
not
None
assert
val_loss
is
not
None
if
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
:
if
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
:
...
@@ -317,9 +340,11 @@ if __name__ == '__main__':
...
@@ -317,9 +340,11 @@ if __name__ == '__main__':
if
args
.
distributed_port
>
0
or
args
.
distributed_init_method
is
not
None
:
if
args
.
distributed_port
>
0
or
args
.
distributed_init_method
is
not
None
:
from
distributed_train
import
main
as
distributed_main
from
distributed_train
import
main
as
distributed_main
distributed_main
(
args
)
distributed_main
(
args
)
elif
args
.
distributed_world_size
>
1
:
elif
args
.
distributed_world_size
>
1
:
from
multiprocessing_train
import
main
as
multiprocessing_main
from
multiprocessing_train
import
main
as
multiprocessing_main
multiprocessing_main
(
args
)
multiprocessing_main
(
args
)
else
:
else
:
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment