Commit 42a0150c authored by Louis Martin's avatar Louis Martin Committed by Myle Ott
Browse files

Replace unk with original string

* Add <eos> for unk replacement
* Add IndexedRawTextDataset to load raw text files
* Replace unk with original string
* Add load_raw_text_dataset() and --output-format
* Move has_binary_files to data.py
parent 7d44181d
......@@ -8,6 +8,7 @@
import contextlib
import itertools
import glob
import numbers
import numpy as np
import os
......@@ -15,7 +16,14 @@ import torch
import torch.utils.data
from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
def has_binary_files(data_dir, splits):
for split in splits:
if len(glob.glob(os.path.join(data_dir, f'{split}.*-*.*.bin'))) < 2:
return False
return True
def infer_language_pair(path, splits):
......@@ -43,7 +51,12 @@ def load_dataset(path, load_splits, src=None, dst=None):
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from binary files
def all_splits_exist(src, dst):
for split in load_splits:
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
......@@ -59,9 +72,6 @@ def load_dataset(path, load_splits, src=None, dst=None):
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args))
......@@ -84,6 +94,30 @@ def load_dataset(path, load_splits, src=None, dst=None):
return dataset
def load_raw_text_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from raw text files
for split in load_splits:
src_path = os.path.join(path, f'{split}.{src}')
dst_path = os.path.join(path, f'{split}.{dst}')
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
......
......@@ -11,6 +11,8 @@ import os
import struct
import torch
from fairseq.tokenizer import Tokenizer
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
......@@ -59,12 +61,15 @@ class IndexedDataset(object):
def read_data(self, path):
self.data_file = open(path + '.bin', 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __del__(self):
self.data_file.close()
def __getitem__(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
......@@ -92,14 +97,49 @@ class IndexedInMemoryDataset(IndexedDataset):
pass
def __getitem__(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a)
class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary):
self.tokens_list = []
self.lines = []
self.sizes = []
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
with open(path, 'r') as f:
for line in f:
self.lines.append(line.strip('\n'))
# +1 for Lua compatibility
tokens = Tokenizer.tokenize(line, dictionary, add_if_not_exist=False) + 1
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
class IndexedDatasetBuilder(object):
element_sizes = {
......
......@@ -192,7 +192,8 @@ def load_align_dict(replace_unk):
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
# Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str)
src_tokens = tokenizer.tokenize_line(src_str)
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>']
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[alignment[i]]
......
......@@ -34,7 +34,10 @@ def main():
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
if args.replace_unk is None:
dataset = data.load_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
......@@ -80,8 +83,14 @@ def main():
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
target_tokens = target_tokens.int().cpu()
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
print('T-{}\t{}'.format(sample_id, target_str))
......@@ -102,8 +111,8 @@ def main():
# Score only the top hypothesis
if i == 0:
if args.remove_bpe is not None:
# Convert the string without BPE back to tokens for evaluation
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(target_str,
dataset.dst_dict,
add_if_not_exist=True)
......
......@@ -8,8 +8,9 @@
#
import argparse
import os
from itertools import zip_longest
import os
import shutil
from fairseq import dictionary, indexed_dataset
from fairseq.tokenizer import Tokenizer
......@@ -33,10 +34,11 @@ def main():
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)')
args = parser.parse_args()
print(args)
os.makedirs(args.destdir, exist_ok=True)
if args.srcdict:
......@@ -53,7 +55,7 @@ def main():
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt)
def make_dataset(input_prefix, output_prefix, lang):
def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
......@@ -74,16 +76,24 @@ def main():
args.destdir, output_prefix,
args.source_lang, args.target_lang, lang))
make_dataset(args.trainpref, 'train', args.source_lang)
make_dataset(args.trainpref, 'train', args.target_lang)
def make_dataset(input_prefix, output_prefix, lang, output_format='binary'):
if output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw':
# Copy original text file to destination folder
output_text_file = os.path.join(args.destdir, f'{output_prefix}.{lang}')
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file)
make_dataset(args.trainpref, 'train', args.source_lang, args.output_format)
make_dataset(args.trainpref, 'train', args.target_lang, args.output_format)
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, args.source_lang)
make_dataset(validpref, outprefix, args.target_lang)
make_dataset(validpref, outprefix, args.source_lang, args.output_format)
make_dataset(validpref, outprefix, args.target_lang, args.output_format)
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, args.source_lang)
make_dataset(testpref, outprefix, args.target_lang)
make_dataset(testpref, outprefix, args.source_lang, args.output_format)
make_dataset(testpref, outprefix, args.target_lang, args.output_format)
print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile:
......
......@@ -46,7 +46,11 @@ def main():
torch.manual_seed(args.seed)
# Load dataset
dataset = data.load_dataset(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
splits = ['train', 'valid']
if data.has_binary_files(args.data, splits):
dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
......@@ -54,7 +58,7 @@ def main():
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in ['train', 'valid']:
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment