"router/vscode:/vscode.git/clone" did not exist on "33bc7212afd2bbf168c6735cd66c96a6cbf69c4e"
Commit fa7c575a authored by Myle Ott's avatar Myle Ott
Browse files

Fix preprocess.py

parent f607d9e8
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math from collections import Counter
import os import os
import torch import torch
...@@ -81,26 +82,43 @@ class Dictionary(object): ...@@ -81,26 +82,43 @@ class Dictionary(object):
self.count.append(n) self.count.append(n)
return idx return idx
def update(self, new_dict): def finalize(self, threshold=1, nwords=-1, padding_factor=8):
"""Updates counts from new dictionary.""" """Sort symbols by frequency in descending order, ignoring special ones.
for word in new_dict.symbols:
idx2 = new_dict.indices[word] Args:
if word in self.indices: - threshold defines the minimum word count
idx = self.indices[word] - nwords defines the total number of words in the final dictionary,
self.count[idx] = self.count[idx] + new_dict.count[idx2] including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if padding_factor > 1:
if nwords == -1:
nwords = len(self)
i = 0
while nwords % padding_factor != 0:
if nwords >= len(self):
self.add_symbol('madeupword{:04d}'.format(i))
i += 1
nwords += 1
new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_symbols.append(symbol)
new_count.append(count)
else: else:
idx = len(self.symbols) break
self.indices[word] = idx assert min(new_count[self.nspecial:]) >= threshold
self.symbols.append(word) assert len(new_symbols) <= nwords
self.count.append(new_dict.count[idx2]) assert len(new_symbols) % padding_factor == 0
def finalize(self): self.count = tuple(new_count)
"""Sort symbols by frequency in descending order, ignoring special ones.""" self.symbols = tuple(new_symbols)
self.count, self.symbols = zip(
*sorted(zip(self.count, self.symbols),
key=(lambda x: math.inf if self.indices[x[1]] < self.nspecial else x[0]),
reverse=True)
)
def pad(self): def pad(self):
"""Helper to get index of pad symbol""" """Helper to get index of pad symbol"""
...@@ -124,7 +142,6 @@ class Dictionary(object): ...@@ -124,7 +142,6 @@ class Dictionary(object):
... ...
``` ```
""" """
if isinstance(f, str): if isinstance(f, str):
try: try:
if not ignore_utf_errors: if not ignore_utf_errors:
...@@ -155,9 +172,5 @@ class Dictionary(object): ...@@ -155,9 +172,5 @@ class Dictionary(object):
os.makedirs(os.path.dirname(f), exist_ok=True) os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd: with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd, threshold, nwords) return self.save(fd, threshold, nwords)
cnt = 0 for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
for i, t in enumerate(zip(self.symbols, self.count)): print('{} {}'.format(symbol, count), file=f)
if i >= self.nspecial and t[1] >= threshold \
and (nwords <= 0 or cnt < nwords):
print('{} {}'.format(t[0], t[1]), file=f)
cnt += 1
...@@ -38,7 +38,8 @@ def get_parser(): ...@@ -38,7 +38,8 @@ def get_parser():
help='output format (optional)') help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary') parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language') parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, help='Pad dictionary size to be multiple of N') parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
help='Pad dictionary size to be multiple of N')
return parser return parser
...@@ -47,25 +48,10 @@ def main(args): ...@@ -47,25 +48,10 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source target = not args.only_source
def pad_dictionary(d):
"""Pad dictionary to be a multiple of args.padding_factor.
Keeping the dictionary size a multiple of 8 improves performance on some
architectures, e.g., Nvidia Tensor Cores.
"""
if args.padding_factor > 1:
i = 0
while len(d) % args.padding_factor != 0:
d.add_symbol('madeupword{:04d}'.format(i))
i += 1
assert len(d) % args.padding_factor == 0
def build_dictionary(filenames): def build_dictionary(filenames):
d = dictionary.Dictionary() d = dictionary.Dictionary()
for filename in filenames: for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line) Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
pad_dictionary(d)
d.finalize()
return d return d
if args.joined_dictionary: if args.joined_dictionary:
...@@ -89,11 +75,20 @@ def main(args): ...@@ -89,11 +75,20 @@ def main(args):
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.target_lang)]) tgt_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.target_lang)])
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)), src_dict.finalize(
threshold=args.thresholdsrc, nwords=args.nwordssrc) threshold=args.thresholdsrc,
nwords=args.nwordssrc,
padding_factor=args.padding_factor,
)
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
if target: if target:
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)), if not args.joined_dictionary:
threshold=args.thresholdtgt, nwords=args.nwordstgt) tgt_dict.finalize(
threshold=args.thresholdtgt,
nwords=args.nwordstgt,
padding_factor=args.padding_factor,
)
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)))
def make_binary_dataset(input_prefix, output_prefix, lang): def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang))) dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment