Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a919570b
Commit
a919570b
authored
May 30, 2018
by
Myle Ott
Browse files
Merge validate and val_loss functions (simplify train.py)
parent
6643d525
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
57 deletions
+53
-57
train.py
train.py
+53
-57
No files found.
train.py
View file @
a919570b
...
...
@@ -88,19 +88,20 @@ def main(args):
first_val_loss
=
None
train_meter
=
StopwatchMeter
()
train_meter
.
start
()
valid_subsets
=
args
.
valid_subset
.
split
(
','
)
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
and
trainer
.
get_num_updates
()
<
max_update
:
# train for one epoch
train
(
args
,
trainer
,
next_ds
,
epoch
,
dataset
)
if
epoch
%
args
.
validate_interval
==
0
:
first_
val_loss
=
val
_loss
(
args
,
trainer
,
dataset
,
epoch
)
val
id
_loss
es
=
val
idate
(
args
,
trainer
,
dataset
,
valid_subsets
,
epoch
)
# only use first validation loss to update the learning rate
lr
=
trainer
.
lr_step
(
epoch
,
first_
val_loss
)
lr
=
trainer
.
lr_step
(
epoch
,
val
id
_loss
es
[
0
]
)
# save checkpoint
if
epoch
%
args
.
save_interval
==
0
:
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
=
True
,
val_loss
=
first_
val_loss
)
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
=
True
,
val_loss
=
val
id
_loss
es
[
0
]
)
epoch
+=
1
next_ds
=
next
(
train_dataloader
)
...
...
@@ -135,6 +136,7 @@ def train(args, trainer, itr, epoch, dataset):
update_freq
=
args
.
update_freq
[
-
1
]
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
first_valid
=
args
.
valid_subset
.
split
(
','
)[
0
]
max_update
=
args
.
max_update
or
math
.
inf
num_batches
=
len
(
itr
)
progress
=
progress_bar
.
build_progress_bar
(
args
,
itr
,
epoch
,
no_progress_bar
=
'simple'
)
...
...
@@ -164,8 +166,8 @@ def train(args, trainer, itr, epoch, dataset):
num_updates
=
trainer
.
get_num_updates
()
if
args
.
save_interval_updates
>
0
and
num_updates
%
args
.
save_interval_updates
==
0
:
first_
val_loss
=
val
_loss
(
args
,
trainer
,
dataset
,
epoch
,
num_updates
)
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
=
False
,
val_loss
=
first_
val_loss
)
val
id
_loss
es
=
val
idate
(
args
,
trainer
,
dataset
,
[
first_valid
],
epoch
)
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
=
False
,
val_loss
=
val
id
_loss
es
[
0
]
)
if
num_updates
>=
max_update
:
break
...
...
@@ -201,52 +203,54 @@ def get_training_stats(trainer):
return
stats
def
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
):
"""Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader
max_positions_valid
=
(
trainer
.
get_model
().
max_encoder_positions
(),
trainer
.
get_model
().
max_decoder_positions
(),
)
itr
=
dataset
.
eval_dataloader
(
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences_valid
,
max_positions
=
max_positions_valid
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
descending
=
True
,
# largest batch first to warm the caching allocator
shard_id
=
args
.
distributed_rank
,
num_shards
=
args
.
distributed_world_size
,
)
progress
=
progress_bar
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
),
no_progress_bar
=
'simple'
)
# reset validation loss meters
for
k
in
[
'valid_loss'
,
'valid_nll_loss'
]:
meter
=
trainer
.
get_meter
(
k
)
if
meter
is
not
None
:
meter
.
reset
()
def
validate
(
args
,
trainer
,
dataset
,
subsets
,
epoch
):
"""Evaluate the model on the validation set(s) and return the losses."""
valid_losses
=
[]
for
subset
in
subsets
:
# Initialize dataloader
max_positions_valid
=
(
trainer
.
get_model
().
max_encoder_positions
(),
trainer
.
get_model
().
max_decoder_positions
(),
)
itr
=
dataset
.
eval_dataloader
(
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences_valid
,
max_positions
=
max_positions_valid
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
descending
=
True
,
# largest batch first to warm the caching allocator
shard_id
=
args
.
distributed_rank
,
num_shards
=
args
.
distributed_world_size
,
)
progress
=
progress_bar
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
),
no_progress_bar
=
'simple'
)
# reset validation loss meters
for
k
in
[
'valid_loss'
,
'valid_nll_loss'
]:
meter
=
trainer
.
get_meter
(
k
)
if
meter
is
not
None
:
meter
.
reset
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
for
sample
in
progress
:
log_output
=
trainer
.
valid_step
(
sample
)
for
k
,
v
in
log_output
.
items
():
if
k
in
[
'loss'
,
'nll_loss'
,
'sample_size'
]:
continue
extra_meters
[
k
].
update
(
v
)
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
for
sample
in
progress
:
log_output
=
trainer
.
valid_step
(
sample
)
# log validation stats
stats
=
get_valid_stats
(
trainer
)
for
k
,
meter
in
extra_meters
.
items
():
stats
[
k
]
=
meter
.
avg
progress
.
print
(
stats
)
for
k
,
v
in
log_output
.
items
():
if
k
in
[
'loss'
,
'nll_loss'
,
'sample_size'
]:
continue
extra_meters
[
k
].
update
(
v
)
# log validation stats
stats
=
get_valid_stats
(
trainer
)
for
k
,
meter
in
extra_meters
.
items
():
stats
[
k
]
=
meter
.
avg
progress
.
print
(
stats
)
return
stats
[
'valid_loss'
]
valid_losses
.
append
(
stats
[
'valid_loss'
])
return
valid_losses
def
get_valid_stats
(
trainer
):
...
...
@@ -271,14 +275,6 @@ def get_perplexity(loss):
return
float
(
'inf'
)
def
val_loss
(
args
,
trainer
,
dataset
,
epoch
,
num_updates
=
None
):
# evaluate on validate set
subsets
=
args
.
valid_subset
.
split
(
','
)
# we want to validate all subsets so the results get printed out, but return only the first
losses
=
[
validate
(
args
,
trainer
,
dataset
,
subset
,
epoch
,
num_updates
)
for
subset
in
subsets
]
return
losses
[
0
]
if
len
(
losses
)
>
0
else
None
def
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
,
val_loss
):
if
args
.
no_save
or
args
.
distributed_rank
>
0
:
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment