Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
67af40c9
Commit
67af40c9
authored
May 15, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jun 15, 2018
Browse files
allow specifying max_tokens for generation
parent
a5e49364
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
3 deletions
+12
-3
fairseq/options.py
fairseq/options.py
+1
-1
generate.py
generate.py
+7
-2
train.py
train.py
+4
-0
No files found.
fairseq/options.py
View file @
67af40c9
...
...
@@ -106,7 +106,7 @@ def add_dataset_args(parser, train=False, gen=False):
help
=
'max number of tokens in the target sequence'
)
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
help
=
'Ignore too long or too short lines in valid and test set'
)
group
.
add_argument
(
'--max-tokens'
,
default
=
6000
,
type
=
int
,
metavar
=
'N'
,
group
.
add_argument
(
'--max-tokens'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of tokens in a batch'
)
group
.
add_argument
(
'--max-sentences'
,
'--batch-size'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a batch'
)
...
...
generate.py
View file @
67af40c9
...
...
@@ -16,6 +16,10 @@ from fairseq.sequence_scorer import SequenceScorer
def
main
(
args
):
assert
args
.
path
is
not
None
,
'--path required for generation!'
if
args
.
max_tokens
is
None
and
args
.
max_sentences
is
None
:
args
.
max_tokens
=
12000
print
(
args
)
assert
not
args
.
sampling
or
args
.
nbest
==
args
.
beam
,
\
'--sampling requires --nbest to be equal to --beam'
...
...
@@ -58,12 +62,13 @@ def main(args):
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict
=
utils
.
load_align_dict
(
args
.
replace_unk
)
# Load dataset (possibly sharded)
max_positions
=
min
(
model
.
max_encoder_positions
()
for
model
in
models
)
itr
=
dataset
.
eval_dataloader
(
args
.
gen_subset
,
max_sentences
=
args
.
max_sentences
or
128
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
)
...
...
train.py
View file @
67af40c9
...
...
@@ -18,6 +18,10 @@ from fairseq.meters import AverageMeter, StopwatchMeter
def
main
(
args
):
if
args
.
max_tokens
is
None
:
args
.
max_tokens
=
6000
print
(
args
)
if
not
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment