Commit fa7c575a authored by Myle Ott's avatar Myle Ott
Browse files

Fix preprocess.py

parent f607d9e8
......@@ -5,8 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
from collections import Counter
import os
import torch
......@@ -81,26 +82,43 @@ class Dictionary(object):
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
def finalize(self, threshold=1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if padding_factor > 1:
if nwords == -1:
nwords = len(self)
i = 0
while nwords % padding_factor != 0:
if nwords >= len(self):
self.add_symbol('madeupword{:04d}'.format(i))
i += 1
nwords += 1
new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_symbols.append(symbol)
new_count.append(count)
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self):
"""Sort symbols by frequency in descending order, ignoring special ones."""
self.count, self.symbols = zip(
*sorted(zip(self.count, self.symbols),
key=(lambda x: math.inf if self.indices[x[1]] < self.nspecial else x[0]),
reverse=True)
)
break
assert min(new_count[self.nspecial:]) >= threshold
assert len(new_symbols) <= nwords
assert len(new_symbols) % padding_factor == 0
self.count = tuple(new_count)
self.symbols = tuple(new_symbols)
def pad(self):
"""Helper to get index of pad symbol"""
......@@ -124,7 +142,6 @@ class Dictionary(object):
...
```
"""
if isinstance(f, str):
try:
if not ignore_utf_errors:
......@@ -155,9 +172,5 @@ class Dictionary(object):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd, threshold, nwords)
cnt = 0
for i, t in enumerate(zip(self.symbols, self.count)):
if i >= self.nspecial and t[1] >= threshold \
and (nwords <= 0 or cnt < nwords):
print('{} {}'.format(t[0], t[1]), file=f)
cnt += 1
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
......@@ -38,7 +38,8 @@ def get_parser():
help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, help='Pad dictionary size to be multiple of N')
parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
help='Pad dictionary size to be multiple of N')
return parser
......@@ -47,25 +48,10 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
def pad_dictionary(d):
"""Pad dictionary to be a multiple of args.padding_factor.
Keeping the dictionary size a multiple of 8 improves performance on some
architectures, e.g., Nvidia Tensor Cores.
"""
if args.padding_factor > 1:
i = 0
while len(d) % args.padding_factor != 0:
d.add_symbol('madeupword{:04d}'.format(i))
i += 1
assert len(d) % args.padding_factor == 0
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
pad_dictionary(d)
d.finalize()
return d
if args.joined_dictionary:
......@@ -89,11 +75,20 @@ def main(args):
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.target_lang)])
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc)
src_dict.finalize(
threshold=args.thresholdsrc,
nwords=args.nwordssrc,
padding_factor=args.padding_factor,
)
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
if target:
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt)
if not args.joined_dictionary:
tgt_dict.finalize(
threshold=args.thresholdtgt,
nwords=args.nwordstgt,
padding_factor=args.padding_factor,
)
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)))
def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment