"vscode:/vscode.git/clone" did not exist on "b48083f33f6784374106f7fcfc29f682af74a6ed"
Commit 68609ca7 authored by Christopher Shallue's avatar Christopher Shallue
Browse files

TF implementation of Skip Thoughts.

parent 51fcc99b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tracks training progress via per-word perplexity.
This script should be run concurrently with training so that summaries show up
in TensorBoard.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os.path
import time
import numpy as np
import tensorflow as tf
from skip_thoughts import configuration
from skip_thoughts import skip_thoughts_model
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("input_file_pattern", None,
"File pattern of sharded TFRecord input files.")
tf.flags.DEFINE_string("checkpoint_dir", None,
"Directory containing model checkpoints.")
tf.flags.DEFINE_string("eval_dir", None, "Directory to write event logs to.")
tf.flags.DEFINE_integer("eval_interval_secs", 600,
"Interval between evaluation runs.")
tf.flags.DEFINE_integer("num_eval_examples", 50000,
"Number of examples for evaluation.")
tf.flags.DEFINE_integer("min_global_step", 100,
"Minimum global step to run evaluation.")
tf.logging.set_verbosity(tf.logging.INFO)
def evaluate_model(sess, losses, weights, num_batches, global_step,
summary_writer, summary_op):
"""Computes perplexity-per-word over the evaluation dataset.
Summaries and perplexity-per-word are written out to the eval directory.
Args:
sess: Session object.
losses: A Tensor of any shape; the target cross entropy losses for the
current batch.
weights: A Tensor of weights corresponding to losses.
num_batches: Integer; the number of evaluation batches.
global_step: Integer; global step of the model checkpoint.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
"""
# Log model summaries on a single batch.
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, global_step)
start_time = time.time()
sum_losses = 0.0
sum_weights = 0.0
for i in xrange(num_batches):
batch_losses, batch_weights = sess.run([losses, weights])
sum_losses += np.sum(batch_losses * batch_weights)
sum_weights += np.sum(batch_weights)
if not i % 100:
tf.logging.info("Computed losses for %d of %d batches.", i + 1,
num_batches)
eval_time = time.time() - start_time
perplexity = math.exp(sum_losses / sum_weights)
tf.logging.info("Perplexity = %f (%.2f sec)", perplexity, eval_time)
# Log perplexity to the SummaryWriter.
summary = tf.Summary()
value = summary.value.add()
value.simple_value = perplexity
value.tag = "perplexity"
summary_writer.add_summary(summary, global_step)
# Write the Events file to the eval directory.
summary_writer.flush()
tf.logging.info("Finished processing evaluation at global step %d.",
global_step)
def run_once(model, losses, weights, saver, summary_writer, summary_op):
"""Evaluates the latest model checkpoint.
Args:
model: Instance of SkipThoughtsModel; the model to evaluate.
losses: Tensor; the target cross entropy losses for the current batch.
weights: A Tensor of weights corresponding to losses.
saver: Instance of tf.train.Saver for restoring model Variables.
summary_writer: Instance of FileWriter.
summary_op: Op for generating model summaries.
"""
model_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if not model_path:
tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
FLAGS.checkpoint_dir)
return
with tf.Session() as sess:
# Load model from checkpoint.
tf.logging.info("Loading model from checkpoint: %s", model_path)
saver.restore(sess, model_path)
global_step = tf.train.global_step(sess, model.global_step.name)
tf.logging.info("Successfully loaded %s at global step = %d.",
os.path.basename(model_path), global_step)
if global_step < FLAGS.min_global_step:
tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
FLAGS.min_global_step)
return
# Start the queue runners.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
num_eval_batches = int(
math.ceil(FLAGS.num_eval_examples / model.config.batch_size))
# Run evaluation on the latest checkpoint.
try:
evaluate_model(sess, losses, weights, num_eval_batches, global_step,
summary_writer, summary_op)
except tf.InvalidArgumentError:
tf.logging.error(
"Evaluation raised InvalidArgumentError (e.g. due to Nans).")
finally:
coord.request_stop()
coord.join(threads, stop_grace_period_secs=10)
def main(unused_argv):
if not FLAGS.input_file_pattern:
raise ValueError("--input_file_pattern is required.")
if not FLAGS.checkpoint_dir:
raise ValueError("--checkpoint_dir is required.")
if not FLAGS.eval_dir:
raise ValueError("--eval_dir is required.")
# Create the evaluation directory if it doesn't exist.
eval_dir = FLAGS.eval_dir
if not tf.gfile.IsDirectory(eval_dir):
tf.logging.info("Creating eval directory: %s", eval_dir)
tf.gfile.MakeDirs(eval_dir)
g = tf.Graph()
with g.as_default():
# Build the model for evaluation.
model_config = configuration.model_config(
input_file_pattern=FLAGS.input_file_pattern,
input_queue_capacity=FLAGS.num_eval_examples,
shuffle_input_data=False)
model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="eval")
model.build()
losses = tf.concat(model.target_cross_entropy_losses, 0)
weights = tf.concat(model.target_cross_entropy_loss_weights, 0)
# Create the Saver to restore model Variables.
saver = tf.train.Saver()
# Create the summary operation and the summary writer.
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(eval_dir)
g.finalize()
# Run a new evaluation run every eval_interval_secs.
while True:
start = time.time()
tf.logging.info("Starting evaluation at " + time.strftime(
"%Y-%m-%d-%H:%M:%S", time.localtime()))
run_once(model, losses, weights, saver, summary_writer, summary_op)
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train the skip-thoughts model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from skip_thoughts import configuration
from skip_thoughts import skip_thoughts_model
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("input_file_pattern", None,
"File pattern of sharded TFRecord files containing "
"tf.Example protos.")
tf.flags.DEFINE_string("train_dir", None,
"Directory for saving and loading checkpoints.")
tf.logging.set_verbosity(tf.logging.INFO)
def _setup_learning_rate(config, global_step):
"""Sets up the learning rate with optional exponential decay.
Args:
config: Object containing learning rate configuration parameters.
global_step: Tensor; the global step.
Returns:
learning_rate: Tensor; the learning rate with exponential decay.
"""
if config.learning_rate_decay_factor > 0:
learning_rate = tf.train.exponential_decay(
learning_rate=float(config.learning_rate),
global_step=global_step,
decay_steps=config.learning_rate_decay_steps,
decay_rate=config.learning_rate_decay_factor,
staircase=False)
else:
learning_rate = tf.constant(config.learning_rate)
return learning_rate
def main(unused_argv):
if not FLAGS.input_file_pattern:
raise ValueError("--input_file_pattern is required.")
if not FLAGS.train_dir:
raise ValueError("--train_dir is required.")
model_config = configuration.model_config(
input_file_pattern=FLAGS.input_file_pattern)
training_config = configuration.training_config()
tf.logging.info("Building training graph.")
g = tf.Graph()
with g.as_default():
model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="train")
model.build()
learning_rate = _setup_learning_rate(training_config, model.global_step)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_tensor = tf.contrib.slim.learning.create_train_op(
total_loss=model.total_loss,
optimizer=optimizer,
global_step=model.global_step,
clip_gradient_norm=training_config.clip_gradient_norm)
saver = tf.train.Saver()
tf.contrib.slim.learning.train(
train_op=train_tensor,
logdir=FLAGS.train_dir,
graph=g,
global_step=model.global_step,
number_of_steps=training_config.number_of_steps,
save_summaries_secs=training_config.save_summaries_secs,
saver=saver,
save_interval_secs=training_config.save_model_secs)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compute an expanded vocabulary of embeddings using a word2vec model.
This script loads the word embeddings from a trained skip-thoughts model and
from a trained word2vec model (typically with a larger vocabulary). It trains a
linear regression model without regularization to learn a linear mapping from
the word2vec embedding space to the skip-thoughts embedding space. The model is
then applied to all words in the word2vec vocabulary, yielding vectors in the
skip-thoughts word embedding space for the union of the two vocabularies.
The linear regression task is to learn a parameter matrix W to minimize
|| X - Y * W ||^2,
where X is a matrix of skip-thoughts embeddings of shape [num_words, dim1],
Y is a matrix of word2vec embeddings of shape [num_words, dim2], and W is a
matrix of shape [dim2, dim1].
This is based on the "Translation Matrix" method from the paper:
"Exploiting Similarities among Languages for Machine Translation"
Tomas Mikolov, Quoc V. Le, Ilya Sutskever
https://arxiv.org/abs/1309.4168
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os.path
import gensim.models
import numpy as np
import sklearn.linear_model
import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("skip_thoughts_model", None,
"Checkpoint file or directory containing a checkpoint "
"file.")
tf.flags.DEFINE_string("skip_thoughts_vocab", None,
"Path to vocabulary file containing a list of newline-"
"separated words where the word id is the "
"corresponding 0-based index in the file.")
tf.flags.DEFINE_string("word2vec_model", None,
"File containing a word2vec model in binary format.")
tf.flags.DEFINE_string("output_dir", None, "Output directory.")
tf.logging.set_verbosity(tf.logging.INFO)
def _load_skip_thoughts_embeddings(checkpoint_path):
"""Loads the embedding matrix from a skip-thoughts model checkpoint.
Args:
checkpoint_path: Model checkpoint file or directory containing a checkpoint
file.
Returns:
word_embedding: A numpy array of shape [vocab_size, embedding_dim].
Raises:
ValueError: If no checkpoint file matches checkpoint_path.
"""
if tf.gfile.IsDirectory(checkpoint_path):
checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)
if not checkpoint_file:
raise ValueError("No checkpoint file found in %s" % checkpoint_path)
else:
checkpoint_file = checkpoint_path
tf.logging.info("Loading skip-thoughts embedding matrix from %s",
checkpoint_file)
reader = tf.train.NewCheckpointReader(checkpoint_file)
word_embedding = reader.get_tensor("word_embedding")
tf.logging.info("Loaded skip-thoughts embedding matrix of shape %s",
word_embedding.shape)
return word_embedding
def _load_vocabulary(filename):
"""Loads a vocabulary file.
Args:
filename: Path to text file containing newline-separated words.
Returns:
vocab: A dictionary mapping word to word id.
"""
tf.logging.info("Reading vocabulary from %s", filename)
vocab = collections.OrderedDict()
with tf.gfile.GFile(filename, mode="r") as f:
for i, line in enumerate(f):
word = line.decode("utf-8").strip()
assert word not in vocab, "Attempting to add word twice: %s" % word
vocab[word] = i
tf.logging.info("Read vocabulary of size %d", len(vocab))
return vocab
def _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab, word2vec):
"""Runs vocabulary expansion on a skip-thoughts model using a word2vec model.
Args:
skip_thoughts_emb: A numpy array of shape [skip_thoughts_vocab_size,
skip_thoughts_embedding_dim].
skip_thoughts_vocab: A dictionary of word to id.
word2vec: An instance of gensim.models.Word2Vec.
Returns:
combined_emb: A dictionary mapping words to embedding vectors.
"""
# Find words shared between the two vocabularies.
tf.logging.info("Finding shared words")
shared_words = [w for w in word2vec.vocab if w in skip_thoughts_vocab]
# Select embedding vectors for shared words.
tf.logging.info("Selecting embeddings for %d shared words", len(shared_words))
shared_st_emb = skip_thoughts_emb[[
skip_thoughts_vocab[w] for w in shared_words
]]
shared_w2v_emb = word2vec[shared_words]
# Train a linear regression model on the shared embedding vectors.
tf.logging.info("Training linear regression model")
model = sklearn.linear_model.LinearRegression()
model.fit(shared_w2v_emb, shared_st_emb)
# Create the expanded vocabulary.
tf.logging.info("Creating embeddings for expanded vocabuary")
combined_emb = collections.OrderedDict()
for w in word2vec.vocab:
# Ignore words with underscores (spaces).
if "_" not in w:
w_emb = model.predict(word2vec[w].reshape(1, -1))
combined_emb[w] = w_emb.reshape(-1)
for w in skip_thoughts_vocab:
combined_emb[w] = skip_thoughts_emb[skip_thoughts_vocab[w]]
tf.logging.info("Created expanded vocabulary of %d words", len(combined_emb))
return combined_emb
def main(unused_argv):
if not FLAGS.skip_thoughts_model:
raise ValueError("--skip_thoughts_model is required.")
if not FLAGS.skip_thoughts_vocab:
raise ValueError("--skip_thoughts_vocab is required.")
if not FLAGS.word2vec_model:
raise ValueError("--word2vec_model is required.")
if not FLAGS.output_dir:
raise ValueError("--output_dir is required.")
if not tf.gfile.IsDirectory(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
# Load the skip-thoughts embeddings and vocabulary.
skip_thoughts_emb = _load_skip_thoughts_embeddings(FLAGS.skip_thoughts_model)
skip_thoughts_vocab = _load_vocabulary(FLAGS.skip_thoughts_vocab)
# Load the Word2Vec model.
word2vec = gensim.models.Word2Vec.load_word2vec_format(
FLAGS.word2vec_model, binary=True)
# Run vocabulary expansion.
embedding_map = _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab,
word2vec)
# Save the output.
vocab = embedding_map.keys()
vocab_file = os.path.join(FLAGS.output_dir, "vocab.txt")
with tf.gfile.GFile(vocab_file, "w") as f:
f.write("\n".join(vocab))
tf.logging.info("Wrote vocabulary file to %s", vocab_file)
embeddings = np.array(embedding_map.values())
embeddings_file = os.path.join(FLAGS.output_dir, "embeddings.npy")
np.save(embeddings_file, embeddings)
tf.logging.info("Wrote embeddings file to %s", embeddings_file)
if __name__ == "__main__":
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment