Commit cab76554 authored by Louis Martin's avatar Louis Martin Committed by Myle Ott
Browse files

Refactor code in Tokenizer

parent eea50f38
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
from collections import Counter
import re import re
import torch import torch
from fairseq import dictionary from fairseq import dictionary
...@@ -32,46 +34,41 @@ class Tokenizer: ...@@ -32,46 +34,41 @@ class Tokenizer:
@staticmethod @staticmethod
def add_file_to_dictionary(filename, dict, tokenize): def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f: with open(filename, 'r') as f:
for line in f.readlines(): for line in f:
for word in tokenize(line): for word in tokenize(line):
dict.add_symbol(word) dict.add_symbol(word)
dict.add_symbol(dict.eos_word) dict.add_symbol(dict.eos_word)
@staticmethod @staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line): def binarize(filename, dict, consumer, tokenize=tokenize_line):
nseq, ntok, nunk = 0, 0, 0 nseq, ntok = 0, 0
replaced = {} replaced = Counter()
with open(filename, 'r') as f:
for line in f.readlines(): def replaced_consumer(word, idx):
words = tokenize(line)
nwords = len(words)
ids = torch.IntTensor(nwords + 1)
nseq = nseq + 1
for i in range(0, len(words)):
word = words[i]
idx = dict.index(word)
if idx == dict.unk_index and word != dict.unk_word: if idx == dict.unk_index and word != dict.unk_word:
nunk = nunk + 1 replaced.update([word])
if word in replaced:
replaced[word] = replaced[word] + 1 with open(filename, 'r') as f:
else: for line in f:
replaced[word] = 1 ids = Tokenizer.tokenize(line, dict, tokenize, add_if_not_exist=False, consumer=replaced_consumer)
ids[i] = idx nseq += 1
ids[nwords] = dict.eos_index
consumer(ids) consumer(ids)
ntok = ntok + len(ids) ntok += len(ids)
return {'nseq': nseq, 'nunk': nunk, 'ntok': ntok, 'replaced': len(replaced)} return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod @staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True): def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, consumer=None):
words = tokenize(line) words = tokenize(line)
nwords = len(words) nwords = len(words)
ids = torch.IntTensor(nwords + 1) ids = torch.IntTensor(nwords + 1)
for i in range(0, len(words)): for i, word in enumerate(words):
if add_if_not_exist: if add_if_not_exist:
ids[i] = dict.add_symbol(words[i]) idx = dict.add_symbol(word)
else: else:
ids[i] = dict.index(words[i]) idx = dict.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
ids[nwords] = dict.eos_index ids[nwords] = dict.eos_index
return ids return ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment