Commit f296824f authored by Vladimir Karpukhin's avatar Vladimir Karpukhin Committed by Facebook Github Bot
Browse files

Move string line encoding logic from tokenizer to Dictionary (unified diff). (#541)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/541

Just a combo of a stacked pair D14057943 & D14176011,
Made this as a separete diff cause there seems to be some issue with porting a stacked change into github repo

Differential Revision: D14251048

fbshipit-source-id: 0a47f534a69d6ab2ebe035fba40fd51748cccfb8
parent bc919276
...@@ -209,7 +209,6 @@ following contents:: ...@@ -209,7 +209,6 @@ following contents::
from fairseq.data import Dictionary, LanguagePairDataset from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import FairseqTask, register_task
from fairseq.tokenizer import Tokenizer
@register_task('simple_classification') @register_task('simple_classification')
...@@ -253,8 +252,8 @@ following contents:: ...@@ -253,8 +252,8 @@ following contents::
sentence = line.strip() sentence = line.strip()
# Tokenize the sentence, splitting on spaces # Tokenize the sentence, splitting on spaces
tokens = Tokenizer.tokenize( tokens = self.input_vocab.encode_line(
sentence, self.input_vocab, add_if_not_exist=False, sentence, add_if_not_exist=False,
) )
sentences.append(tokens) sentences.append(tokens)
...@@ -356,7 +355,6 @@ Finally we can write a short script to evaluate our model on new inputs. Create ...@@ -356,7 +355,6 @@ Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents:: a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation # Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification') parser = options.get_generation_parser(default_task='simple_classification')
...@@ -375,8 +373,8 @@ a new file named :file:`eval_classifier.py` with the following contents:: ...@@ -375,8 +373,8 @@ a new file named :file:`eval_classifier.py` with the following contents::
# Tokenize into characters # Tokenize into characters
chars = ' '.join(list(sentence.strip())) chars = ' '.join(list(sentence.strip()))
tokens = Tokenizer.tokenize( tokens = task.source_dictionary.encode_line(
chars, task.source_dictionary, add_if_not_exist=False, chars, add_if_not_exist=False,
) )
# Build mini-batch to feed to the model # Build mini-batch to feed to the model
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
import os
from fairseq.tokenizer import tokenize_line
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Binarizer:
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False,
offset=0, end=-1):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = dict.encode_line(
line=line,
line_tokenizer=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
...@@ -213,3 +213,11 @@ def batch_by_size( ...@@ -213,3 +213,11 @@ def batch_by_size(
if len(batch) > 0: if len(batch) > 0:
yield batch yield batch
def process_bpe_symbol(sentence: str, bpe_symbol: str):
if bpe_symbol == 'sentencepiece':
sentence = sentence.replace('\u2581', ' ').strip()
elif bpe_symbol is not None:
sentence = (sentence + ' ').replace(bpe_symbol, '').rstrip()
return sentence
...@@ -6,10 +6,15 @@ ...@@ -6,10 +6,15 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import Counter from collections import Counter
from multiprocessing import Pool
import os import os
import torch import torch
from fairseq.tokenizer import tokenize_line
from fairseq.binarizer import safe_readline
from fairseq.data import data_utils
class Dictionary(object): class Dictionary(object):
"""A mapping from symbols to consecutive integers""" """A mapping from symbols to consecutive integers"""
...@@ -57,14 +62,8 @@ class Dictionary(object): ...@@ -57,14 +62,8 @@ class Dictionary(object):
else: else:
return self[i] return self[i]
if bpe_symbol == 'sentencepiece': sent = ''.join(token_string(i) for i in tensor if i != self.eos())
sent = ''.join(token_string(i) for i in tensor if i != self.eos()) return data_utils.process_bpe_symbol(sent, bpe_symbol)
sent = sent.replace('\u2581', ' ').strip()
else:
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None and bpe_symbol != 'sentencepiece':
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
def unk_string(self, escape=False): def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>""" """Return unknown string, optionally escaped as: <<unk>>"""
...@@ -181,31 +180,104 @@ class Dictionary(object): ...@@ -181,31 +180,104 @@ class Dictionary(object):
"rebuild the dataset".format(f)) "rebuild the dataset".format(f))
d = cls() d = cls()
for line in f.readlines(): lines = f.readlines()
indices_start_line = d._load_meta(lines)
for line in lines[indices_start_line:]:
idx = line.rfind(' ') idx = line.rfind(' ')
if idx == -1: if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'") raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx] word = line[:idx]
count = int(line[idx+1:]) count = int(line[idx + 1:])
d.indices[word] = len(d.symbols) d.indices[word] = len(d.symbols)
d.symbols.append(word) d.symbols.append(word)
d.count.append(count) d.count.append(count)
return d return d
def save(self, f): def _save(self, f, kv_iterator):
"""Stores dictionary into a text file"""
if isinstance(f, str): if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True) os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd: with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd) return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]): for k, v in kv_iterator:
print('{} {}'.format(symbol, count), file=f) print('{} {}'.format(k, v), file=f)
def _get_meta(self):
return [], []
def _load_meta(self, lines):
return 0
def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:]))
def dummy_sentence(self, length): def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos() t[-1] = self.eos()
return t return t
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
@staticmethod
def _add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Dictionary._add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
class TruncatedDictionary(object): class TruncatedDictionary(object):
......
...@@ -11,8 +11,6 @@ import struct ...@@ -11,8 +11,6 @@ import struct
import numpy as np import numpy as np
import torch import torch
from fairseq.tokenizer import Tokenizer
def read_longs(f, n): def read_longs(f, n):
a = np.empty(n, dtype=np.int64) a = np.empty(n, dtype=np.int64)
...@@ -171,8 +169,8 @@ class IndexedRawTextDataset(torch.utils.data.Dataset): ...@@ -171,8 +169,8 @@ class IndexedRawTextDataset(torch.utils.data.Dataset):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
self.lines.append(line.strip('\n')) self.lines.append(line.strip('\n'))
tokens = Tokenizer.tokenize( tokens = dictionary.encode_line(
line, dictionary, add_if_not_exist=False, line, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order, append_eos=self.append_eos, reverse_order=self.reverse_order,
).long() ).long()
self.tokens_list.append(tokens) self.tokens_list.append(tokens)
......
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
from fairseq import tokenizer from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from fairseq.tokenizer import Tokenizer
class FairseqTask(object): class FairseqTask(object):
...@@ -52,7 +51,7 @@ class FairseqTask(object): ...@@ -52,7 +51,7 @@ class FairseqTask(object):
""" """
d = Dictionary() d = Dictionary()
for filename in filenames: for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d return d
......
...@@ -5,13 +5,8 @@ ...@@ -5,13 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import Counter
from multiprocessing import Pool
import os
import re import re
import torch
SPACE_NORMALIZER = re.compile(r"\s+") SPACE_NORMALIZER = re.compile(r"\s+")
...@@ -19,124 +14,3 @@ def tokenize_line(line): ...@@ -19,124 +14,3 @@ def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line) line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip() line = line.strip()
return line.split() return line.split()
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Tokenizer:
@staticmethod
def add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Tokenizer.add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod
def binarize(
filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
reverse_order=False, offset=0, end=-1,
):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = Tokenizer.tokenize(
line=line,
dict=dict,
tokenize=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = tokenize(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = dict.add_symbol(word)
else:
idx = dict.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = dict.eos_index
return ids
...@@ -304,7 +304,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic ...@@ -304,7 +304,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic
if align_dict is not None or remove_bpe is not None: if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE # Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method. # Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True) hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment return hypo_tokens, hypo_str, alignment
......
...@@ -165,8 +165,7 @@ def main(args): ...@@ -165,8 +165,7 @@ def main(args):
if has_target and i == 0: if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None: if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE # Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize( target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
target_str, tgt_dict, add_if_not_exist=True)
if hasattr(scorer, 'add_string'): if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str) scorer.add_string(target_str, hypo_str)
else: else:
......
...@@ -38,7 +38,7 @@ def buffered_read(input, buffer_size): ...@@ -38,7 +38,7 @@ def buffered_read(input, buffer_size):
def make_batches(lines, args, task, max_positions): def make_batches(lines, args, task, max_positions):
tokens = [ tokens = [
tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long() task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long()
for src_str in lines for src_str in lines
] ]
lengths = torch.LongTensor([t.numel() for t in tokens]) lengths = torch.LongTensor([t.numel() for t in tokens])
......
...@@ -11,15 +11,15 @@ Data pre-processing: build vocabularies and binarize training data. ...@@ -11,15 +11,15 @@ Data pre-processing: build vocabularies and binarize training data.
from collections import Counter from collections import Counter
from itertools import zip_longest from itertools import zip_longest
import os
import shutil
from fairseq import options, tasks from fairseq import options, tasks
from fairseq.data import indexed_dataset from fairseq.data import indexed_dataset
from fairseq.tokenizer import Tokenizer from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
from multiprocessing import Pool from multiprocessing import Pool
from fairseq.utils import import_user_module import os
import shutil
def main(args): def main(args):
...@@ -95,9 +95,8 @@ def main(args): ...@@ -95,9 +95,8 @@ def main(args):
if target and tgt_dict is not None: if target and tgt_dict is not None:
tgt_dict.save(dict_path(args.target_lang)) tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers): def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
dict = task.load_dictionary(dict_path(lang)) print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
n_seq_tok = [0, 0] n_seq_tok = [0, 0]
replaced = Counter() replaced = Counter()
...@@ -109,7 +108,7 @@ def main(args): ...@@ -109,7 +108,7 @@ def main(args):
input_file = "{}{}".format( input_file = "{}{}".format(
input_prefix, ("." + lang) if lang is not None else "" input_prefix, ("." + lang) if lang is not None else ""
) )
offsets = Tokenizer.find_offsets(input_file, num_workers) offsets = Binarizer.find_offsets(input_file, num_workers)
pool = None pool = None
if num_workers > 1: if num_workers > 1:
pool = Pool(processes=num_workers - 1) pool = Pool(processes=num_workers - 1)
...@@ -120,13 +119,13 @@ def main(args): ...@@ -120,13 +119,13 @@ def main(args):
( (
args, args,
input_file, input_file,
dict, vocab,
prefix, prefix,
lang, lang,
offsets[worker_id], offsets[worker_id],
offsets[worker_id + 1], offsets[worker_id + 1]
), ),
callback=merge_result, callback=merge_result
) )
pool.close() pool.close()
...@@ -134,8 +133,9 @@ def main(args): ...@@ -134,8 +133,9 @@ def main(args):
dataset_dest_file(args, output_prefix, lang, "bin") dataset_dest_file(args, output_prefix, lang, "bin")
) )
merge_result( merge_result(
Tokenizer.binarize( Binarizer.binarize(
input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1] input_file, vocab, lambda t: ds.add_item(t),
offset=0, end=offsets[1]
) )
) )
if num_workers > 1: if num_workers > 1:
...@@ -156,13 +156,13 @@ def main(args): ...@@ -156,13 +156,13 @@ def main(args):
n_seq_tok[0], n_seq_tok[0],
n_seq_tok[1], n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1], 100 * sum(replaced.values()) / n_seq_tok[1],
dict.unk_word, vocab.unk_word,
) )
) )
def make_dataset(input_prefix, output_prefix, lang, num_workers=1): def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == "binary": if args.output_format == "binary":
make_binary_dataset(input_prefix, output_prefix, lang, num_workers) make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
elif args.output_format == "raw": elif args.output_format == "raw":
# Copy original text file to destination folder # Copy original text file to destination folder
output_text_file = dest_path( output_text_file = dest_path(
...@@ -171,21 +171,21 @@ def main(args): ...@@ -171,21 +171,21 @@ def main(args):
) )
shutil.copyfile(file_name(input_prefix, lang), output_text_file) shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(lang): def make_all(lang, vocab):
if args.trainpref: if args.trainpref:
make_dataset(args.trainpref, "train", lang, num_workers=args.workers) make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)
if args.validpref: if args.validpref:
for k, validpref in enumerate(args.validpref.split(",")): for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid" outprefix = "valid{}".format(k) if k > 0 else "valid"
make_dataset(validpref, outprefix, lang) make_dataset(vocab, validpref, outprefix, lang)
if args.testpref: if args.testpref:
for k, testpref in enumerate(args.testpref.split(",")): for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test" outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(testpref, outprefix, lang) make_dataset(vocab, testpref, outprefix, lang)
make_all(args.source_lang) make_all(args.source_lang, src_dict)
if target: if target:
make_all(args.target_lang) make_all(args.target_lang, tgt_dict)
print("| Wrote preprocessed data to {}".format(args.destdir)) print("| Wrote preprocessed data to {}".format(args.destdir))
...@@ -198,8 +198,8 @@ def main(args): ...@@ -198,8 +198,8 @@ def main(args):
with open(src_file_name, "r", encoding='utf-8') as src_file: with open(src_file_name, "r", encoding='utf-8') as src_file:
with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file): for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False) si = src_dict.encode_line(s, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False) ti = tgt_dict.encode_line(t, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split("-")), a.split())) ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai: for sai, tai in ai:
srcidx = si[int(sai)] srcidx = si[int(sai)]
...@@ -232,7 +232,7 @@ def main(args): ...@@ -232,7 +232,7 @@ def main(args):
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos=True): def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
ds = indexed_dataset.IndexedDatasetBuilder( ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin") dataset_dest_file(args, output_prefix, lang, "bin")
) )
...@@ -240,14 +240,8 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos= ...@@ -240,14 +240,8 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos=
def consumer(tensor): def consumer(tensor):
ds.add_item(tensor) ds.add_item(tensor)
res = Tokenizer.binarize( res = Binarizer.binarize(filename, vocab, consumer, append_eos=append_eos,
filename, offset=offset, end=end)
dict,
consumer,
offset=offset,
end=end,
append_eos=append_eos
)
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
return res return res
...@@ -266,7 +260,7 @@ def dataset_dest_file(args, output_prefix, lang, extension): ...@@ -266,7 +260,7 @@ def dataset_dest_file(args, output_prefix, lang, extension):
def get_offsets(input_file, num_workers): def get_offsets(input_file, num_workers):
return Tokenizer.find_offsets(input_file, num_workers) return Binarizer.find_offsets(input_file, num_workers)
def merge_files(files, outpath): def merge_files(files, outpath):
......
...@@ -13,7 +13,7 @@ import argparse ...@@ -13,7 +13,7 @@ import argparse
import os import os
import sys import sys
from fairseq import bleu, tokenizer from fairseq import bleu
from fairseq.data import dictionary from fairseq.data import dictionary
...@@ -62,8 +62,8 @@ def main(): ...@@ -62,8 +62,8 @@ def main():
with open(args.ref) as fdref: with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) sys_tok = dict.encode_line(sys_tok)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) ref_tok = dict.encode_line(ref_tok)
scorer.add(ref_tok, sys_tok) scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order)) print(scorer.result_string(args.order))
......
...@@ -11,7 +11,6 @@ import unittest ...@@ -11,7 +11,6 @@ import unittest
import torch import torch
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.tokenizer import Tokenizer
class TestDictionary(unittest.TestCase): class TestDictionary(unittest.TestCase):
...@@ -39,12 +38,12 @@ class TestDictionary(unittest.TestCase): ...@@ -39,12 +38,12 @@ class TestDictionary(unittest.TestCase):
# build dictionary # build dictionary
d = Dictionary() d = Dictionary()
for line in txt: for line in txt:
Tokenizer.tokenize(line, d, add_if_not_exist=True) d.encode_line(line, add_if_not_exist=True)
def get_ids(dictionary): def get_ids(dictionary):
ids = [] ids = []
for line in txt: for line in txt:
ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False)) ids.append(dictionary.encode_line(line, add_if_not_exist=False))
return ids return ids
def assertMatch(ids, ref_ids): def assertMatch(ids, ref_ids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment