"vscode:/vscode.git/clone" did not exist on "ef67fd926200d86e7d98cc151fd4728a9869ba19"
Commit f296824f authored by Vladimir Karpukhin's avatar Vladimir Karpukhin Committed by Facebook Github Bot
Browse files

Move string line encoding logic from tokenizer to Dictionary (unified diff). (#541)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/541

Just a combo of a stacked pair D14057943 & D14176011,
Made this as a separete diff cause there seems to be some issue with porting a stacked change into github repo

Differential Revision: D14251048

fbshipit-source-id: 0a47f534a69d6ab2ebe035fba40fd51748cccfb8
parent bc919276
......@@ -209,7 +209,6 @@ following contents::
from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tokenizer import Tokenizer
@register_task('simple_classification')
......@@ -253,8 +252,8 @@ following contents::
sentence = line.strip()
# Tokenize the sentence, splitting on spaces
tokens = Tokenizer.tokenize(
sentence, self.input_vocab, add_if_not_exist=False,
tokens = self.input_vocab.encode_line(
sentence, add_if_not_exist=False,
)
sentences.append(tokens)
......@@ -356,7 +355,6 @@ Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
......@@ -375,8 +373,8 @@ a new file named :file:`eval_classifier.py` with the following contents::
# Tokenize into characters
chars = ' '.join(list(sentence.strip()))
tokens = Tokenizer.tokenize(
chars, task.source_dictionary, add_if_not_exist=False,
tokens = task.source_dictionary.encode_line(
chars, add_if_not_exist=False,
)
# Build mini-batch to feed to the model
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
import os
from fairseq.tokenizer import tokenize_line
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Binarizer:
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False,
offset=0, end=-1):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = dict.encode_line(
line=line,
line_tokenizer=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
......@@ -213,3 +213,11 @@ def batch_by_size(
if len(batch) > 0:
yield batch
def process_bpe_symbol(sentence: str, bpe_symbol: str):
if bpe_symbol == 'sentencepiece':
sentence = sentence.replace('\u2581', ' ').strip()
elif bpe_symbol is not None:
sentence = (sentence + ' ').replace(bpe_symbol, '').rstrip()
return sentence
......@@ -6,10 +6,15 @@
# can be found in the PATENTS file in the same directory.
from collections import Counter
from multiprocessing import Pool
import os
import torch
from fairseq.tokenizer import tokenize_line
from fairseq.binarizer import safe_readline
from fairseq.data import data_utils
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
......@@ -57,14 +62,8 @@ class Dictionary(object):
else:
return self[i]
if bpe_symbol == 'sentencepiece':
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
sent = sent.replace('\u2581', ' ').strip()
else:
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None and bpe_symbol != 'sentencepiece':
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
return data_utils.process_bpe_symbol(sent, bpe_symbol)
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
......@@ -181,31 +180,104 @@ class Dictionary(object):
"rebuild the dataset".format(f))
d = cls()
for line in f.readlines():
lines = f.readlines()
indices_start_line = d._load_meta(lines)
for line in lines[indices_start_line:]:
idx = line.rfind(' ')
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
count = int(line[idx+1:])
count = int(line[idx + 1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d
def save(self, f):
"""Stores dictionary into a text file"""
def _save(self, f, kv_iterator):
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
for k, v in kv_iterator:
print('{} {}'.format(k, v), file=f)
def _get_meta(self):
return [], []
def _load_meta(self, lines):
return 0
def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:]))
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
@staticmethod
def _add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Dictionary._add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
class TruncatedDictionary(object):
......
......@@ -11,8 +11,6 @@ import struct
import numpy as np
import torch
from fairseq.tokenizer import Tokenizer
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
......@@ -171,8 +169,8 @@ class IndexedRawTextDataset(torch.utils.data.Dataset):
with open(path, 'r', encoding='utf-8') as f:
for line in f:
self.lines.append(line.strip('\n'))
tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False,
tokens = dictionary.encode_line(
line, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
......
......@@ -9,7 +9,6 @@ import torch
from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from fairseq.tokenizer import Tokenizer
class FairseqTask(object):
......@@ -52,7 +51,7 @@ class FairseqTask(object):
"""
d = Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
......
......@@ -5,13 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
from multiprocessing import Pool
import os
import re
import torch
SPACE_NORMALIZER = re.compile(r"\s+")
......@@ -19,124 +14,3 @@ def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Tokenizer:
@staticmethod
def add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Tokenizer.add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod
def binarize(
filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
reverse_order=False, offset=0, end=-1,
):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = Tokenizer.tokenize(
line=line,
dict=dict,
tokenize=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = tokenize(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = dict.add_symbol(word)
else:
idx = dict.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = dict.eos_index
return ids
......@@ -304,7 +304,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment
......
......@@ -165,8 +165,7 @@ def main(args):
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
......
......@@ -38,7 +38,7 @@ def buffered_read(input, buffer_size):
def make_batches(lines, args, task, max_positions):
tokens = [
tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long()
for src_str in lines
]
lengths = torch.LongTensor([t.numel() for t in tokens])
......
......@@ -11,15 +11,15 @@ Data pre-processing: build vocabularies and binarize training data.
from collections import Counter
from itertools import zip_longest
import os
import shutil
from fairseq import options, tasks
from fairseq.data import indexed_dataset
from fairseq.tokenizer import Tokenizer
from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
from multiprocessing import Pool
from fairseq.utils import import_user_module
import os
import shutil
def main(args):
......@@ -95,9 +95,8 @@ def main(args):
if target and tgt_dict is not None:
tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = task.load_dictionary(dict_path(lang))
print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
n_seq_tok = [0, 0]
replaced = Counter()
......@@ -109,7 +108,7 @@ def main(args):
input_file = "{}{}".format(
input_prefix, ("." + lang) if lang is not None else ""
)
offsets = Tokenizer.find_offsets(input_file, num_workers)
offsets = Binarizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
......@@ -120,13 +119,13 @@ def main(args):
(
args,
input_file,
dict,
vocab,
prefix,
lang,
offsets[worker_id],
offsets[worker_id + 1],
offsets[worker_id + 1]
),
callback=merge_result,
callback=merge_result
)
pool.close()
......@@ -134,8 +133,9 @@ def main(args):
dataset_dest_file(args, output_prefix, lang, "bin")
)
merge_result(
Tokenizer.binarize(
input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
Binarizer.binarize(
input_file, vocab, lambda t: ds.add_item(t),
offset=0, end=offsets[1]
)
)
if num_workers > 1:
......@@ -156,13 +156,13 @@ def main(args):
n_seq_tok[0],
n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1],
dict.unk_word,
vocab.unk_word,
)
)
def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == "binary":
make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
elif args.output_format == "raw":
# Copy original text file to destination folder
output_text_file = dest_path(
......@@ -171,21 +171,21 @@ def main(args):
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(lang):
def make_all(lang, vocab):
if args.trainpref:
make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid"
make_dataset(validpref, outprefix, lang)
make_dataset(vocab, validpref, outprefix, lang)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(testpref, outprefix, lang)
make_dataset(vocab, testpref, outprefix, lang)
make_all(args.source_lang)
make_all(args.source_lang, src_dict)
if target:
make_all(args.target_lang)
make_all(args.target_lang, tgt_dict)
print("| Wrote preprocessed data to {}".format(args.destdir))
......@@ -198,8 +198,8 @@ def main(args):
with open(src_file_name, "r", encoding='utf-8') as src_file:
with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
si = src_dict.encode_line(s, add_if_not_exist=False)
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
......@@ -232,7 +232,7 @@ def main(args):
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos=True):
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
)
......@@ -240,14 +240,8 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos=
def consumer(tensor):
ds.add_item(tensor)
res = Tokenizer.binarize(
filename,
dict,
consumer,
offset=offset,
end=end,
append_eos=append_eos
)
res = Binarizer.binarize(filename, vocab, consumer, append_eos=append_eos,
offset=offset, end=end)
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
return res
......@@ -266,7 +260,7 @@ def dataset_dest_file(args, output_prefix, lang, extension):
def get_offsets(input_file, num_workers):
return Tokenizer.find_offsets(input_file, num_workers)
return Binarizer.find_offsets(input_file, num_workers)
def merge_files(files, outpath):
......
......@@ -13,7 +13,7 @@ import argparse
import os
import sys
from fairseq import bleu, tokenizer
from fairseq import bleu
from fairseq.data import dictionary
......@@ -62,8 +62,8 @@ def main():
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
sys_tok = dict.encode_line(sys_tok)
ref_tok = dict.encode_line(ref_tok)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))
......
......@@ -11,7 +11,6 @@ import unittest
import torch
from fairseq.data import Dictionary
from fairseq.tokenizer import Tokenizer
class TestDictionary(unittest.TestCase):
......@@ -39,12 +38,12 @@ class TestDictionary(unittest.TestCase):
# build dictionary
d = Dictionary()
for line in txt:
Tokenizer.tokenize(line, d, add_if_not_exist=True)
d.encode_line(line, add_if_not_exist=True)
def get_ids(dictionary):
ids = []
for line in txt:
ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
ids.append(dictionary.encode_line(line, add_if_not_exist=False))
return ids
def assertMatch(ids, ref_ids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment