Commit 537b7eb6 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #409 from panyx0718/models/lm_1b

Add pre-trained lm_1b model
parents 76739168 2a5a5596
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//lm_1b/...",
],
)
py_library(
name = "data_utils",
srcs = ["data_utils.py"],
)
py_binary(
name = "lm_1b_eval",
srcs = [
"lm_1b_eval.py",
],
deps = [
":data_utils",
],
)
<font size=4><b>Language Model on One Billion Word Benchmark</b></font>
<b>Authors:</b>
Oriol Vinyals (vinyals@google.com, github: OriolVinyals),
Xin Pan (xpan@google.com, github: panyx0718)
<b>Paper Authors:</b>
Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, Yonghui Wu
<b>TL;DR</b>
This is a pretrained model on One Billion Word Benchmark.
If you use this model in your publication, please cite the original paper:
@article{jozefowicz2016exploring,
title={Exploring the Limits of Language Modeling},
author={Jozefowicz, Rafal and Vinyals, Oriol and Schuster, Mike
and Shazeer, Noam and Wu, Yonghui},
journal={arXiv preprint arXiv:1602.02410},
year={2016}
}
<b>Introduction</b>
In this release, we open source a model trained on the One Billion Word
Benchmark (http://arxiv.org/abs/1312.3005), a large language corpus in English
which was released in 2013. This dataset contains about one billion words, and
has a vocabulary size of about 800K words. It contains mostly news data. Since
sentences in the training set are shuffled, models can ignore the context and
focus on sentence level language modeling.
In the original release and subsequent work, people have used the same test set
to train models on this dataset as a standard benchmark for language modeling.
Recently, we wrote an article (http://arxiv.org/abs/1602.02410) describing a
model hybrid between character CNN, a large and deep LSTM, and a specific
Softmax architecture which allowed us to train the best model on this dataset
thus far, almost halving the best perplexity previously obtained by others.
<b>Code Release</b>
The open-sourced components include:
* TensorFlow GraphDef proto buffer text file.
* TensorFlow pre-trained checkpoint shards.
* Code used to evaluate the pre-trained model.
* Vocabulary file.
* Test set from LM-1B evaluation.
The code supports 4 evaluation modes:
* Given provided dataset, calculate the model's perplexity.
* Given a prefix sentence, predict the next words.
* Dump the softmax embedding, character-level CNN word embeddings.
* Give a sentence, dump the embedding from the LSTM state.
<b>Results</b>
Model | Test Perplexity | Number of Params [billions]
------|-----------------|----------------------------
Sigmoid-RNN-2048 [Blackout] | 68.3 | 4.1
Interpolated KN 5-gram, 1.1B n-grams [chelba2013one] | 67.6 | 1.76
Sparse Non-Negative Matrix LM [shazeer2015sparse] | 52.9 | 33
RNN-1024 + MaxEnt 9-gram features [chelba2013one] | 51.3 | 20
LSTM-512-512 | 54.1 | 0.82
LSTM-1024-512 | 48.2 | 0.82
LSTM-2048-512 | 43.7 | 0.83
LSTM-8192-2048 (No Dropout) | 37.9 | 3.3
LSTM-8192-2048 (50\% Dropout) | 32.2 | 3.3
2-Layer LSTM-8192-1024 (BIG LSTM) | 30.6 | 1.8
(THIS RELEASE) BIG LSTM+CNN Inputs | <b>30.0</b> | <b>1.04</b>
<b>How To Run</b>
Pre-requesite:
* Install TensorFlow.
* Install Bazel.
* Download the data files:
* Model GraphDef file:
[link](download.tensorflow.org/models/LM_LSTM_CNN/graph-2016-09-10.pbtxt)
* Model Checkpoint sharded file:
[1](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-base)
[2](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-char-embedding)
[3](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-lstm)
[4](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax0)
[5](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax1)
[6](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax2)
[7](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax3)
[8](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax4)
[9](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax5)
[10](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax6)
[11](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax7)
[12](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax8)
* Vocabulary file:
[link](download.tensorflow.org/models/LM_LSTM_CNN/vocab-2016-09-10.txt)
* test dataset: link
[link](download.tensorflow.org/models/LM_LSTM_CNN/test/news.en.heldout-00000-of-00050)
* It is recommended to run on modern desktop PC instead of laptop.
```shell
# 1. Clone the code to your workspace.
# 2. Download the data to your workspace.
# 3. Create an empty WORKSPACE file in your workspace.
# 4. Create an empty output directory in your workspace.
# Example directory structure below:
ls -R
.:
data lm_1b output WORKSPACE
./data:
ckpt eval_2_8k_1k_1_1_char.pbtxt news.en.heldout-00000-of-00050 vocab.txt
./lm_1b:
BUILD data_utils.py data_utils.pyc lm_1b_eval.py README.md
./output:
# Build the codes.
bazel build -c opt lm_1b/...
# Run sample mode:
bazel-bin/lm_1b/lm_1b_eval --mode sample \
--prefix "I love that I" \
--pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
--vocab_file data/vocab.txt \
--ckpt data/ckpt
...(omitted some TensorFlow output)
I love
I love that
I love that I
I love that I find
I love that I find that
I love that I find that amazing
...(omitted)
# Run eval mode:
bazel-bin/lm_1b/lm_1b_eval --mode eval \
--pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
--vocab_file data/vocab.txt \
--input_data data/news.en.heldout-00000-of-00050 \
--ckpt data/ckpt
...(omitted some TensorFlow output)
Loaded step 14108582.
# perplexity is high initially because words without context are harder to
# predict.
Eval Step: 0, Average Perplexity: 2045.512297.
Eval Step: 1, Average Perplexity: 229.478699.
Eval Step: 2, Average Perplexity: 208.116787.
Eval Step: 3, Average Perplexity: 338.870601.
Eval Step: 4, Average Perplexity: 228.950107.
Eval Step: 5, Average Perplexity: 197.685857.
Eval Step: 6, Average Perplexity: 156.287063.
Eval Step: 7, Average Perplexity: 124.866189.
Eval Step: 8, Average Perplexity: 147.204975.
Eval Step: 9, Average Perplexity: 90.124864.
Eval Step: 10, Average Perplexity: 59.897914.
Eval Step: 11, Average Perplexity: 42.591137.
...(omitted)
Eval Step: 4529, Average Perplexity: 29.243668.
Eval Step: 4530, Average Perplexity: 29.302362.
Eval Step: 4531, Average Perplexity: 29.285674.
...(omitted. At convergence, it should be around 30.)
# Run dump_emb mode:
bazel-bin/lm_1b/lm_1b_eval --mode dump_emb \
--pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
--vocab_file data/vocab.txt \
--ckpt data/ckpt \
--save_dir output
...(omitted some TensorFlow output)
Finished softmax weights
Finished word embedding 0/793471
Finished word embedding 1/793471
Finished word embedding 2/793471
...(omitted)
ls output/
embeddings_softmax.npy ...
# Run dump_lstm_emb mode:
bazel-bin/lm_1b/lm_1b_eval --mode dump_lstm_emb \
--pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
--vocab_file data/vocab.txt \
--ckpt data/ckpt \
--sentence "I love who I am ." \
--save_dir output
ls output/
lstm_emb_step_0.npy lstm_emb_step_2.npy lstm_emb_step_4.npy
lstm_emb_step_6.npy lstm_emb_step_1.npy lstm_emb_step_3.npy
lstm_emb_step_5.npy
```
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A library for loading 1B word benchmark dataset."""
import random
import numpy as np
import tensorflow as tf
class Vocabulary(object):
"""Class that holds a vocabulary for the dataset."""
def __init__(self, filename):
"""Initialize vocabulary.
Args:
filename: Vocabulary file name.
"""
self._id_to_word = []
self._word_to_id = {}
self._unk = -1
self._bos = -1
self._eos = -1
with tf.gfile.Open(filename) as f:
idx = 0
for line in f:
word_name = line.strip()
if word_name == '<S>':
self._bos = idx
elif word_name == '</S>':
self._eos = idx
elif word_name == '<UNK>':
self._unk = idx
if word_name == '!!!MAXTERMID':
continue
self._id_to_word.append(word_name)
self._word_to_id[word_name] = idx
idx += 1
@property
def bos(self):
return self._bos
@property
def eos(self):
return self._eos
@property
def unk(self):
return self._unk
@property
def size(self):
return len(self._id_to_word)
def word_to_id(self, word):
if word in self._word_to_id:
return self._word_to_id[word]
return self.unk
def id_to_word(self, cur_id):
if cur_id < self.size:
return self._id_to_word[cur_id]
return 'ERROR'
def decode(self, cur_ids):
"""Convert a list of ids to a sentence, with space inserted."""
return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
def encode(self, sentence):
"""Convert a sentence to a list of ids, with special tokens added."""
word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
class CharsVocabulary(Vocabulary):
"""Vocabulary containing character-level information."""
def __init__(self, filename, max_word_length):
super(CharsVocabulary, self).__init__(filename)
self._max_word_length = max_word_length
chars_set = set()
for word in self._id_to_word:
chars_set |= set(word)
free_ids = []
for i in range(256):
if chr(i) in chars_set:
continue
free_ids.append(chr(i))
if len(free_ids) < 5:
raise ValueError('Not enough free char ids: %d' % len(free_ids))
self.bos_char = free_ids[0] # <begin sentence>
self.eos_char = free_ids[1] # <end sentence>
self.bow_char = free_ids[2] # <begin word>
self.eow_char = free_ids[3] # <end word>
self.pad_char = free_ids[4] # <padding>
chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char,
self.pad_char}
self._char_set = chars_set
num_words = len(self._id_to_word)
self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32)
self.bos_chars = self._convert_word_to_char_ids(self.bos_char)
self.eos_chars = self._convert_word_to_char_ids(self.eos_char)
for i, word in enumerate(self._id_to_word):
self._word_char_ids[i] = self._convert_word_to_char_ids(word)
@property
def word_char_ids(self):
return self._word_char_ids
@property
def max_word_length(self):
return self._max_word_length
def _convert_word_to_char_ids(self, word):
code = np.zeros([self.max_word_length], dtype=np.int32)
code[:] = ord(self.pad_char)
if len(word) > self.max_word_length - 2:
word = word[:self.max_word_length-2]
cur_word = self.bow_char + word + self.eow_char
for j in range(len(cur_word)):
code[j] = ord(cur_word[j])
return code
def word_to_char_ids(self, word):
if word in self._word_to_id:
return self._word_char_ids[self._word_to_id[word]]
else:
return self._convert_word_to_char_ids(word)
def encode_chars(self, sentence):
chars_ids = [self.word_to_char_ids(cur_word)
for cur_word in sentence.split()]
return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
def get_batch(generator, batch_size, num_steps, max_word_length, pad=False):
"""Read batches of input."""
cur_stream = [None] * batch_size
inputs = np.zeros([batch_size, num_steps], np.int32)
char_inputs = np.zeros([batch_size, num_steps, max_word_length], np.int32)
global_word_ids = np.zeros([batch_size, num_steps], np.int32)
targets = np.zeros([batch_size, num_steps], np.int32)
weights = np.ones([batch_size, num_steps], np.float32)
no_more_data = False
while True:
inputs[:] = 0
char_inputs[:] = 0
global_word_ids[:] = 0
targets[:] = 0
weights[:] = 0.0
for i in range(batch_size):
cur_pos = 0
while cur_pos < num_steps:
if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
try:
cur_stream[i] = list(generator.next())
except StopIteration:
# No more data, exhaust current streams and quit
no_more_data = True
break
how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
next_pos = cur_pos + how_many
inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
char_inputs[i, cur_pos:next_pos] = cur_stream[i][1][:how_many]
global_word_ids[i, cur_pos:next_pos] = cur_stream[i][2][:how_many]
targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many+1]
weights[i, cur_pos:next_pos] = 1.0
cur_pos = next_pos
cur_stream[i][0] = cur_stream[i][0][how_many:]
cur_stream[i][1] = cur_stream[i][1][how_many:]
cur_stream[i][2] = cur_stream[i][2][how_many:]
if pad:
break
if no_more_data and np.sum(weights) == 0:
# There is no more data and this is an empty batch. Done!
break
yield inputs, char_inputs, global_word_ids, targets, weights
class LM1BDataset(object):
"""Utility class for 1B word benchmark dataset.
The current implementation reads the data from the tokenized text files.
"""
def __init__(self, filepattern, vocab):
"""Initialize LM1BDataset reader.
Args:
filepattern: Dataset file pattern.
vocab: Vocabulary.
"""
self._vocab = vocab
self._all_shards = tf.gfile.Glob(filepattern)
tf.logging.info('Found %d shards at %s', len(self._all_shards), filepattern)
def _load_random_shard(self):
"""Randomly select a file and read it."""
return self._load_shard(random.choice(self._all_shards))
def _load_shard(self, shard_name):
"""Read one file and convert to ids.
Args:
shard_name: file path.
Returns:
list of (id, char_id, global_word_id) tuples.
"""
tf.logging.info('Loading data from: %s', shard_name)
with tf.gfile.Open(shard_name) as f:
sentences = f.readlines()
chars_ids = [self.vocab.encode_chars(sentence) for sentence in sentences]
ids = [self.vocab.encode(sentence) for sentence in sentences]
global_word_ids = []
current_idx = 0
for word_ids in ids:
current_size = len(word_ids) - 1 # without <BOS> symbol
cur_ids = np.arange(current_idx, current_idx + current_size)
global_word_ids.append(cur_ids)
current_idx += current_size
tf.logging.info('Loaded %d words.', current_idx)
tf.logging.info('Finished loading')
return zip(ids, chars_ids, global_word_ids)
def _get_sentence(self, forever=True):
while True:
ids = self._load_random_shard()
for current_ids in ids:
yield current_ids
if not forever:
break
def get_batch(self, batch_size, num_steps, pad=False, forever=True):
return get_batch(self._get_sentence(forever), batch_size, num_steps,
self.vocab.max_word_length, pad=pad)
@property
def vocab(self):
return self._vocab
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Eval pre-trained 1 billion word language model.
"""
import os
import sys
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
import data_utils
FLAGS = tf.flags.FLAGS
# General flags.
tf.flags.DEFINE_string('mode', 'eval',
'One of [sample, eval, dump_emb, dump_lstm_emb]. '
'"sample" mode samples future word predictions, using '
'FLAGS.prefix as prefix (prefix could be left empty). '
'"eval" mode calculates perplexity of the '
'FLAGS.input_data. '
'"dump_emb" mode dumps word and softmax embeddings to '
'FLAGS.save_dir. embeddings are dumped in the same '
'order as words in vocabulary. All words in vocabulary '
'are dumped.'
'dump_lstm_emb dumps lstm embeddings of FLAGS.sentence '
'to FLAGS.save_dir.')
tf.flags.DEFINE_string('pbtxt', '',
'GraphDef proto text file used to construct model '
'structure.')
tf.flags.DEFINE_string('ckpt', '',
'Checkpoint directory used to fill model values.')
tf.flags.DEFINE_string('vocab_file', '', 'Vocabulary file.')
tf.flags.DEFINE_string('save_dir', '',
'Used for "dump_emb" mode to save word embeddings.')
# sample mode flags.
tf.flags.DEFINE_string('prefix', '',
'Used for "sample" mode to predict next words.')
tf.flags.DEFINE_integer('max_sample_words', 100,
'Sampling stops either when </S> is met or this number '
'of steps has passed.')
tf.flags.DEFINE_integer('num_samples', 3,
'Number of samples to generate for the prefix.')
# dump_lstm_emb mode flags.
tf.flags.DEFINE_string('sentence', '',
'Used as input for "dump_lstm_emb" mode.')
# eval mode flags.
tf.flags.DEFINE_string('input_data', '',
'Input data files for eval model.')
tf.flags.DEFINE_integer('max_eval_steps', 1000000,
'Maximum mumber of steps to run "eval" mode.')
# For saving demo resources, use batch size 1 and step 1.
BATCH_SIZE = 1
NUM_TIMESTEPS = 1
MAX_WORD_LEN = 50
def _LoadModel(gd_file, ckpt_file):
"""Load the model from GraphDef and Checkpoint.
Args:
gd_file: GraphDef proto text file.
ckpt_file: TensorFlow Checkpoint file.
Returns:
TensorFlow session and tensors dict.
"""
with tf.Graph().as_default():
sys.stderr.write('Recovering graph.\n')
with tf.gfile.FastGFile(gd_file, 'r') as f:
s = f.read()
gd = tf.GraphDef()
text_format.Merge(s, gd)
tf.logging.info('Recovering Graph %s', gd_file)
t = {}
[t['states_init'], t['lstm/lstm_0/control_dependency'],
t['lstm/lstm_1/control_dependency'], t['softmax_out'], t['class_ids_out'],
t['class_weights_out'], t['log_perplexity_out'], t['inputs_in'],
t['targets_in'], t['target_weights_in'], t['char_inputs_in'],
t['all_embs'], t['softmax_weights'], t['global_step']
] = tf.import_graph_def(gd, {}, ['states_init',
'lstm/lstm_0/control_dependency:0',
'lstm/lstm_1/control_dependency:0',
'softmax_out:0',
'class_ids_out:0',
'class_weights_out:0',
'log_perplexity_out:0',
'inputs_in:0',
'targets_in:0',
'target_weights_in:0',
'char_inputs_in:0',
'all_embs_out:0',
'Reshape_3:0',
'global_step:0'], name='')
sys.stderr.write('Recovering checkpoint %s\n' % ckpt_file)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run('save/restore_all', {'save/Const:0': ckpt_file})
sess.run(t['states_init'])
return sess, t
def _EvalModel(dataset):
"""Evaluate model perplexity using provided dataset.
Args:
dataset: LM1BDataset object.
"""
sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
current_step = t['global_step'].eval(session=sess)
sys.stderr.write('Loaded step %d.\n' % current_step)
data_gen = dataset.get_batch(BATCH_SIZE, NUM_TIMESTEPS, forever=False)
sum_num = 0.0
sum_den = 0.0
perplexity = 0.0
for i, (inputs, char_inputs, _, targets, weights) in enumerate(data_gen):
input_dict = {t['inputs_in']: inputs,
t['targets_in']: targets,
t['target_weights_in']: weights}
if 'char_inputs_in' in t:
input_dict[t['char_inputs_in']] = char_inputs
log_perp = sess.run(t['log_perplexity_out'], feed_dict=input_dict)
if np.isnan(log_perp):
sys.stderr.error('log_perplexity is Nan.\n')
else:
sum_num += log_perp * weights.mean()
sum_den += weights.mean()
if sum_den > 0:
perplexity = np.exp(sum_num / sum_den)
sys.stderr.write('Eval Step: %d, Average Perplexity: %f.\n' %
(i, perplexity))
if i > FLAGS.max_eval_steps:
break
def _SampleSoftmax(softmax):
return min(np.sum(np.cumsum(softmax) < np.random.rand()), len(softmax) - 1)
def _SampleModel(prefix_words, vocab):
"""Predict next words using the given prefix words.
Args:
prefix_words: Prefix words.
vocab: Vocabulary. Contains max word chard id length and converts between
words and ids.
"""
targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
if prefix_words.find('<S>') != 0:
prefix_words = '<S> ' + prefix_words
prefix = [vocab.word_to_id(w) for w in prefix_words.split()]
prefix_char_ids = [vocab.word_to_char_ids(w) for w in prefix_words.split()]
for _ in xrange(FLAGS.num_samples):
inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
char_ids_inputs = np.zeros(
[BATCH_SIZE, NUM_TIMESTEPS, vocab.max_word_length], np.int32)
samples = prefix[:]
char_ids_samples = prefix_char_ids[:]
sent = ''
while True:
inputs[0, 0] = samples[0]
char_ids_inputs[0, 0, :] = char_ids_samples[0]
samples = samples[1:]
char_ids_samples = char_ids_samples[1:]
softmax = sess.run(t['softmax_out'],
feed_dict={t['char_inputs_in']: char_ids_inputs,
t['inputs_in']: inputs,
t['targets_in']: targets,
t['target_weights_in']: weights})
sample = _SampleSoftmax(softmax[0])
sample_char_ids = vocab.word_to_char_ids(vocab.id_to_word(sample))
if not samples:
samples = [sample]
char_ids_samples = [sample_char_ids]
sent += vocab.id_to_word(samples[0]) + ' '
sys.stderr.write('%s\n' % sent)
if (vocab.id_to_word(samples[0]) == '</S>' or
len(sent) > FLAGS.max_sample_words):
break
def _DumpEmb(vocab):
"""Dump the softmax weights and word embeddings to files.
Args:
vocab: Vocabulary. Contains vocabulary size and converts word to ids.
"""
assert FLAGS.save_dir, 'Must specify FLAGS.save_dir for dump_emb.'
inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
softmax_weights = sess.run(t['softmax_weights'])
fname = FLAGS.save_dir + '/embeddings_softmax.npy'
with tf.gfile.Open(fname, mode='w') as f:
np.save(f, softmax_weights)
sys.stderr.write('Finished softmax weights\n')
all_embs = np.zeros([vocab.size, 1024])
for i in range(vocab.size):
input_dict = {t['inputs_in']: inputs,
t['targets_in']: targets,
t['target_weights_in']: weights}
if 'char_inputs_in' in t:
input_dict[t['char_inputs_in']] = (
vocab.word_char_ids[i].reshape([-1, 1, MAX_WORD_LEN]))
embs = sess.run(t['all_embs'], input_dict)
all_embs[i, :] = embs
sys.stderr.write('Finished word embedding %d/%d\n' % (i, vocab.size))
fname = FLAGS.save_dir + '/embeddings_char_cnn.npy'
with tf.gfile.Open(fname, mode='w') as f:
np.save(f, all_embs)
sys.stderr.write('Embedding file saved\n')
def _DumpSentenceEmbedding(sentence, vocab):
"""Predict next words using the given prefix words.
Args:
sentence: Sentence words.
vocab: Vocabulary. Contains max word chard id length and converts between
words and ids.
"""
targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
if sentence.find('<S>') != 0:
sentence = '<S> ' + sentence
word_ids = [vocab.word_to_id(w) for w in sentence.split()]
char_ids = [vocab.word_to_char_ids(w) for w in sentence.split()]
inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
char_ids_inputs = np.zeros(
[BATCH_SIZE, NUM_TIMESTEPS, vocab.max_word_length], np.int32)
for i in xrange(len(word_ids)):
inputs[0, 0] = word_ids[i]
char_ids_inputs[0, 0, :] = char_ids[i]
# Add 'lstm/lstm_0/control_dependency' if you want to dump previous layer
# LSTM.
lstm_emb = sess.run(t['lstm/lstm_1/control_dependency'],
feed_dict={t['char_inputs_in']: char_ids_inputs,
t['inputs_in']: inputs,
t['targets_in']: targets,
t['target_weights_in']: weights})
fname = os.path.join(FLAGS.save_dir, 'lstm_emb_step_%d.npy' % i)
with tf.gfile.Open(fname, mode='w') as f:
np.save(f, lstm_emb)
sys.stderr.write('LSTM embedding step %d file saved\n' % i)
def main(unused_argv):
vocab = data_utils.CharsVocabulary(FLAGS.vocab_file, MAX_WORD_LEN)
if FLAGS.mode == 'eval':
dataset = data_utils.LM1BDataset(FLAGS.input_data, vocab)
_EvalModel(dataset)
elif FLAGS.mode == 'sample':
_SampleModel(FLAGS.prefix, vocab)
elif FLAGS.mode == 'dump_emb':
_DumpEmb(vocab)
elif FLAGS.mode == 'dump_lstm_emb':
_DumpSentenceEmbedding(FLAGS.sentence, vocab)
else:
raise Exception('Mode not supported.')
if __name__ == '__main__':
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment