Commit b03b53b4 authored by Sergey Edunov's avatar Sergey Edunov Committed by Sergey Edunov
Browse files

Allow more flexible pre-processing and generation (#227)

* Allow more flexible pre-processing and generation

* Addressing CR comments

* small fix
parent e73fddf4
......@@ -57,17 +57,17 @@ def load_dataset(path, load_splits, src=None, dst=None):
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from binary files
def all_splits_exist(src, dst):
def all_splits_exist(src, dst, lang):
for split in load_splits:
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(path, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst):
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src):
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
......@@ -84,9 +84,13 @@ def load_dataset(path, load_splits, src=None, dst=None):
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(dst_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
......@@ -194,21 +198,20 @@ class LanguagePairDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
target = self.dst[i].long() - 1
return {
'id': i,
'source': source,
'target': target,
}
res = { 'id': i, 'source': source }
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx)
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def collate(samples, pad_idx, eos_idx):
def collate(samples, pad_idx, eos_idx, has_target=True):
if len(samples) == 0:
return {}
......@@ -220,6 +223,16 @@ class LanguagePairDataset(torch.utils.data.Dataset):
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
......@@ -228,18 +241,13 @@ class LanguagePairDataset(torch.utils.data.Dataset):
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': sum(len(s['target']) for s in samples),
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
......@@ -301,21 +309,23 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len = 0
ignored = []
for idx in map(int, indices):
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size
if not _valid_size(src_size, dst_size, max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src.sizes[idx], dst.sizes[idx], max_positions))
).format(idx, src_size, dst_size, max_positions))
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
sample_len = max(sample_len, src_size, dst_size)
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
yield batch
batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx])
sample_len = max(src_size, dst_size)
batch.append(idx)
......@@ -332,7 +342,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
descending=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
......
......@@ -77,11 +77,11 @@ class SequenceGenerator(object):
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
)
if timer is not None:
timer.stop(s['ntokens'])
timer.stop(sum([len(h[0]['tokens']) for h in hypos]))
for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :]
# remove padding from ref
ref = utils.strip_pad(s['target'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
......
......@@ -84,6 +84,7 @@ def main(args):
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
if args.score_reference:
translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
......@@ -94,17 +95,21 @@ def main(args):
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
target_tokens = target_tokens.int().cpu()
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
target_str = dataset.dst_dict.string(target_tokens,
args.remove_bpe,
escape_unk=True) if has_target else ''
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
......@@ -133,7 +138,7 @@ def main(args):
))
# Score only the top hypothesis
if i == 0:
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
......@@ -146,6 +151,7 @@ def main(args):
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
......
......@@ -21,9 +21,9 @@ def get_parser():
description='Data pre-processing: Create dictionary and store data in binary format')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--trainpref', metavar='FP', default='train', help='target language')
parser.add_argument('--validpref', metavar='FP', default='valid', help='comma separated, valid language prefixes')
parser.add_argument('--testpref', metavar='FP', default='test', help='comma separated, test language prefixes')
parser.add_argument('--trainpref', metavar='FP', default=None, help='target language')
parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid language prefixes')
parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test language prefixes')
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
......@@ -37,12 +37,14 @@ def get_parser():
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
return parser
def main(args):
print(args)
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
......@@ -60,14 +62,18 @@ def main(args):
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
if target:
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc)
if target:
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt)
......@@ -100,19 +106,26 @@ def main(args):
output_text_file = os.path.join(args.destdir, '{}.{}'.format(output_prefix, lang))
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file)
make_dataset(args.trainpref, 'train', args.source_lang, args.output_format)
make_dataset(args.trainpref, 'train', args.target_lang, args.output_format)
def make_all(args, make_dataset, lang):
if args.trainpref:
make_dataset(args.trainpref, 'train', lang, args.output_format)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, args.source_lang, args.output_format)
make_dataset(validpref, outprefix, args.target_lang, args.output_format)
make_dataset(validpref, outprefix, lang, args.output_format)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, args.source_lang, args.output_format)
make_dataset(testpref, outprefix, args.target_lang, args.output_format)
make_dataset(testpref, outprefix, lang, args.output_format)
make_all(args, make_dataset, args.source_lang)
if target:
make_all(args, make_dataset, args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = '{}.{}'.format(args.trainpref, args.source_lang)
tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang)
src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment