"src/vscode:/vscode.git/clone" did not exist on "07860f991639f35f4b5a152676bd4d590c3e589e"
Commit cab76554 authored by Louis Martin's avatar Louis Martin Committed by Myle Ott
Browse files

Refactor code in Tokenizer

parent eea50f38
......@@ -6,7 +6,9 @@
# can be found in the PATENTS file in the same directory.
#
from collections import Counter
import re
import torch
from fairseq import dictionary
......@@ -32,46 +34,41 @@ class Tokenizer:
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f:
for line in f.readlines():
for line in f:
for word in tokenize(line):
dict.add_symbol(word)
dict.add_symbol(dict.eos_word)
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line):
nseq, ntok, nunk = 0, 0, 0
replaced = {}
with open(filename, 'r') as f:
for line in f.readlines():
words = tokenize(line)
nwords = len(words)
ids = torch.IntTensor(nwords + 1)
nseq = nseq + 1
for i in range(0, len(words)):
word = words[i]
idx = dict.index(word)
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
nunk = nunk + 1
if word in replaced:
replaced[word] = replaced[word] + 1
else:
replaced[word] = 1
ids[i] = idx
replaced.update([word])
with open(filename, 'r') as f:
for line in f:
ids = Tokenizer.tokenize(line, dict, tokenize, add_if_not_exist=False, consumer=replaced_consumer)
nseq += 1
ids[nwords] = dict.eos_index
consumer(ids)
ntok = ntok + len(ids)
return {'nseq': nseq, 'nunk': nunk, 'ntok': ntok, 'replaced': len(replaced)}
ntok += len(ids)
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True):
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, consumer=None):
words = tokenize(line)
nwords = len(words)
ids = torch.IntTensor(nwords + 1)
for i in range(0, len(words)):
for i, word in enumerate(words):
if add_if_not_exist:
ids[i] = dict.add_symbol(words[i])
idx = dict.add_symbol(word)
else:
ids[i] = dict.index(words[i])
idx = dict.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
ids[nwords] = dict.eos_index
return ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment