Commit c6fe9fc5 authored by Myle Ott's avatar Myle Ott
Browse files

Fix for Dictionary.finalize

parent 7bcb487a
...@@ -95,7 +95,7 @@ class Dictionary(object): ...@@ -95,7 +95,7 @@ class Dictionary(object):
self.symbols.append(word) self.symbols.append(word)
self.count.append(new_dict.count[idx2]) self.count.append(new_dict.count[idx2])
def finalize(self, threshold=1, nwords=-1, padding_factor=8): def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones. """Sort symbols by frequency in descending order, ignoring special ones.
Args: Args:
...@@ -109,12 +109,14 @@ class Dictionary(object): ...@@ -109,12 +109,14 @@ class Dictionary(object):
if nwords == -1: if nwords == -1:
nwords = len(self) nwords = len(self)
new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[:self.nspecial] new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial] new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:]))) c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial): for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold: if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol) new_symbols.append(symbol)
new_count.append(count) new_count.append(count)
else: else:
...@@ -124,16 +126,20 @@ class Dictionary(object): ...@@ -124,16 +126,20 @@ class Dictionary(object):
if padding_factor > 1: if padding_factor > 1:
i = 0 i = 0
while threshold_nwords % padding_factor != 0: while threshold_nwords % padding_factor != 0:
new_symbols.append('madeupword{:04d}'.format(i)) symbol = 'madeupword{:04d}'.format(i)
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(0) new_count.append(0)
i += 1 i += 1
threshold_nwords += 1 threshold_nwords += 1
assert min(new_count[self.nspecial:]) >= threshold assert min(new_count[self.nspecial:]) >= threshold
assert len(new_symbols) % padding_factor == 0 assert len(new_symbols) % padding_factor == 0
assert len(new_symbols) == len(new_indices)
self.count = tuple(new_count) self.count = list(new_count)
self.symbols = tuple(new_symbols) self.symbols = list(new_symbols)
self.indices = new_indices
def pad(self): def pad(self):
"""Helper to get index of pad symbol""" """Helper to get index of pad symbol"""
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import tempfile
import unittest
import torch
from fairseq.data import Dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
class TestDictionary(unittest.TestCase):
def test_finalize(self):
txt = [
'A B C D',
'B C D',
'C D',
'D',
]
ref_ids1 = list(map(torch.IntTensor, [
[4, 5, 6, 7, 2],
[5, 6, 7, 2],
[6, 7, 2],
[7, 2],
]))
ref_ids2 = list(map(torch.IntTensor, [
[7, 6, 5, 4, 2],
[6, 5, 4, 2],
[5, 4, 2],
[4, 2],
]))
# build dictionary
d = Dictionary()
for line in txt:
Tokenizer.tokenize(line, d, add_if_not_exist=True)
def get_ids(dictionary):
ids = []
for line in txt:
ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
return ids
def assertMatch(ids, ref_ids):
for toks, ref_toks in zip(ids, ref_ids):
self.assertEqual(toks.size(), ref_toks.size())
self.assertEqual(0, (toks != ref_toks).sum().item())
ids = get_ids(d)
assertMatch(ids, ref_ids1)
# check finalized dictionary
d.finalize()
finalized_ids = get_ids(d)
assertMatch(finalized_ids, ref_ids2)
# write to disk and reload
with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
d.save(tmp_dict.name)
d = Dictionary.load(tmp_dict.name)
reload_ids = get_ids(d)
assertMatch(reload_ids, ref_ids2)
assertMatch(finalized_ids, reload_ids)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment