Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a3e4c4c3
Commit
a3e4c4c3
authored
May 23, 2018
by
alexeib
Committed by
Myle Ott
Jun 15, 2018
Browse files
remove unused verbose option & make arguments to averaging script nicer
parent
7c07e87c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
21 deletions
+21
-21
scripts/average_checkpoints.py
scripts/average_checkpoints.py
+19
-9
train.py
train.py
+2
-12
No files found.
scripts/average_checkpoints.py
View file @
a3e4c4c3
...
...
@@ -100,22 +100,32 @@ def main():
help
=
'Write the new checkpoint containing the averaged weights to this '
'path.'
,
)
parser
.
add_argument
(
'--num'
,
num_group
=
parser
.
add_mutually_exclusive_group
()
num_group
.
add_argument
(
'--num-epoch-checkpoints'
,
type
=
int
,
help
=
'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last
num
of th
ose
'
,
'and average last
this many
of th
em.
'
,
)
parser
.
add_argument
(
'--update-based-checkpoints'
,
action
=
'store_true'
,
help
=
'if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
num_group
.
add_argument
(
'--num-update-checkpoints'
,
type
=
int
,
help
=
'if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.'
,
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num
is
not
None
:
args
.
inputs
=
last_n_checkpoints
(
args
.
inputs
,
args
.
num
,
args
.
update_based_checkpoints
)
num
=
None
is_update_based
=
False
if
args
.
num_update_checkpoints
is
not
None
:
num
=
args
.
num_update_checkpoints
is_update_based
=
True
elif
args
.
num_epoch_checkpoints
is
not
None
:
num
=
args
.
num_epoch_checkpoints
if
num
is
not
None
:
args
.
inputs
=
last_n_checkpoints
(
args
.
inputs
,
num
,
is_update_based
)
print
(
'averaging checkpoints: '
,
args
.
inputs
)
new_state
=
average_checkpoints
(
args
.
inputs
)
...
...
train.py
View file @
a3e4c4c3
...
...
@@ -203,7 +203,7 @@ def get_training_stats(trainer):
return
stats
def
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
,
verbose
):
def
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
):
"""Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader
...
...
@@ -237,16 +237,6 @@ def validate(args, trainer, dataset, subset, epoch, num_updates, verbose):
for
sample
in
progress
:
log_output
=
trainer
.
valid_step
(
sample
)
if
verbose
:
# log mid-validation stats
stats
=
get_valid_stats
(
trainer
)
for
k
,
v
in
log_output
.
items
():
if
k
in
[
'loss'
,
'nll_loss'
,
'sample_size'
]:
continue
extra_meters
[
k
].
update
(
v
)
stats
[
k
]
=
extra_meters
[
k
].
avg
progress
.
log
(
stats
)
# log validation stats
stats
=
get_valid_stats
(
trainer
)
for
k
,
meter
in
extra_meters
.
items
():
...
...
@@ -283,7 +273,7 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets
=
args
.
valid_subset
.
split
(
','
)
# we want to validate all subsets so the results get printed out, but return only the first
losses
=
[
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
,
verbose
=
False
)
for
subset
in
subsets
]
losses
=
[
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
)
for
subset
in
subsets
]
return
losses
[
0
]
if
len
(
losses
)
>
0
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment