Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
3f9b9838
Commit
3f9b9838
authored
Oct 11, 2017
by
Sergey Edunov
Browse files
Ignore invalid sentences in test and valid
parent
8f058ea0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
10 deletions
+38
-10
fairseq/data.py
fairseq/data.py
+19
-5
fairseq/models/fconv.py
fairseq/models/fconv.py
+4
-0
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+2
-2
generate.py
generate.py
+7
-1
train.py
train.py
+4
-2
No files found.
fairseq/data.py
View file @
3f9b9838
...
@@ -91,7 +91,8 @@ class LanguageDatasets(object):
...
@@ -91,7 +91,8 @@ class LanguageDatasets(object):
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
1024
):
sample_without_replacement
=
0
,
max_positions
=
1024
,
skip_invalid_size_inputs_valid_test
=
False
):
dataset
=
self
.
splits
[
split
]
dataset
=
self
.
splits
[
split
]
if
split
.
startswith
(
'train'
):
if
split
.
startswith
(
'train'
):
with
numpy_seed
(
seed
):
with
numpy_seed
(
seed
):
...
@@ -102,9 +103,11 @@ class LanguageDatasets(object):
...
@@ -102,9 +103,11 @@ class LanguageDatasets(object):
max_positions
=
max_positions
)
max_positions
=
max_positions
)
elif
split
.
startswith
(
'valid'
):
elif
split
.
startswith
(
'valid'
):
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
max_positions
=
max_positions
))
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
else
:
else
:
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
))
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
dataset
,
...
@@ -207,7 +210,8 @@ class LanguagePairDataset(object):
...
@@ -207,7 +210,8 @@ class LanguagePairDataset(object):
return
res
return
res
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
):
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
,
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences of different lengths
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
...
@@ -233,14 +237,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
...
@@ -233,14 +237,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
return
False
return
False
cur_max_size
=
0
cur_max_size
=
0
ignored
=
[]
for
idx
in
indices
:
for
idx
in
indices
:
# - 2 here stems from make_positions() where we offset positions
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
# by padding_value + 1
if
src
.
sizes
[
idx
]
<
2
or
\
if
src
.
sizes
[
idx
]
<
2
or
\
(
False
if
dst
is
None
else
dst
.
sizes
[
idx
]
<
2
)
or
\
(
False
if
dst
is
None
else
dst
.
sizes
[
idx
]
<
2
)
or
\
sizes
[
idx
]
>
max_positions
-
2
:
sizes
[
idx
]
>
max_positions
-
2
:
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
(
"Unable to handle input id {} of "
raise
Exception
(
"Unable to handle input id {} of "
"size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
"size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
yield
batch
yield
batch
...
@@ -249,6 +259,10 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
...
@@ -249,6 +259,10 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
batch
.
append
(
idx
)
batch
.
append
(
idx
)
cur_max_size
=
max
(
cur_max_size
,
sizes
[
idx
])
cur_max_size
=
max
(
cur_max_size
,
sizes
[
idx
])
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
if
len
(
batch
)
>
0
:
if
len
(
batch
)
>
0
:
yield
batch
yield
batch
...
...
fairseq/models/fconv.py
View file @
3f9b9838
...
@@ -243,6 +243,10 @@ class Decoder(nn.Module):
...
@@ -243,6 +243,10 @@ class Decoder(nn.Module):
context
+=
conv
.
kernel_size
[
0
]
-
1
context
+=
conv
.
kernel_size
[
0
]
-
1
return
context
return
context
def
max_positions
(
self
):
"""Returns maximum size of positions embeddings supported by this decoder"""
return
self
.
embed_positions
.
num_embeddings
def
incremental_inference
(
self
,
beam_size
=
None
):
def
incremental_inference
(
self
,
beam_size
=
None
):
"""Context manager for incremental inference.
"""Context manager for incremental inference.
...
...
fairseq/options.py
View file @
3f9b9838
...
@@ -34,6 +34,8 @@ def add_dataset_args(parser):
...
@@ -34,6 +34,8 @@ def add_dataset_args(parser):
help
=
'number of data loading workers (default: 1)'
)
help
=
'number of data loading workers (default: 1)'
)
group
.
add_argument
(
'--max-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
group
.
add_argument
(
'--max-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the sequence'
)
help
=
'max number of tokens in the sequence'
)
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
help
=
'Ignore too long or too short lines in valid and test set'
)
return
group
return
group
...
...
fairseq/sequence_generator.py
View file @
3f9b9838
...
@@ -35,8 +35,8 @@ class SequenceGenerator(object):
...
@@ -35,8 +35,8 @@ class SequenceGenerator(object):
self
.
vocab_size
=
len
(
dst_dict
)
self
.
vocab_size
=
len
(
dst_dict
)
self
.
beam_size
=
beam_size
self
.
beam_size
=
beam_size
self
.
minlen
=
minlen
self
.
minlen
=
minlen
self
.
maxlen
=
maxlen
self
.
maxlen
=
min
(
maxlen
,
*
(
m
.
decoder
.
max_positions
()
-
self
.
pad
-
2
for
m
in
self
.
models
))
self
.
positions
=
torch
.
LongTensor
(
range
(
self
.
pad
+
1
,
self
.
pad
+
maxlen
+
2
))
self
.
positions
=
torch
.
LongTensor
(
range
(
self
.
pad
+
1
,
self
.
pad
+
self
.
maxlen
+
2
))
self
.
decoder_context
=
models
[
0
].
decoder
.
context_size
()
self
.
decoder_context
=
models
[
0
].
decoder
.
context_size
()
self
.
stop_early
=
stop_early
self
.
stop_early
=
stop_early
self
.
normalize_scores
=
normalize_scores
self
.
normalize_scores
=
normalize_scores
...
...
generate.py
View file @
3f9b9838
...
@@ -45,6 +45,10 @@ def main():
...
@@ -45,6 +45,10 @@ def main():
if
not
args
.
interactive
:
if
not
args
.
interactive
:
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
args
.
gen_subset
,
len
(
dataset
.
splits
[
args
.
gen_subset
])))
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
args
.
gen_subset
,
len
(
dataset
.
splits
[
args
.
gen_subset
])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args
.
max_positions
=
min
(
args
.
max_positions
,
*
(
m
.
decoder
.
max_positions
()
for
m
in
models
))
# Optimize model for generation
# Optimize model for generation
for
model
in
models
:
for
model
in
models
:
model
.
make_generation_fast_
(
not
args
.
no_beamable_mm
)
model
.
make_generation_fast_
(
not
args
.
no_beamable_mm
)
...
@@ -122,7 +126,9 @@ def main():
...
@@ -122,7 +126,9 @@ def main():
# Generate and compute BLEU score
# Generate and compute BLEU score
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
itr
=
dataset
.
dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
args
.
max_positions
)
itr
=
dataset
.
dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
args
.
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
wps_meter
=
TimeMeter
()
wps_meter
=
TimeMeter
()
...
...
train.py
View file @
3f9b9838
...
@@ -107,7 +107,8 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
...
@@ -107,7 +107,8 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
args
.
max_positions
,
max_positions
=
args
.
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
)
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
wpb_meter
=
AverageMeter
()
# words per batch
...
@@ -163,7 +164,8 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
...
@@ -163,7 +164,8 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
args
.
max_positions
)
max_positions
=
args
.
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
desc
=
'| epoch {:03d} | valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
desc
=
'| epoch {:03d} | valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment