Commit 862cad11 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Parallel preprocessing

parent ee46c63b
...@@ -52,9 +52,15 @@ def data_file_path(prefix_path): ...@@ -52,9 +52,15 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False): def __init__(self, path, fix_lua_indexing=False, read_data=True):
super().__init__() super().__init__()
self.fix_lua_indexing = fix_lua_indexing self.fix_lua_indexing = fix_lua_indexing
self.read_index(path)
self.data_file = None
if read_data:
self.read_data(path)
def read_index(self, path):
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
magic = f.read(8) magic = f.read(8)
assert magic == b'TNTIDX\x00\x00' assert magic == b'TNTIDX\x00\x00'
...@@ -66,7 +72,6 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -66,7 +72,6 @@ class IndexedDataset(torch.utils.data.Dataset):
self.dim_offsets = read_longs(f, self.size + 1) self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1) self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s) self.sizes = read_longs(f, self.s)
self.read_data(path)
def read_data(self, path): def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0) self.data_file = open(data_file_path(path), 'rb', buffering=0)
...@@ -76,7 +81,8 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -76,7 +81,8 @@ class IndexedDataset(torch.utils.data.Dataset):
raise IndexError('index out of range') raise IndexError('index out of range')
def __del__(self): def __del__(self):
self.data_file.close() if self.data_file:
self.data_file.close()
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
...@@ -193,6 +199,26 @@ class IndexedDatasetBuilder(object): ...@@ -193,6 +199,26 @@ class IndexedDatasetBuilder(object):
self.sizes.append(s) self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file): def finalize(self, index_file):
self.out_file.close() self.out_file.close()
index = open(index_file, 'wb') index = open(index_file, 'wb')
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import Counter from collections import Counter
import re import os, re
import torch import torch
from multiprocessing import Pool
SPACE_NORMALIZER = re.compile("\s+") SPACE_NORMALIZER = re.compile("\s+")
...@@ -20,28 +20,74 @@ def tokenize_line(line): ...@@ -20,28 +20,74 @@ def tokenize_line(line):
return line.split() return line.split()
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Tokenizer: class Tokenizer:
@staticmethod @staticmethod
def add_file_to_dictionary(filename, dict, tokenize): def add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r') as f: with open(filename, 'r') as f:
for line in f: size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line): for word in tokenize(line):
dict.add_symbol(word) counter.update([word])
dict.add_symbol(dict.eos_word) counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Tokenizer.add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod @staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line, def binarize(filename, dict, consumer, tokenize=tokenize_line,
append_eos=True, reverse_order=False): append_eos=True, reverse_order=False,
offset=0, end=-1):
nseq, ntok = 0, 0 nseq, ntok = 0, 0
replaced = Counter() replaced = Counter()
def replaced_consumer(word, idx): def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word: if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word]) replaced.update([word])
with open(filename, 'r') as f: with open(filename, 'r') as f:
for line in f: f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = Tokenizer.tokenize( ids = Tokenizer.tokenize(
line=line, line=line,
dict=dict, dict=dict,
...@@ -52,10 +98,22 @@ class Tokenizer: ...@@ -52,10 +98,22 @@ class Tokenizer:
reverse_order=reverse_order, reverse_order=reverse_order,
) )
nseq += 1 nseq += 1
consumer(ids)
ntok += len(ids) ntok += len(ids)
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)} consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
@staticmethod @staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
......
...@@ -10,12 +10,16 @@ Data pre-processing: build vocabularies and binarize training data. ...@@ -10,12 +10,16 @@ Data pre-processing: build vocabularies and binarize training data.
""" """
import argparse import argparse
from collections import Counter
from itertools import zip_longest from itertools import zip_longest
import os import os
import shutil import shutil
from fairseq.data import indexed_dataset, dictionary from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process
def get_parser(): def get_parser():
...@@ -41,6 +45,7 @@ def get_parser(): ...@@ -41,6 +45,7 @@ def get_parser():
parser.add_argument('--only-source', action='store_true', help='Only process the source language') parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, type=int, parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
help='Pad dictionary size to be multiple of N') help='Pad dictionary size to be multiple of N')
parser.add_argument('--workers', metavar='N', default=1, type=int, help='number of parallel workers')
return parser return parser
...@@ -52,7 +57,7 @@ def main(args): ...@@ -52,7 +57,7 @@ def main(args):
def build_dictionary(filenames): def build_dictionary(filenames):
d = dictionary.Dictionary() d = dictionary.Dictionary()
for filename in filenames: for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line) Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, args.workers)
return d return d
def train_path(lang): def train_path(lang):
...@@ -70,11 +75,6 @@ def main(args): ...@@ -70,11 +75,6 @@ def main(args):
def dict_path(lang): def dict_path(lang):
return dest_path('dict', lang) + '.txt' return dest_path('dict', lang) + '.txt'
def dataset_dest_path(output_prefix, lang, extension):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}.{extension}'
if args.joined_dictionary: if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary' assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary' assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
...@@ -111,25 +111,54 @@ def main(args): ...@@ -111,25 +111,54 @@ def main(args):
) )
tgt_dict.save(dict_path(args.target_lang)) tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang): def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = dictionary.Dictionary.load(dict_path(lang)) dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1)) print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
n_seq_tok = [0, 0]
replaced = Counter()
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin')) def merge_result(worker_result):
replaced.update(worker_result['replaced'])
def consumer(tensor): n_seq_tok[0] += worker_result['nseq']
ds.add_item(tensor) n_seq_tok[1] += worker_result['ntok']
input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '') input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
res = Tokenizer.binarize(input_file, dict, consumer) offsets = Tokenizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers-1)
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(binarize, (args, input_file, dict, prefix, lang,
offsets[worker_id],
offsets[worker_id + 1]), callback=merge_result)
pool.close()
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
merge_result(Tokenizer.binarize(input_file, dict, lambda t: ds.add_item(t),
offset=0, end=offsets[1]))
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
temp_file_path = dataset_dest_prefix(args, prefix, lang)
ds.merge_file_(temp_file_path)
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, res['nseq'], res['ntok'], lang, input_file, n_seq_tok[0], n_seq_tok[1],
100 * res['nunk'] / res['ntok'], dict.unk_word)) 100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word))
ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
def make_dataset(input_prefix, output_prefix, lang): def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == 'binary': if args.output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang) make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
elif args.output_format == 'raw': elif args.output_format == 'raw':
# Copy original text file to destination folder # Copy original text file to destination folder
output_text_file = dest_path( output_text_file = dest_path(
...@@ -140,7 +169,7 @@ def main(args): ...@@ -140,7 +169,7 @@ def main(args):
def make_all(lang): def make_all(lang):
if args.trainpref: if args.trainpref:
make_dataset(args.trainpref, 'train', lang) make_dataset(args.trainpref, 'train', lang, num_workers=args.workers)
if args.validpref: if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')): for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid' outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
...@@ -196,6 +225,28 @@ def main(args): ...@@ -196,6 +225,28 @@ def main(args):
print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f) print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)
def binarize(args, filename, dict, output_prefix, lang, offset, end):
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
def consumer(tensor):
ds.add_item(tensor)
res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))
return res
def dataset_dest_prefix(args, output_prefix, lang):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}'
def dataset_dest_file(args, output_prefix, lang, extension):
base = dataset_dest_prefix(args, output_prefix, lang)
return f'{base}.{extension}'
if __name__ == '__main__': if __name__ == '__main__':
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment