vocabulary.py 5.03 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import Counter, OrderedDict

import numpy as np

import tensorflow as tf

from tensorflow.gfile import Open as open
from tensorflow.gfile import Exists as exists

class Vocab(object):
  def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
         delimiter=None, vocab_file=None):
    self.counter = Counter()
    self.special = special
    self.min_freq = min_freq
    self.max_size = max_size
    self.lower_case = lower_case
    self.delimiter = delimiter
    self.vocab_file = vocab_file

  def tokenize(self, line, add_eos=False, add_double_eos=False):
    line = line.strip()
    # convert to lower case
    if self.lower_case:
      line = line.lower()

    # empty delimiter '' will evaluate False
    if self.delimiter == '':
      symbols = line
    else:
      symbols = line.split(self.delimiter)

    if add_double_eos: # lm1b
      return ['<S>'] + symbols + ['<S>']
    elif add_eos:
      return symbols + ['<eos>']
    else:
      return symbols

  def count_file(self, path, verbose=False, add_eos=False):
    if verbose: print('counting file {} ...'.format(path))
    assert exists(path)

    sents = []
    with open(path, 'r') as f:
      for idx, line in enumerate(f):
        if verbose and idx > 0 and idx % 500000 == 0:
          print('  line {}'.format(idx))
        symbols = self.tokenize(line, add_eos=add_eos)
        self.counter.update(symbols)
        sents.append(symbols)

    return sents

  def count_sents(self, sents, verbose=False):
    """
      sents : a list of sentences, each a list of tokenized symbols
    """
    if verbose: print('counting {} sents ...'.format(len(sents)))
    for idx, symbols in enumerate(sents):
      if verbose and idx > 0 and idx % 500000 == 0:
        print('  line {}'.format(idx))
      self.counter.update(symbols)

  def _build_from_file(self, vocab_file):
    self.idx2sym = []
    self.sym2idx = OrderedDict()

    with open(vocab_file, 'r') as f:
      for line in f:
        symb = line.strip().split()[0]
        self.add_symbol(symb)
    self.unk_idx = self.sym2idx['<UNK>']

  def build_vocab(self):
    if self.vocab_file:
      print('building vocab from {}'.format(self.vocab_file))
      self._build_from_file(self.vocab_file)
      print('final vocab size {}'.format(len(self)))
    else:
      print('building vocab with min_freq={}, max_size={}'.format(
        self.min_freq, self.max_size))
      self.idx2sym = []
      self.sym2idx = OrderedDict()

      for sym in self.special:
        self.add_special(sym)

      for sym, cnt in self.counter.most_common(self.max_size):
        if cnt < self.min_freq: break
        self.add_symbol(sym)

      print('final vocab size {} from {} unique tokens'.format(
        len(self), len(self.counter)))

  def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
          add_double_eos=False):
    if verbose: print('encoding file {} ...'.format(path))
    assert exists(path)
    encoded = []
    with open(path, 'r') as f:
      for idx, line in enumerate(f):
        if verbose and idx > 0 and idx % 500000 == 0:
          print('  line {}'.format(idx))
        symbols = self.tokenize(line, add_eos=add_eos,
          add_double_eos=add_double_eos)
        encoded.append(self.convert_to_nparray(symbols))

    if ordered:
      encoded = np.concatenate(encoded)

    return encoded

  def encode_sents(self, sents, ordered=False, verbose=False):
    if verbose: print('encoding {} sents ...'.format(len(sents)))
    encoded = []
    for idx, symbols in enumerate(sents):
      if verbose and idx > 0 and idx % 500000 == 0:
        print('  line {}'.format(idx))
      encoded.append(self.convert_to_nparray(symbols))

    if ordered:
      encoded = np.concatenate(encoded)

    return encoded

  def add_special(self, sym):
    if sym not in self.sym2idx:
      self.idx2sym.append(sym)
      self.sym2idx[sym] = len(self.idx2sym) - 1
      setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])

  def add_symbol(self, sym):
    if sym not in self.sym2idx:
      self.idx2sym.append(sym)
      self.sym2idx[sym] = len(self.idx2sym) - 1

  def get_sym(self, idx):
    assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
    return self.idx2sym[idx]

  def get_idx(self, sym):
    if sym in self.sym2idx:
      return self.sym2idx[sym]
    else:
      assert hasattr(self, 'unk_idx')
      return self.sym2idx.get(sym, self.unk_idx)

  def get_symbols(self, indices):
    return [self.get_sym(idx) for idx in indices]

  def get_indices(self, symbols):
    return [self.get_idx(sym) for sym in symbols]

  def convert_to_nparray(self, symbols):
    nparray = np.array(self.get_indices(symbols), dtype=np.int64)
    return nparray

  def convert_to_sent(self, indices, exclude=None):
    if exclude is None:
      return ' '.join([self.get_sym(idx) for idx in indices])
    else:
      return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])

  def __len__(self):
    return len(self.idx2sym)