"src/vscode:/vscode.git/clone" did not exist on "a74f02fb40f5853175162852aac3f38f57b7d85c"
Commit 3f9b9838 authored by Sergey Edunov's avatar Sergey Edunov
Browse files

Ignore invalid sentences in test and valid

parent 8f058ea0
......@@ -91,7 +91,8 @@ class LanguageDatasets(object):
def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024):
sample_without_replacement=0, max_positions=1024,
skip_invalid_size_inputs_valid_test=False):
dataset = self.splits[split]
if split.startswith('train'):
with numpy_seed(seed):
......@@ -102,9 +103,11 @@ class LanguageDatasets(object):
max_positions=max_positions)
elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions))
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
else:
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions))
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader(
dataset,
......@@ -207,7 +210,8 @@ class LanguagePairDataset(object):
return res
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positions=1024):
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
max_positions=1024, ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert isinstance(src, IndexedDataset)
......@@ -233,14 +237,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
return False
cur_max_size = 0
ignored = []
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or \
(False if dst is None else dst.sizes[idx] < 2) or \
sizes[idx] > max_positions - 2:
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception("Unable to handle input id {} of "
"size {} / {}.".format(idx, src.sizes[idx], dst.sizes[idx]))
"size {} / {}.".format(idx, src.sizes[idx],
"none" if dst is None else dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch
......@@ -249,6 +259,10 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positio
batch.append(idx)
cur_max_size = max(cur_max_size, sizes[idx])
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
if len(batch) > 0:
yield batch
......
......@@ -243,6 +243,10 @@ class Decoder(nn.Module):
context += conv.kernel_size[0] - 1
return context
def max_positions(self):
"""Returns maximum size of positions embeddings supported by this decoder"""
return self.embed_positions.num_embeddings
def incremental_inference(self, beam_size=None):
"""Context manager for incremental inference.
......
......@@ -34,6 +34,8 @@ def add_dataset_args(parser):
help='number of data loading workers (default: 1)')
group.add_argument('--max-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set')
return group
......
......@@ -35,8 +35,8 @@ class SequenceGenerator(object):
self.vocab_size = len(dst_dict)
self.beam_size = beam_size
self.minlen = minlen
self.maxlen = maxlen
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + maxlen + 2))
self.maxlen = min(maxlen, *(m.decoder.max_positions() - self.pad - 2 for m in self.models))
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early
self.normalize_scores = normalize_scores
......
......@@ -45,6 +45,10 @@ def main():
if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize model for generation
for model in models:
model.make_generation_fast_(not args.no_beamable_mm)
......@@ -122,7 +126,9 @@ def main():
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, max_positions=args.max_positions)
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
wps_meter = TimeMeter()
......
......@@ -107,7 +107,8 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
max_positions=args.max_positions,
sample_without_replacement=args.sample_without_replacement)
sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
......@@ -163,7 +164,8 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
itr = dataset.dataloader(subset, batch_size=None,
max_tokens=args.max_tokens,
max_positions=args.max_positions)
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter()
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment