Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
3f9b9838
"src/vscode:/vscode.git/clone" did not exist on "a74f02fb40f5853175162852aac3f38f57b7d85c"
Commit
3f9b9838
authored
Oct 11, 2017
by
Sergey Edunov
Browse files
Ignore invalid sentences in test and valid
parent
8f058ea0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
10 deletions
+38
-10
fairseq/data.py
fairseq/data.py
+19
-5
fairseq/models/fconv.py
fairseq/models/fconv.py
+4
-0
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+2
-2
generate.py
generate.py
+7
-1
train.py
train.py
+4
-2
No files found.
fairseq/data.py
View file @
3f9b9838
...
...
@@ -91,7 +91,8 @@ class LanguageDatasets(object):
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
1024
):
sample_without_replacement
=
0
,
max_positions
=
1024
,
skip_invalid_size_inputs_valid_test
=
False
):
dataset
=
self
.
splits
[
split
]
if
split
.
startswith
(
'train'
):
with
numpy_seed
(
seed
):
...
...
@@ -102,9 +103,11 @@ class LanguageDatasets(object):
max_positions
=
max_positions
)
elif
split
.
startswith
(
'valid'
):
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
max_positions
=
max_positions
))
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
else
:
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
))
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
...
...
@@ -207,7 +210,8 @@ class LanguagePairDataset(object):
return
res
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
):
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
,
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
...
...
@@ -233,14 +237,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
return
False
cur_max_size
=
0
ignored
=
[]
for
idx
in
indices
:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if
src
.
sizes
[
idx
]
<
2
or
\
(
False
if
dst
is
None
else
dst
.
sizes
[
idx
]
<
2
)
or
\
sizes
[
idx
]
>
max_positions
-
2
:
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
(
"Unable to handle input id {} of "
"size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
"size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
yield
batch
...
...
@@ -249,6 +259,10 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
batch
.
append
(
idx
)
cur_max_size
=
max
(
cur_max_size
,
sizes
[
idx
])
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
if
len
(
batch
)
>
0
:
yield
batch
...
...
fairseq/models/fconv.py
View file @
3f9b9838
...
...
@@ -243,6 +243,10 @@ class Decoder(nn.Module):
context
+=
conv
.
kernel_size
[
0
]
-
1
return
context
def
max_positions
(
self
):
"""Returns maximum size of positions embeddings supported by this decoder"""
return
self
.
embed_positions
.
num_embeddings
def
incremental_inference
(
self
,
beam_size
=
None
):
"""Context manager for incremental inference.
...
...
fairseq/options.py
View file @
3f9b9838
...
...
@@ -34,6 +34,8 @@ def add_dataset_args(parser):
help
=
'number of data loading workers (default: 1)'
)
group
.
add_argument
(
'--max-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the sequence'
)
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
help
=
'Ignore too long or too short lines in valid and test set'
)
return
group
...
...
fairseq/sequence_generator.py
View file @
3f9b9838
...
...
@@ -35,8 +35,8 @@ class SequenceGenerator(object):
self
.
vocab_size
=
len
(
dst_dict
)
self
.
beam_size
=
beam_size
self
.
minlen
=
minlen
self
.
maxlen
=
maxlen
self
.
positions
=
torch
.
LongTensor
(
range
(
self
.
pad
+
1
,
self
.
pad
+
maxlen
+
2
))
self
.
maxlen
=
min
(
maxlen
,
*
(
m
.
decoder
.
max_positions
()
-
self
.
pad
-
2
for
m
in
self
.
models
))
self
.
positions
=
torch
.
LongTensor
(
range
(
self
.
pad
+
1
,
self
.
pad
+
self
.
maxlen
+
2
))
self
.
decoder_context
=
models
[
0
].
decoder
.
context_size
()
self
.
stop_early
=
stop_early
self
.
normalize_scores
=
normalize_scores
...
...
generate.py
View file @
3f9b9838
...
...
@@ -45,6 +45,10 @@ def main():
if
not
args
.
interactive
:
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
args
.
gen_subset
,
len
(
dataset
.
splits
[
args
.
gen_subset
])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args
.
max_positions
=
min
(
args
.
max_positions
,
*
(
m
.
decoder
.
max_positions
()
for
m
in
models
))
# Optimize model for generation
for
model
in
models
:
model
.
make_generation_fast_
(
not
args
.
no_beamable_mm
)
...
...
@@ -122,7 +126,9 @@ def main():
# Generate and compute BLEU score
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
itr
=
dataset
.
dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
args
.
max_positions
)
itr
=
dataset
.
dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
args
.
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
wps_meter
=
TimeMeter
()
...
...
train.py
View file @
3f9b9838
...
...
@@ -107,7 +107,8 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
args
.
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
)
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
...
...
@@ -163,7 +164,8 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
args
.
max_positions
)
max_positions
=
args
.
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
desc
=
'| epoch {:03d} | valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment