"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "2290bdb2fe350985f7e2ef991140087c25a3514c"
Commit a533325c authored by Xin Pan's avatar Xin Pan Committed by GitHub
Browse files

Merge pull request #725 from pranay360/master

Fix for issue #621
parents 821d3da7 f1e8ff7c
...@@ -56,6 +56,11 @@ class Vocab(object): ...@@ -56,6 +56,11 @@ class Vocab(object):
if self._count > max_size: if self._count > max_size:
raise ValueError('Too many words: >%d.' % max_size) raise ValueError('Too many words: >%d.' % max_size)
def CheckVocab(self, word):
if word not in self._word_to_id:
return None
return self._word_to_id[word]
def WordToId(self, word): def WordToId(self, word):
if word not in self._word_to_id: if word not in self._word_to_id:
return self._word_to_id[UNKNOWN_TOKEN] return self._word_to_id[UNKNOWN_TOKEN]
......
...@@ -160,10 +160,10 @@ def _Eval(model, data_batcher, vocab=None): ...@@ -160,10 +160,10 @@ def _Eval(model, data_batcher, vocab=None):
def main(unused_argv): def main(unused_argv):
vocab = data.Vocab(FLAGS.vocab_path, 1000000) vocab = data.Vocab(FLAGS.vocab_path, 1000000)
# Check for presence of required special tokens. # Check for presence of required special tokens.
assert vocab.WordToId(data.PAD_TOKEN) > 0 assert vocab.CheckVocab(data.PAD_TOKEN) > 0
assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0 assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
assert vocab.WordToId(data.SENTENCE_START) > 0 assert vocab.CheckVocab(data.SENTENCE_START) > 0
assert vocab.WordToId(data.SENTENCE_END) > 0 assert vocab.CheckVocab(data.SENTENCE_END) > 0
batch_size = 4 batch_size = 4
if FLAGS.mode == 'decode': if FLAGS.mode == 'decode':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment