Commit 3f9b9838 authored by Sergey Edunov's avatar Sergey Edunov
Browse files

Ignore invalid sentences in test and valid

parent 8f058ea0
...@@ -91,7 +91,8 @@ class LanguageDatasets(object): ...@@ -91,7 +91,8 @@ class LanguageDatasets(object):
def dataloader(self, split, batch_size=1, num_workers=0, def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1, max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024): sample_without_replacement=0, max_positions=1024,
skip_invalid_size_inputs_valid_test=False):
dataset = self.splits[split] dataset = self.splits[split]
if split.startswith('train'): if split.startswith('train'):
with numpy_seed(seed): with numpy_seed(seed):
...@@ -102,9 +103,11 @@ class LanguageDatasets(object): ...@@ -102,9 +103,11 @@ class LanguageDatasets(object):
max_positions=max_positions) max_positions=max_positions)
elif split.startswith('valid'): elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst, batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions)) max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
else: else:
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions)) batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, dataset,
...@@ -207,7 +210,8 @@ class LanguagePairDataset(object): ...@@ -207,7 +210,8 @@ class LanguagePairDataset(object):
return res return res
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positions=1024): def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
max_positions=1024, ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences of different lengths """Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch.""" are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) assert isinstance(src, IndexedDataset)
...@@ -233,14 +237,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio ...@@ -233,14 +237,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
return False return False
cur_max_size = 0 cur_max_size = 0
ignored = []
for idx in indices: for idx in indices:
# - 2 here stems from make_positions() where we offset positions # - 2 here stems from make_positions() where we offset positions
# by padding_value + 1 # by padding_value + 1
if src.sizes[idx] < 2 or \ if src.sizes[idx] < 2 or \
(False if dst is None else dst.sizes[idx] < 2) or \ (False if dst is None else dst.sizes[idx] < 2) or \
sizes[idx] > max_positions - 2: sizes[idx] > max_positions - 2:
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception("Unable to handle input id {} of " raise Exception("Unable to handle input id {} of "
"size {} / {}.".format(idx, src.sizes[idx], dst.sizes[idx])) "size {} / {}.".format(idx, src.sizes[idx],
"none" if dst is None else dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)): if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch yield batch
...@@ -249,6 +259,10 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio ...@@ -249,6 +259,10 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
batch.append(idx) batch.append(idx)
cur_max_size = max(cur_max_size, sizes[idx]) cur_max_size = max(cur_max_size, sizes[idx])
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
if len(batch) > 0: if len(batch) > 0:
yield batch yield batch
......
...@@ -243,6 +243,10 @@ class Decoder(nn.Module): ...@@ -243,6 +243,10 @@ class Decoder(nn.Module):
context += conv.kernel_size[0] - 1 context += conv.kernel_size[0] - 1
return context return context
def max_positions(self):
"""Returns maximum size of positions embeddings supported by this decoder"""
return self.embed_positions.num_embeddings
def incremental_inference(self, beam_size=None): def incremental_inference(self, beam_size=None):
"""Context manager for incremental inference. """Context manager for incremental inference.
......
...@@ -34,6 +34,8 @@ def add_dataset_args(parser): ...@@ -34,6 +34,8 @@ def add_dataset_args(parser):
help='number of data loading workers (default: 1)') help='number of data loading workers (default: 1)')
group.add_argument('--max-positions', default=1024, type=int, metavar='N', group.add_argument('--max-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the sequence') help='max number of tokens in the sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set')
return group return group
......
...@@ -35,8 +35,8 @@ class SequenceGenerator(object): ...@@ -35,8 +35,8 @@ class SequenceGenerator(object):
self.vocab_size = len(dst_dict) self.vocab_size = len(dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
self.maxlen = maxlen self.maxlen = min(maxlen, *(m.decoder.max_positions() - self.pad - 2 for m in self.models))
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + maxlen + 2)) self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size() self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early self.stop_early = stop_early
self.normalize_scores = normalize_scores self.normalize_scores = normalize_scores
......
...@@ -45,6 +45,10 @@ def main(): ...@@ -45,6 +45,10 @@ def main():
if not args.interactive: if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset]))) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize model for generation # Optimize model for generation
for model in models: for model in models:
model.make_generation_fast_(not args.no_beamable_mm) model.make_generation_fast_(not args.no_beamable_mm)
...@@ -122,7 +126,9 @@ def main(): ...@@ -122,7 +126,9 @@ def main():
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, max_positions=args.max_positions) itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0 num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t: with progress_bar(itr, smoothing=0, leave=False) as t:
wps_meter = TimeMeter() wps_meter = TimeMeter()
......
...@@ -107,7 +107,8 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): ...@@ -107,7 +107,8 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
itr = dataset.dataloader(args.train_subset, num_workers=args.workers, itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
max_positions=args.max_positions, max_positions=args.max_positions,
sample_without_replacement=args.sample_without_replacement) sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter() loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch wpb_meter = AverageMeter() # words per batch
...@@ -163,7 +164,8 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): ...@@ -163,7 +164,8 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
itr = dataset.dataloader(subset, batch_size=None, itr = dataset.dataloader(subset, batch_size=None,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_positions=args.max_positions) max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter() loss_meter = AverageMeter()
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment