Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
c542884d
Commit
c542884d
authored
Jan 01, 2018
by
Myle Ott
Browse files
Add --max-sentences-valid to train.py
parent
eb005cdb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
1 deletion
+10
-1
train.py
train.py
+10
-1
No files found.
train.py
View file @
c542884d
...
...
@@ -30,6 +30,8 @@ def main():
dataset_args
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list of data subsets '
' to use for validation (train, valid, valid1,test, test1)'
)
dataset_args
.
add_argument
(
'--max-sentences-valid'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a validation batch'
)
options
.
add_optimization_args
(
parser
)
options
.
add_checkpoint_args
(
parser
)
options
.
add_model_args
(
parser
)
...
...
@@ -39,6 +41,9 @@ def main():
if
args
.
no_progress_bar
and
args
.
log_format
is
None
:
args
.
log_format
=
'simple'
if
args
.
max_sentences_valid
is
None
:
args
.
max_sentences_valid
=
args
.
max_sentences
if
not
os
.
path
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
torch
.
manual_seed
(
args
.
seed
)
...
...
@@ -218,6 +223,10 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
save_checkpoint
.
best
=
val_loss
best_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_best.pt'
)
trainer
.
save_checkpoint
(
best_filename
,
extra_state
)
elif
not
args
.
no_epoch_checkpoints
:
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}_{}.pt'
.
format
(
epoch
,
batch_offset
))
trainer
.
save_checkpoint
(
epoch_filename
,
extra_state
)
last_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_last.pt'
)
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
...
...
@@ -227,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
eval_dataloader
(
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
_valid
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
descending
=
True
,
# largest batch first to warm the caching allocator
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment