Commit 68609ca7 authored by Christopher Shallue's avatar Christopher Shallue
Browse files

TF implementation of Skip Thoughts.

parent 51fcc99b
......@@ -21,6 +21,7 @@ To propose a model for inclusion please submit a pull request.
- [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
- [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
- [resnet](resnet): deep and wide residual networks.
- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder.
- [slim](slim): image classification models in TF-Slim.
- [street](street): identify the name of a street (in France) from an image using a Deep RNN.
- [swivel](swivel): the Swivel algorithm for generating word embeddings.
......
/bazel-bin
/bazel-ci_build-cache
/bazel-genfiles
/bazel-out
/bazel-skip_thoughts
/bazel-testlogs
/bazel-tf
*.pyc
# Skip-Thought Vectors
This is a TensorFlow implementation of the model described in:
Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
Antonio Torralba, Raquel Urtasun, Sanja Fidler.
[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf).
*In NIPS, 2015.*
## Contact
***Code author:*** Chris Shallue
***Pull requests and issues:*** @cshallue
## Contents
* [Model Overview](#model-overview)
* [Getting Started](#getting-started)
* [Install Required Packages](#install-required-packages)
* [Download Pretrained Models (Optional)](#download-pretrained-models-optional)
* [Training a Model](#training-a-model)
* [Prepare the Training Data](#prepare-the-training-data)
* [Run the Training Script](#run-the-training-script)
* [Track Training Progress](#track-training-progress)
* [Expanding the Vocabulary](#expanding-the-vocabulary)
* [Overview](#overview)
* [Preparation](#preparation)
* [Run the Vocabulary Expansion Script](#run-the-vocabulary-expansion-script)
* [Evaluating a Model](#evaluating-a-model)
* [Overview](#overview-1)
* [Preparation](#preparation-1)
* [Run the Evaluation Tasks](#run-the-evaluation-tasks)
* [Encoding Sentences](#encoding-sentences)
## Model overview
The *Skip-Thoughts* model is a sentence encoder. It learns to encode input
sentences into a fixed-dimensional vector representation that is useful for many
tasks, for example to detect paraphrases or to classify whether a product review
is positive or negative. See the
[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
paper for details of the model architecture and more example applications.
A trained *Skip-Thoughts* model will encode similar sentences nearby each other
in the embedding vector space. The following examples show the nearest neighbor by
cosine similarity of some sentences from the
[movie review dataset](https://www.cs.cornell.edu/people/pabo/movie-review-data/).
| Input sentence | Nearest Neighbor |
|----------------|------------------|
| Simplistic, silly and tedious. | Trite, banal, cliched, mostly inoffensive. |
| Not so much farcical as sour. | Not only unfunny, but downright repellent. |
| A sensitive and astute first feature by Anne-Sophie Birot. | Absorbing character study by André Turpin . |
| An enthralling, entertaining feature. | A slick, engrossing melodrama. |
## Getting Started
### Install Required Packages
First ensure that you have installed the following required packages:
* **Bazel** ([instructions](http://bazel.build/docs/install.html))
* **TensorFlow** ([instructions](https://www.tensorflow.org/install/))
* **NumPy** ([instructions](http://www.scipy.org/install.html))
* **scikit-learn** ([instructions](http://scikit-learn.org/stable/install.html))
* **Natural Language Toolkit (NLTK)**
* First install NLTK ([instructions](http://www.nltk.org/install.html))
* Then install the NLTK data ([instructions](http://www.nltk.org/data.html))
* **gensim** ([instructions](https://radimrehurek.com/gensim/install.html))
* Only required if you will be expanding your vocabulary with the [word2vec](https://code.google.com/archive/p/word2vec/) model.
### Download Pretrained Models (Optional)
You can download model checkpoints pretrained on the
[BookCorpus](http://yknzhu.wixsite.com/mbweb) dataset in the following
configurations:
* Unidirectional RNN encoder ("uni-skip" in the paper)
* Bidirectional RNN encoder ("bi-skip" in the paper)
```shell
# Directory to download the pretrained models to.
PRETRAINED_MODELS_DIR="${HOME}/skip_thoughts/pretrained/"
mkdir -p ${PRETRAINED_MODELS_DIR}
cd ${PRETRAINED_MODELS_DIR}
# Download and extract the unidirectional model.
wget "http://download.tensorflow.org/models/skip_thoughts_uni_2017_02_02.tar.gz"
tar -xvf skip_thoughts_uni_2017_02_02.tar
rm skip_thoughts_uni_2017_02_02.tar
# Download and extract the bidirectional model.
wget "http://download.tensorflow.org/models/skip_thoughts_bi_2017_02_16.tar.gz"
tar -xvf skip_thoughts_bi_2017_02_16.tar
rm skip_thoughts_bi_2017_02_16.tar
```
You can now skip to the sections [Evaluating a Model](#evaluating-a-model) and
[Encoding Sentences](#encoding-sentences).
## Training a Model
### Prepare the Training Data
To train a model you will need to provide training data in TFRecord format. The
TFRecord format consists of a set of sharded files containing serialized
`tf.Example` protocol buffers. Each `tf.Example` proto contains three
sentences:
* `encode`: The sentence to encode.
* `decode_pre`: The sentence preceding `encode` in the original text.
* `decode_post`: The sentence following `encode` in the original text.
Each sentence is a list of words. During preprocessing, a dictionary is created
that assigns each word in the vocabulary to an integer-valued id. Each sentence
is encoded as a list of integer word ids in the `tf.Example` protos.
We have provided a script to preprocess any set of text-files into this format.
You may wish to use the [BookCorpus](http://yknzhu.wixsite.com/mbweb) dataset.
Note that the preprocessing script may take **12 hours** or more to complete
on this large dataset.
```shell
# Comma-separated list of globs matching the input input files. The format of
# the input files is assumed to be a list of newline-separated sentences, where
# each sentence is already tokenized.
INPUT_FILES="${HOME}/skip_thoughts/bookcorpus/*.txt"
# Location to save the preprocessed training and validation data.
DATA_DIR="${HOME}/skip_thoughts/data"
# Build the preprocessing script.
bazel build -c opt skip_thoughts/data/preprocess_dataset
# Run the preprocessing script.
bazel-bin/skip_thoughts/data/preprocess_dataset \
--input_files=${INPUT_FILES} \
--output_dir=${DATA_DIR}
```
When the script finishes you will find 100 training files and 1 validation file
in `DATA_DIR`. The files will match the patterns `train-?????-of-00100` and
`validation-00000-of-00001` respectively.
The script will also produce a file named `vocab.txt`. The format of this file
is a list of newline-separated words where the word id is the corresponding 0-
based line index. Words are sorted by descending order of frequency in the input
data. Only the top 20,000 words are assigned unique ids; all other words are
assigned the "unknown id" of 1 in the processed data.
### Run the Training Script
Execute the following commands to start the training script. By default it will
run for 500k steps (around 9 days on a GeForce GTX 1080 GPU).
```shell
# Directory containing the preprocessed data.
DATA_DIR="${HOME}/skip_thoughts/data"
# Directory to save the model.
MODEL_DIR="${HOME}/skip_thoughts/model"
# Build the model.
bazel build -c opt skip_thoughts/...
# Run the training script.
bazel-bin/skip_thoughts/train \
--input_file_pattern="${DATA_DIR}/train-?????-of-00100" \
--train_dir="${MODEL_DIR}/train"
```
### Track Training Progress
Optionally, you can run the `track_perplexity` script in a separate process.
This will log per-word perplexity on the validation set which allows training
progress to be monitored on
[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
Note that you may run out of memory if you run the this script on the same GPU
as the training script. You can set the environment variable
`CUDA_VISIBLE_DEVICES=""` to force the script to run on CPU. If it runs too
slowly on CPU, you can decrease the value of `--num_eval_examples`.
```shell
DATA_DIR="${HOME}/skip_thoughts/data"
MODEL_DIR="${HOME}/skip_thoughts/model"
# Ignore GPU devices (only necessary if your GPU is currently memory
# constrained, for example, by running the training script).
export CUDA_VISIBLE_DEVICES=""
# Run the evaluation script. This will run in a loop, periodically loading the
# latest model checkpoint file and computing evaluation metrics.
bazel-bin/skip_thoughts/track_perplexity \
--input_file_pattern="${DATA_DIR}/validation-?????-of-00001" \
--checkpoint_dir="${MODEL_DIR}/train" \
--eval_dir="${MODEL_DIR}/val" \
--num_eval_examples=50000
```
If you started the `track_perplexity` script, run a
[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
server in a separate process for real-time monitoring of training summaries and
validation perplexity.
```shell
MODEL_DIR="${HOME}/skip_thoughts/model"
# Run a TensorBoard server.
tensorboard --logdir="${MODEL_DIR}"
```
## Expanding the Vocabulary
### Overview
The vocabulary generated by the preprocessing script contains only 20,000 words
which is insufficient for many tasks. For example, a sentence from Wikipedia
might contain nouns that do not appear in this vocabulary.
A solution to this problem described in the
[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
paper is to learn a mapping that transfers word representations from one model to
another. This idea is based on the "Translation Matrix" method from the paper
[Exploiting Similarities Among Languages for Machine Translation](https://arxiv.org/abs/1309.4168).
Specifically, we will load the word embeddings from a trained *Skip-Thoughts*
model and from a trained [word2vec model](https://arxiv.org/pdf/1301.3781.pdf)
(which has a much larger vocabulary). We will train a linear regression model
without regularization to learn a linear mapping from the word2vec embedding
space to the *Skip-Thoughts* embedding space. We will then apply the linear
model to all words in the word2vec vocabulary, yielding vectors in the *Skip-
Thoughts* word embedding space for the union of the two vocabularies.
The linear regression task is to learn a parameter matrix *W* to minimize
*|| X - Y \* W ||<sup>2</sup>*, where *X* is a matrix of *Skip-Thoughts*
embeddings of shape `[num_words, dim1]`, *Y* is a matrix of word2vec embeddings
of shape `[num_words, dim2]`, and *W* is a matrix of shape `[dim2, dim1]`.
### Preparation
First you will need to download and unpack a pretrained
[word2vec model](https://arxiv.org/pdf/1301.3781.pdf) from
[this website](https://code.google.com/archive/p/word2vec/)
([direct download link](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing)).
This model was trained on the Google News dataset (about 100 billion words).
Also ensure that you have already [installed gensim](https://radimrehurek.com/gensim/install.html).
### Run the Vocabulary Expansion Script
```shell
# Path to checkpoint file or a directory containing checkpoint files (the script
# will select the most recent).
CHECKPOINT_PATH="${HOME}/skip_thoughts/model/train"
# Vocabulary file generated by the preprocessing script.
SKIP_THOUGHTS_VOCAB="${HOME}/skip_thoughts/data/vocab.txt"
# Path to downloaded word2vec model.
WORD2VEC_MODEL="${HOME}/skip_thoughts/googlenews/GoogleNews-vectors-negative300.bin"
# Output directory.
EXP_VOCAB_DIR="${HOME}/skip_thoughts/exp_vocab"
# Build the vocabulary expansion script.
bazel build -c opt skip_thoughts/vocabulary_expansion
# Run the vocabulary expansion script.
bazel-bin/skip_thoughts/vocabulary_expansion \
--skip_thoughts_model=${CHECKPOINT_PATH} \
--skip_thoughts_vocab=${SKIP_THOUGHTS_VOCAB} \
--word2vec_model=${WORD2VEC_MODEL} \
--output_dir=${EXP_VOCAB_DIR}
```
## Evaluating a Model
### Overview
The model can be evaluated using the benchmark tasks described in the
[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
paper. The following tasks are suported (refer to the paper for full details):
* **SICK** semantic relatedness task.
* **MSRP** (Microsoft Research Paraphrase Corpus) paraphrase detection task.
* Binary classification tasks:
* **MR** movie review sentiment task.
* **CR** customer product review task.
* **SUBJ** subjectivity/objectivity task.
* **MPQA** opinion polarity task.
* **TREC** question-type classification task.
### Preparation
You will need to clone or download the
[skip-thoughts GitHub repository](https://github.com/ryankiros/skip-thoughts) by
[ryankiros](https://github.com/ryankiros) (the first author of the Skip-Thoughts
paper):
```shell
# Folder to clone the repository to.
ST_KIROS_DIR="${HOME}/skip_thoughts/skipthoughts_kiros"
# Clone the repository.
git clone git@github.com:ryankiros/skip-thoughts.git "${ST_KIROS_DIR}/skipthoughts"
# Make the package importable.
export PYTHONPATH="${ST_KIROS_DIR}/:${PYTHONPATH}"
```
You will also need to download the data needed for each evaluation task. See the
instructions [here](https://github.com/ryankiros/skip-thoughts).
For example, the CR (customer review) dataset is found [here](http://nlp.stanford.edu/~sidaw/home/projects:nbsvm). For this task we want the
files `custrev.pos` and `custrev.neg`.
### Run the Evaluation Tasks
In the following example we will evaluate a unidirectional model ("uni-skip" in
the paper) on the CR task. To use a bidirectional model ("bi-skip" in the
paper), simply pass the flags `--bi_vocab_file`, `--bi_embeddings_file` and
`--bi_checkpoint_path` instead. To use the "combine-skip" model described in the
paper you will need to pass both the unidirectional and bidirectional flags.
```shell
# Path to checkpoint file or a directory containing checkpoint files (the script
# will select the most recent).
CHECKPOINT_PATH="${HOME}/skip_thoughts/model/train"
# Vocabulary file generated by the vocabulary expansion script.
VOCAB_FILE="${HOME}/skip_thoughts/exp_vocab/vocab.txt"
# Embeddings file generated by the vocabulary expansion script.
EMBEDDINGS_FILE="${HOME}/skip_thoughts/exp_vocab/embeddings.npy"
# Directory containing files custrev.pos and custrev.neg.
EVAL_DATA_DIR="${HOME}/skip_thoughts/eval_data"
# Build the evaluation script.
bazel build -c opt skip_thoughts/evaluate
# Run the evaluation script.
bazel-bin/skip_thoughts/evaluate \
--eval_task=CR \
--data_dir=${EVAL_DATA_DIR} \
--uni_vocab_file=${VOCAB_FILE} \
--uni_embeddings_file=${EMBEDDINGS_FILE} \
--uni_checkpoint_path=${CHECKPOINT_PATH}
```
Output:
```python
[0.82539682539682535, 0.84084880636604775, 0.83023872679045096,
0.86206896551724133, 0.83554376657824936, 0.85676392572944293,
0.84084880636604775, 0.83023872679045096, 0.85145888594164454,
0.82758620689655171]
```
The output is a list of accuracies of 10 cross-validation classification models.
To get a single number, simply take the average:
```python
ipython # Launch iPython.
In [0]:
import numpy as np
np.mean([0.82539682539682535, 0.84084880636604775, 0.83023872679045096,
0.86206896551724133, 0.83554376657824936, 0.85676392572944293,
0.84084880636604775, 0.83023872679045096, 0.85145888594164454,
0.82758620689655171])
Out [0]: 0.84009936423729525
```
## Encoding Sentences
In this example we will encode data from the
[movie review dataset](https://www.cs.cornell.edu/people/pabo/movie-review-data/)
(specifically the [sentence polarity dataset v1.0](https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz)).
```python
ipython # Launch iPython.
In [0]:
# Imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os.path
import scipy.spatial.distance as sd
from skip_thoughts import configuration
from skip_thoughts import combined_encoder
In [1]:
# Set paths to the model.
VOCAB_FILE = "/path/to/vocab.txt"
EMBEDDING_MATRIX_FILE = "/path/to/embeddings.npy"
CHECKPOINT_PATH = "/path/to/model.ckpt-9999"
# The following directory should contain files rt-polarity.neg and
# rt-polarity.pos.
MR_DATA_DIR = "/dir/containing/mr/data"
In [2]:
# Set up the encoder. Here we are using a single unidirectional model.
# To use a bidirectional model as well, call load_encoder() again with
# configuration.ModelConfig(bidirectional_encoder=True) and paths to the
# bidirectional model's files. The encoder will use the concatenation of
# all loaded models.
encoder = combined_encoder.CombinedEncoder()
encoder.load_encoder(configuration.ModelConfig(),
vocabulary_file=VOCAB_FILE,
embedding_matrix_file=EMBEDDING_MATRIX_FILE,
checkpoint_path=CHECKPOINT_PATH)
In [3]:
# Load the movie review dataset.
data = []
with open(os.path.join(MR_DATA_DIR, 'rt-polarity.neg'), 'rb') as f:
data.extend([line.decode('latin-1').strip() for line in f])
with open(os.path.join(MR_DATA_DIR, 'rt-polarity.pos'), 'rb') as f:
data.extend([line.decode('latin-1').strip() for line in f])
In [4]:
# Generate Skip-Thought Vectors for each sentence in the dataset.
encodings = encoder.encode(data)
In [5]:
# Define a helper function to generate nearest neighbors.
def get_nn(ind, num=10):
encoding = encodings[ind]
scores = sd.cdist([encoding], encodings, "cosine")[0]
sorted_ids = np.argsort(scores)
print("Sentence:")
print("", data[ind])
print("\nNearest neighbors:")
for i in range(1, num + 1):
print(" %d. %s (%.3f)" %
(i, data[sorted_ids[i]], scores[sorted_ids[i]]))
In [6]:
# Compute nearest neighbors of the first sentence in the dataset.
get_nn(0)
```
Output:
```
Sentence:
simplistic , silly and tedious .
Nearest neighbors:
1. trite , banal , cliched , mostly inoffensive . (0.247)
2. banal and predictable . (0.253)
3. witless , pointless , tasteless and idiotic . (0.272)
4. loud , silly , stupid and pointless . (0.295)
5. grating and tedious . (0.299)
6. idiotic and ugly . (0.330)
7. black-and-white and unrealistic . (0.335)
8. hopelessly inane , humorless and under-inspired . (0.335)
9. shallow , noisy and pretentious . (0.340)
10. . . . unlikable , uninteresting , unfunny , and completely , utterly inept . (0.346)
```
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//skip_thoughts/...",
],
)
py_library(
name = "configuration",
srcs = ["configuration.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "skip_thoughts_model",
srcs = ["skip_thoughts_model.py"],
srcs_version = "PY2AND3",
deps = [
"//skip_thoughts/ops:gru_cell",
"//skip_thoughts/ops:input_ops",
],
)
py_test(
name = "skip_thoughts_model_test",
size = "large",
srcs = ["skip_thoughts_model_test.py"],
deps = [
":configuration",
":skip_thoughts_model",
],
)
py_binary(
name = "train",
srcs = ["train.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":skip_thoughts_model",
],
)
py_binary(
name = "track_perplexity",
srcs = ["track_perplexity.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":skip_thoughts_model",
],
)
py_binary(
name = "vocabulary_expansion",
srcs = ["vocabulary_expansion.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "skip_thoughts_encoder",
srcs = ["skip_thoughts_encoder.py"],
srcs_version = "PY2AND3",
deps = [
":skip_thoughts_model",
"//skip_thoughts/data:special_words",
],
)
py_library(
name = "encoder_manager",
srcs = ["encoder_manager.py"],
srcs_version = "PY2AND3",
deps = [
":skip_thoughts_encoder",
],
)
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
srcs_version = "PY2AND3",
deps = [
":encoder_manager",
"//skip_thoughts:configuration",
],
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Default configuration for model architecture and training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class _HParams(object):
"""Wrapper for configuration parameters."""
pass
def model_config(input_file_pattern=None,
input_queue_capacity=640000,
num_input_reader_threads=1,
shuffle_input_data=True,
uniform_init_scale=0.1,
vocab_size=20000,
batch_size=128,
word_embedding_dim=620,
bidirectional_encoder=False,
encoder_dim=2400):
"""Creates a model configuration object.
Args:
input_file_pattern: File pattern of sharded TFRecord files containing
tf.Example protobufs.
input_queue_capacity: Number of examples to keep in the input queue.
num_input_reader_threads: Number of threads for prefetching input
tf.Examples.
shuffle_input_data: Whether to shuffle the input data.
uniform_init_scale: Scale of random uniform initializer.
vocab_size: Number of unique words in the vocab.
batch_size: Batch size (training and evaluation only).
word_embedding_dim: Word embedding dimension.
bidirectional_encoder: Whether to use a bidirectional or unidirectional
encoder RNN.
encoder_dim: Number of output dimensions of the sentence encoder.
Returns:
An object containing model configuration parameters.
"""
config = _HParams()
config.input_file_pattern = input_file_pattern
config.input_queue_capacity = input_queue_capacity
config.num_input_reader_threads = num_input_reader_threads
config.shuffle_input_data = shuffle_input_data
config.uniform_init_scale = uniform_init_scale
config.vocab_size = vocab_size
config.batch_size = batch_size
config.word_embedding_dim = word_embedding_dim
config.bidirectional_encoder = bidirectional_encoder
config.encoder_dim = encoder_dim
return config
def training_config(learning_rate=0.0008,
learning_rate_decay_factor=0.5,
learning_rate_decay_steps=400000,
number_of_steps=500000,
clip_gradient_norm=5.0,
save_model_secs=600,
save_summaries_secs=600):
"""Creates a training configuration object.
Args:
learning_rate: Initial learning rate.
learning_rate_decay_factor: If > 0, the learning rate decay factor.
learning_rate_decay_steps: The number of steps before the learning rate
decays by learning_rate_decay_factor.
number_of_steps: The total number of training steps to run. Passing None
will cause the training script to run indefinitely.
clip_gradient_norm: If not None, then clip gradients to this value.
save_model_secs: How often (in seconds) to save model checkpoints.
save_summaries_secs: How often (in seconds) to save model summaries.
Returns:
An object containing training configuration parameters.
Raises:
ValueError: If learning_rate_decay_factor is set and
learning_rate_decay_steps is unset.
"""
if learning_rate_decay_factor and not learning_rate_decay_steps:
raise ValueError(
"learning_rate_decay_factor requires learning_rate_decay_steps.")
config = _HParams()
config.learning_rate = learning_rate
config.learning_rate_decay_factor = learning_rate_decay_factor
config.learning_rate_decay_steps = learning_rate_decay_steps
config.number_of_steps = number_of_steps
config.clip_gradient_norm = clip_gradient_norm
config.save_model_secs = save_model_secs
config.save_summaries_secs = save_summaries_secs
return config
package(default_visibility = ["//skip_thoughts:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "special_words",
srcs = ["special_words.py"],
srcs_version = "PY2AND3",
deps = [],
)
py_binary(
name = "preprocess_dataset",
srcs = [
"preprocess_dataset.py",
],
srcs_version = "PY2AND3",
deps = [
":special_words",
],
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts a set of text files to TFRecord format with Example protos.
Each Example proto in the output contains the following fields:
decode_pre: list of int64 ids corresponding to the "previous" sentence.
encode: list of int64 ids corresponding to the "current" sentence.
decode_post: list of int64 ids corresponding to the "post" sentence.
In addition, the following files are generated:
vocab.txt: List of "<word> <id>" pairs, where <id> is the integer
encoding of <word> in the Example protos.
word_counts.txt: List of "<word> <count>" pairs, where <count> is the number
of occurrences of <word> in the input files.
The vocabulary of word ids is constructed from the top --num_words by word
count. All other words get the <unk> word id.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import numpy as np
import tensorflow as tf
from skip_thoughts.data import special_words
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("input_files", None,
"Comma-separated list of globs matching the input "
"files. The format of the input files is assumed to be "
"a list of newline-separated sentences, where each "
"sentence is already tokenized.")
tf.flags.DEFINE_string("vocab_file", "",
"(Optional) existing vocab file. Otherwise, a new vocab "
"file is created and written to the output directory. "
"The file format is a list of newline-separated words, "
"where the word id is the corresponding 0-based index "
"in the file.")
tf.flags.DEFINE_string("output_dir", None, "Output directory.")
tf.flags.DEFINE_integer("train_output_shards", 100,
"Number of output shards for the training set.")
tf.flags.DEFINE_integer("validation_output_shards", 1,
"Number of output shards for the validation set.")
tf.flags.DEFINE_integer("num_validation_sentences", 50000,
"Number of output shards for the validation set.")
tf.flags.DEFINE_integer("num_words", 20000,
"Number of words to include in the output.")
tf.flags.DEFINE_integer("max_sentences", 0,
"If > 0, the maximum number of sentences to output.")
tf.flags.DEFINE_integer("max_sentence_length", 30,
"If > 0, exclude sentences whose encode, decode_pre OR"
"decode_post sentence exceeds this length.")
tf.flags.DEFINE_boolean("add_eos", True,
"Whether to add end-of-sentence ids to the output.")
tf.logging.set_verbosity(tf.logging.INFO)
def _build_vocabulary(input_files):
"""Loads or builds the model vocabulary.
Args:
input_files: List of pre-tokenized input .txt files.
Returns:
vocab: A dictionary of word to id.
"""
if FLAGS.vocab_file:
tf.logging.info("Loading existing vocab file.")
vocab = collections.OrderedDict()
with tf.gfile.GFile(FLAGS.vocab_file, mode="r") as f:
for i, line in enumerate(f):
word = line.decode("utf-8").strip()
assert word not in vocab, "Attempting to add word twice: %s" % word
vocab[word] = i
tf.logging.info("Read vocab of size %d from %s",
len(vocab), FLAGS.vocab_file)
return vocab
tf.logging.info("Creating vocabulary.")
num = 0
wordcount = collections.Counter()
for input_file in input_files:
tf.logging.info("Processing file: %s", input_file)
for sentence in tf.gfile.FastGFile(input_file):
wordcount.update(sentence.split())
num += 1
if num % 1000000 == 0:
tf.logging.info("Processed %d sentences", num)
tf.logging.info("Processed %d sentences total", num)
words = wordcount.keys()
freqs = wordcount.values()
sorted_indices = np.argsort(freqs)[::-1]
vocab = collections.OrderedDict()
vocab[special_words.EOS] = special_words.EOS_ID
vocab[special_words.UNK] = special_words.UNK_ID
for w_id, w_index in enumerate(sorted_indices[0:FLAGS.num_words - 2]):
vocab[words[w_index]] = w_id + 2 # 0: EOS, 1: UNK.
tf.logging.info("Created vocab with %d words", len(vocab))
vocab_file = os.path.join(FLAGS.output_dir, "vocab.txt")
with tf.gfile.FastGFile(vocab_file, "w") as f:
f.write("\n".join(vocab.keys()))
tf.logging.info("Wrote vocab file to %s", vocab_file)
word_counts_file = os.path.join(FLAGS.output_dir, "word_counts.txt")
with tf.gfile.FastGFile(word_counts_file, "w") as f:
for i in sorted_indices:
f.write("%s %d\n" % (words[i], freqs[i]))
tf.logging.info("Wrote word counts file to %s", word_counts_file)
return vocab
def _int64_feature(value):
"""Helper for creating an Int64 Feature."""
return tf.train.Feature(int64_list=tf.train.Int64List(
value=[int(v) for v in value]))
def _sentence_to_ids(sentence, vocab):
"""Helper for converting a sentence (list of words) to a list of ids."""
ids = [vocab.get(w, special_words.UNK_ID) for w in sentence]
if FLAGS.add_eos:
ids.append(special_words.EOS_ID)
return ids
def _create_serialized_example(predecessor, current, successor, vocab):
"""Helper for creating a serialized Example proto."""
example = tf.train.Example(features=tf.train.Features(feature={
"decode_pre": _int64_feature(_sentence_to_ids(predecessor, vocab)),
"encode": _int64_feature(_sentence_to_ids(current, vocab)),
"decode_post": _int64_feature(_sentence_to_ids(successor, vocab)),
}))
return example.SerializeToString()
def _process_input_file(filename, vocab, stats):
"""Processes the sentences in an input file.
Args:
filename: Path to a pre-tokenized input .txt file.
vocab: A dictionary of word to id.
stats: A Counter object for statistics.
Returns:
processed: A list of serialized Example protos
"""
tf.logging.info("Processing input file: %s", filename)
processed = []
predecessor = None # Predecessor sentence (list of words).
current = None # Current sentence (list of words).
successor = None # Successor sentence (list of words).
for successor_str in tf.gfile.FastGFile(filename):
stats.update(["sentences_seen"])
successor = successor_str.split()
# The first 2 sentences per file will be skipped.
if predecessor and current and successor:
stats.update(["sentences_considered"])
# Note that we are going to insert <EOS> later, so we only allow
# sentences with strictly less than max_sentence_length to pass.
if FLAGS.max_sentence_length and (
len(predecessor) >= FLAGS.max_sentence_length or len(current) >=
FLAGS.max_sentence_length or len(successor) >=
FLAGS.max_sentence_length):
stats.update(["sentences_too_long"])
else:
serialized = _create_serialized_example(predecessor, current, successor,
vocab)
processed.append(serialized)
stats.update(["sentences_output"])
predecessor = current
current = successor
sentences_seen = stats["sentences_seen"]
sentences_output = stats["sentences_output"]
if sentences_seen and sentences_seen % 100000 == 0:
tf.logging.info("Processed %d sentences (%d output)", sentences_seen,
sentences_output)
if FLAGS.max_sentences and sentences_output >= FLAGS.max_sentences:
break
tf.logging.info("Completed processing file %s", filename)
return processed
def _write_shard(filename, dataset, indices):
"""Writes a TFRecord shard."""
with tf.python_io.TFRecordWriter(filename) as writer:
for j in indices:
writer.write(dataset[j])
def _write_dataset(name, dataset, indices, num_shards):
"""Writes a sharded TFRecord dataset.
Args:
name: Name of the dataset (e.g. "train").
dataset: List of serialized Example protos.
indices: List of indices of 'dataset' to be written.
num_shards: The number of output shards.
"""
tf.logging.info("Writing dataset %s", name)
borders = np.int32(np.linspace(0, len(indices), num_shards + 1))
for i in range(num_shards):
filename = os.path.join(FLAGS.output_dir, "%s-%.5d-of-%.5d" % (name, i,
num_shards))
shard_indices = indices[borders[i]:borders[i + 1]]
_write_shard(filename, dataset, shard_indices)
tf.logging.info("Wrote dataset indices [%d, %d) to output shard %s",
borders[i], borders[i + 1], filename)
tf.logging.info("Finished writing %d sentences in dataset %s.",
len(indices), name)
def main(unused_argv):
if not FLAGS.input_files:
raise ValueError("--input_files is required.")
if not FLAGS.output_dir:
raise ValueError("--output_dir is required.")
if not tf.gfile.IsDirectory(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
input_files = []
for pattern in FLAGS.input_files.split(","):
match = tf.gfile.Glob(FLAGS.input_files)
if not match:
raise ValueError("Found no files matching %s" % pattern)
input_files.extend(match)
tf.logging.info("Found %d input files.", len(input_files))
vocab = _build_vocabulary(input_files)
tf.logging.info("Generating dataset.")
stats = collections.Counter()
dataset = []
for filename in input_files:
dataset.extend(_process_input_file(filename, vocab, stats))
if FLAGS.max_sentences and stats["sentences_output"] >= FLAGS.max_sentences:
break
tf.logging.info("Generated dataset with %d sentences.", len(dataset))
for k, v in stats.items():
tf.logging.info("%s: %d", k, v)
tf.logging.info("Shuffling dataset.")
np.random.seed(123)
shuffled_indices = np.random.permutation(len(dataset))
val_indices = shuffled_indices[:FLAGS.num_validation_sentences]
train_indices = shuffled_indices[FLAGS.num_validation_sentences:]
_write_dataset("train", dataset, train_indices, FLAGS.train_output_shards)
_write_dataset("validation", dataset, val_indices,
FLAGS.validation_output_shards)
if __name__ == "__main__":
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Special word constants.
NOTE: The ids of the EOS and UNK constants should not be modified. It is assumed
that these always occupy the first two ids.
"""
# End of sentence.
EOS = "<eos>"
EOS_ID = 0
# Unknown.
UNK = "<unk>"
UNK_ID = 1
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Manager class for loading and encoding with multiple skip-thoughts models.
If multiple models are loaded at once then the encode() function returns the
concatenation of the outputs of each model.
Example usage:
manager = EncoderManager()
manager.load_model(model_config_1, vocabulary_file_1, embedding_matrix_file_1,
checkpoint_path_1)
manager.load_model(model_config_2, vocabulary_file_2, embedding_matrix_file_2,
checkpoint_path_2)
encodings = manager.encode(data)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import tensorflow as tf
from skip_thoughts import skip_thoughts_encoder
class EncoderManager(object):
"""Manager class for loading and encoding with skip-thoughts models."""
def __init__(self):
self.encoders = []
self.sessions = []
def load_model(self, model_config, vocabulary_file, embedding_matrix_file,
checkpoint_path):
"""Loads a skip-thoughts model.
Args:
model_config: Object containing parameters for building the model.
vocabulary_file: Path to vocabulary file containing a list of newline-
separated words where the word id is the corresponding 0-based index in
the file.
embedding_matrix_file: Path to a serialized numpy array of shape
[vocab_size, embedding_dim].
checkpoint_path: SkipThoughtsModel checkpoint file or a directory
containing a checkpoint file.
"""
tf.logging.info("Reading vocabulary from %s", vocabulary_file)
with tf.gfile.GFile(vocabulary_file, mode="r") as f:
lines = list(f.readlines())
reverse_vocab = [line.decode("utf-8").strip() for line in lines]
tf.logging.info("Loaded vocabulary with %d words.", len(reverse_vocab))
tf.logging.info("Loading embedding matrix from %s", embedding_matrix_file)
# Note: tf.gfile.GFile doesn't work here because np.load() calls f.seek()
# with 3 arguments.
with open(embedding_matrix_file, "r") as f:
embedding_matrix = np.load(f)
tf.logging.info("Loaded embedding matrix with shape %s",
embedding_matrix.shape)
word_embeddings = collections.OrderedDict(
zip(reverse_vocab, embedding_matrix))
g = tf.Graph()
with g.as_default():
encoder = skip_thoughts_encoder.SkipThoughtsEncoder(word_embeddings)
restore_model = encoder.build_graph_from_config(model_config,
checkpoint_path)
sess = tf.Session(graph=g)
restore_model(sess)
self.encoders.append(encoder)
self.sessions.append(sess)
def encode(self,
data,
use_norm=True,
verbose=False,
batch_size=128,
use_eos=False):
"""Encodes a sequence of sentences as skip-thought vectors.
Args:
data: A list of input strings.
use_norm: If True, normalize output skip-thought vectors to unit L2 norm.
verbose: Whether to log every batch.
batch_size: Batch size for the RNN encoders.
use_eos: If True, append the end-of-sentence word to each input sentence.
Returns:
thought_vectors: A list of numpy arrays corresponding to 'data'.
Raises:
ValueError: If called before calling load_encoder.
"""
if not self.encoders:
raise ValueError(
"Must call load_model at least once before calling encode.")
encoded = []
for encoder, sess in zip(self.encoders, self.sessions):
encoded.append(
np.array(
encoder.encode(
sess,
data,
use_norm=use_norm,
verbose=verbose,
batch_size=batch_size,
use_eos=use_eos)))
return np.concatenate(encoded, axis=1)
def close(self):
"""Closes the active TensorFlow Sessions."""
for sess in self.sessions:
sess.close()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to evaluate a skip-thoughts model.
This script can evaluate a model with a unidirectional encoder ("uni-skip" in
the paper); or a model with a bidirectional encoder ("bi-skip"); or the
combination of a model with a unidirectional encoder and a model with a
bidirectional encoder ("combine-skip").
The uni-skip model (if it exists) is specified by the flags
--uni_vocab_file, --uni_embeddings_file, --uni_checkpoint_path.
The bi-skip model (if it exists) is specified by the flags
--bi_vocab_file, --bi_embeddings_path, --bi_checkpoint_path.
The evaluation tasks have different running times. SICK may take 5-10 minutes.
MSRP, TREC and CR may take 20-60 minutes. SUBJ, MPQA and MR may take 2+ hours.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from skipthoughts import eval_classification
from skipthoughts import eval_msrp
from skipthoughts import eval_sick
from skipthoughts import eval_trec
import tensorflow as tf
from skip_thoughts import configuration
from skip_thoughts import encoder_manager
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("eval_task", "CR",
"Name of the evaluation task to run. Available tasks: "
"MR, CR, SUBJ, MPQA, SICK, MSRP, TREC.")
tf.flags.DEFINE_string("data_dir", None, "Directory containing training data.")
tf.flags.DEFINE_string("uni_vocab_file", None,
"Path to vocabulary file containing a list of newline-"
"separated words where the word id is the "
"corresponding 0-based index in the file.")
tf.flags.DEFINE_string("bi_vocab_file", None,
"Path to vocabulary file containing a list of newline-"
"separated words where the word id is the "
"corresponding 0-based index in the file.")
tf.flags.DEFINE_string("uni_embeddings_file", None,
"Path to serialized numpy array of shape "
"[vocab_size, embedding_dim].")
tf.flags.DEFINE_string("bi_embeddings_file", None,
"Path to serialized numpy array of shape "
"[vocab_size, embedding_dim].")
tf.flags.DEFINE_string("uni_checkpoint_path", None,
"Checkpoint file or directory containing a checkpoint "
"file.")
tf.flags.DEFINE_string("bi_checkpoint_path", None,
"Checkpoint file or directory containing a checkpoint "
"file.")
tf.logging.set_verbosity(tf.logging.INFO)
def main(unused_argv):
if not FLAGS.data_dir:
raise ValueError("--data_dir is required.")
encoder = encoder_manager.EncoderManager()
# Maybe load unidirectional encoder.
if FLAGS.uni_checkpoint_path:
print("Loading unidirectional model...")
uni_config = configuration.model_config()
encoder.load_model(uni_config, FLAGS.uni_vocab_file,
FLAGS.uni_embeddings_file, FLAGS.uni_checkpoint_path)
# Maybe load bidirectional encoder.
if FLAGS.bi_checkpoint_path:
print("Loading bidirectional model...")
bi_config = configuration.model_config(bidirectional_encoder=True)
encoder.load_model(bi_config, FLAGS.bi_vocab_file, FLAGS.bi_embeddings_file,
FLAGS.bi_checkpoint_path)
if FLAGS.eval_task in ["MR", "CR", "SUBJ", "MPQA"]:
eval_classification.eval_nested_kfold(
encoder, FLAGS.eval_task, FLAGS.data_dir, use_nb=False)
elif FLAGS.eval_task == "SICK":
eval_sick.evaluate(encoder, evaltest=True, loc=FLAGS.data_dir)
elif FLAGS.eval_task == "MSRP":
eval_msrp.evaluate(
encoder, evalcv=True, evaltest=True, use_feats=True, loc=FLAGS.data_dir)
elif FLAGS.eval_task == "TREC":
eval_trec.evaluate(encoder, evalcv=True, evaltest=True, loc=FLAGS.data_dir)
else:
raise ValueError("Unrecognized eval_task: %s" % FLAGS.eval_task)
encoder.close()
if __name__ == "__main__":
tf.app.run()
package(default_visibility = ["//skip_thoughts:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "input_ops",
srcs = ["input_ops.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "gru_cell",
srcs = ["gru_cell.py"],
srcs_version = "PY2AND3",
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""GRU cell implementation for the skip-thought vectors model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
_layer_norm = tf.contrib.layers.layer_norm
class LayerNormGRUCell(tf.contrib.rnn.RNNCell):
"""GRU cell with layer normalization.
The layer normalization implementation is based on:
https://arxiv.org/abs/1607.06450.
"Layer Normalization"
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
"""
def __init__(self,
num_units,
w_initializer,
u_initializer,
b_initializer,
activation=tf.nn.tanh):
"""Initializes the cell.
Args:
num_units: Number of cell units.
w_initializer: Initializer for the "W" (input) parameter matrices.
u_initializer: Initializer for the "U" (recurrent) parameter matrices.
b_initializer: Initializer for the "b" (bias) parameter vectors.
activation: Cell activation function.
"""
self._num_units = num_units
self._w_initializer = w_initializer
self._u_initializer = u_initializer
self._b_initializer = b_initializer
self._activation = activation
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def _w_h_initializer(self):
"""Returns an initializer for the "W_h" parameter matrix.
See equation (23) in the paper. The "W_h" parameter matrix is the
concatenation of two parameter submatrices. The matrix returned is
[U_z, U_r].
Returns:
A Tensor with shape [num_units, 2 * num_units] as described above.
"""
def _initializer(shape, dtype=tf.float32, partition_info=None):
num_units = self._num_units
assert shape == [num_units, 2 * num_units]
u_z = self._u_initializer([num_units, num_units], dtype, partition_info)
u_r = self._u_initializer([num_units, num_units], dtype, partition_info)
return tf.concat([u_z, u_r], 1)
return _initializer
def _w_x_initializer(self, input_dim):
"""Returns an initializer for the "W_x" parameter matrix.
See equation (23) in the paper. The "W_x" parameter matrix is the
concatenation of two parameter submatrices. The matrix returned is
[W_z, W_r].
Args:
input_dim: The dimension of the cell inputs.
Returns:
A Tensor with shape [input_dim, 2 * num_units] as described above.
"""
def _initializer(shape, dtype=tf.float32, partition_info=None):
num_units = self._num_units
assert shape == [input_dim, 2 * num_units]
w_z = self._w_initializer([input_dim, num_units], dtype, partition_info)
w_r = self._w_initializer([input_dim, num_units], dtype, partition_info)
return tf.concat([w_z, w_r], 1)
return _initializer
def __call__(self, inputs, state, scope=None):
"""GRU cell with layer normalization."""
input_dim = inputs.get_shape().as_list()[1]
num_units = self._num_units
with tf.variable_scope(scope or "gru_cell"):
with tf.variable_scope("gates"):
w_h = tf.get_variable(
"w_h", [num_units, 2 * num_units],
initializer=self._w_h_initializer())
w_x = tf.get_variable(
"w_x", [input_dim, 2 * num_units],
initializer=self._w_x_initializer(input_dim))
z_and_r = (_layer_norm(tf.matmul(state, w_h), scope="layer_norm/w_h") +
_layer_norm(tf.matmul(inputs, w_x), scope="layer_norm/w_x"))
z, r = tf.split(tf.sigmoid(z_and_r), 2, 1)
with tf.variable_scope("candidate"):
w = tf.get_variable(
"w", [input_dim, num_units], initializer=self._w_initializer)
u = tf.get_variable(
"u", [num_units, num_units], initializer=self._u_initializer)
h_hat = (r * _layer_norm(tf.matmul(state, u), scope="layer_norm/u") +
_layer_norm(tf.matmul(inputs, w), scope="layer_norm/w"))
new_h = (1 - z) * state + z * self._activation(h_hat)
return new_h, new_h
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
# A SentenceBatch is a pair of Tensors:
# ids: Batch of input sentences represented as sequences of word ids: an int64
# Tensor with shape [batch_size, padded_length].
# mask: Boolean mask distinguishing real words (1) from padded words (0): an
# int32 Tensor with shape [batch_size, padded_length].
SentenceBatch = collections.namedtuple("SentenceBatch", ("ids", "mask"))
def parse_example_batch(serialized):
"""Parses a batch of tf.Example protos.
Args:
serialized: A 1-D string Tensor; a batch of serialized tf.Example protos.
Returns:
encode: A SentenceBatch of encode sentences.
decode_pre: A SentenceBatch of "previous" sentences to decode.
decode_post: A SentenceBatch of "post" sentences to decode.
"""
features = tf.parse_example(
serialized,
features={
"encode": tf.VarLenFeature(dtype=tf.int64),
"decode_pre": tf.VarLenFeature(dtype=tf.int64),
"decode_post": tf.VarLenFeature(dtype=tf.int64),
})
def _sparse_to_batch(sparse):
ids = tf.sparse_tensor_to_dense(sparse) # Padding with zeroes.
mask = tf.sparse_to_dense(sparse.indices, sparse.dense_shape,
tf.ones_like(sparse.values, dtype=tf.int32))
return SentenceBatch(ids=ids, mask=mask)
output_names = ("encode", "decode_pre", "decode_post")
return tuple(_sparse_to_batch(features[x]) for x in output_names)
def prefetch_input_data(reader,
file_pattern,
shuffle,
capacity,
num_reader_threads=1):
"""Prefetches string values from disk into an input queue.
Args:
reader: Instance of tf.ReaderBase.
file_pattern: Comma-separated list of file patterns (e.g.
"/tmp/train_data-?????-of-00100", where '?' acts as a wildcard that
matches any character).
shuffle: Boolean; whether to randomly shuffle the input data.
capacity: Queue capacity (number of records).
num_reader_threads: Number of reader threads feeding into the queue.
Returns:
A Queue containing prefetched string values.
"""
data_files = []
for pattern in file_pattern.split(","):
data_files.extend(tf.gfile.Glob(pattern))
if not data_files:
tf.logging.fatal("Found no input files matching %s", file_pattern)
else:
tf.logging.info("Prefetching values from %d files matching %s",
len(data_files), file_pattern)
filename_queue = tf.train.string_input_producer(
data_files, shuffle=shuffle, capacity=16, name="filename_queue")
if shuffle:
min_after_dequeue = int(0.6 * capacity)
values_queue = tf.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=min_after_dequeue,
dtypes=[tf.string],
shapes=[[]],
name="random_input_queue")
else:
values_queue = tf.FIFOQueue(
capacity=capacity,
dtypes=[tf.string],
shapes=[[]],
name="fifo_input_queue")
enqueue_ops = []
for _ in range(num_reader_threads):
_, value = reader.read(filename_queue)
enqueue_ops.append(values_queue.enqueue([value]))
tf.train.queue_runner.add_queue_runner(
tf.train.queue_runner.QueueRunner(values_queue, enqueue_ops))
tf.summary.scalar("queue/%s/fraction_of_%d_full" % (values_queue.name,
capacity),
tf.cast(values_queue.size(), tf.float32) * (1.0 / capacity))
return values_queue
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class for encoding text using a trained SkipThoughtsModel.
Example usage:
g = tf.Graph()
with g.as_default():
encoder = SkipThoughtsEncoder(embeddings)
restore_fn = encoder.build_graph_from_config(model_config, checkpoint_path)
with tf.Session(graph=g) as sess:
restore_fn(sess)
skip_thought_vectors = encoder.encode(sess, data)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import nltk
import nltk.tokenize
import numpy as np
import tensorflow as tf
from skip_thoughts import skip_thoughts_model
from skip_thoughts.data import special_words
def _pad(seq, target_len):
"""Pads a sequence of word embeddings up to the target length.
Args:
seq: Sequence of word embeddings.
target_len: Desired padded sequence length.
Returns:
embeddings: Input sequence padded with zero embeddings up to the target
length.
mask: A 0/1 vector with zeros corresponding to padded embeddings.
Raises:
ValueError: If len(seq) is not in the interval (0, target_len].
"""
seq_len = len(seq)
if seq_len <= 0 or seq_len > target_len:
raise ValueError("Expected 0 < len(seq) <= %d, got %d" % (target_len,
seq_len))
emb_dim = seq[0].shape[0]
padded_seq = np.zeros(shape=(target_len, emb_dim), dtype=seq[0].dtype)
mask = np.zeros(shape=(target_len,), dtype=np.int8)
for i in range(seq_len):
padded_seq[i] = seq[i]
mask[i] = 1
return padded_seq, mask
def _batch_and_pad(sequences):
"""Batches and pads sequences of word embeddings into a 2D array.
Args:
sequences: A list of batch_size sequences of word embeddings.
Returns:
embeddings: A numpy array with shape [batch_size, padded_length, emb_dim].
mask: A numpy 0/1 array with shape [batch_size, padded_length] with zeros
corresponding to padded elements.
"""
batch_embeddings = []
batch_mask = []
batch_len = max([len(seq) for seq in sequences])
for seq in sequences:
embeddings, mask = _pad(seq, batch_len)
batch_embeddings.append(embeddings)
batch_mask.append(mask)
return np.array(batch_embeddings), np.array(batch_mask)
class SkipThoughtsEncoder(object):
"""Skip-thoughts sentence encoder."""
def __init__(self, embeddings):
"""Initializes the encoder.
Args:
embeddings: Dictionary of word to embedding vector (1D numpy array).
"""
self._sentence_detector = nltk.data.load("tokenizers/punkt/english.pickle")
self._embeddings = embeddings
def _create_restore_fn(self, checkpoint_path, saver):
"""Creates a function that restores a model from checkpoint.
Args:
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
saver: Saver for restoring variables from the checkpoint file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
Raises:
ValueError: If checkpoint_path does not refer to a checkpoint file or a
directory containing a checkpoint file.
"""
if tf.gfile.IsDirectory(checkpoint_path):
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if not latest_checkpoint:
raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
checkpoint_path = latest_checkpoint
def _restore_fn(sess):
tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
saver.restore(sess, checkpoint_path)
tf.logging.info("Successfully loaded checkpoint: %s",
os.path.basename(checkpoint_path))
return _restore_fn
def build_graph_from_config(self, model_config, checkpoint_path):
"""Builds the inference graph from a configuration object.
Args:
model_config: Object containing configuration for building the model.
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
"""
tf.logging.info("Building model.")
model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="encode")
model.build()
saver = tf.train.Saver()
return self._create_restore_fn(checkpoint_path, saver)
def build_graph_from_proto(self, graph_def_file, saver_def_file,
checkpoint_path):
"""Builds the inference graph from serialized GraphDef and SaverDef protos.
Args:
graph_def_file: File containing a serialized GraphDef proto.
saver_def_file: File containing a serialized SaverDef proto.
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
"""
# Load the Graph.
tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(graph_def_file, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
# Load the Saver.
tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
saver_def = tf.train.SaverDef()
with tf.gfile.FastGFile(saver_def_file, "rb") as f:
saver_def.ParseFromString(f.read())
saver = tf.train.Saver(saver_def=saver_def)
return self._create_restore_fn(checkpoint_path, saver)
def _tokenize(self, item):
"""Tokenizes an input string into a list of words."""
tokenized = []
for s in self._sentence_detector.tokenize(item):
tokenized.extend(nltk.tokenize.word_tokenize(s))
return tokenized
def _word_to_embedding(self, w):
"""Returns the embedding of a word."""
return self._embeddings.get(w, self._embeddings[special_words.UNK])
def _preprocess(self, data, use_eos):
"""Preprocesses text for the encoder.
Args:
data: A list of input strings.
use_eos: Whether to append the end-of-sentence word to each sentence.
Returns:
embeddings: A list of word embedding sequences corresponding to the input
strings.
"""
preprocessed_data = []
for item in data:
tokenized = self._tokenize(item)
if use_eos:
tokenized.append(special_words.EOS)
preprocessed_data.append([self._word_to_embedding(w) for w in tokenized])
return preprocessed_data
def encode(self,
sess,
data,
use_norm=True,
verbose=True,
batch_size=128,
use_eos=False):
"""Encodes a sequence of sentences as skip-thought vectors.
Args:
sess: TensorFlow Session.
data: A list of input strings.
use_norm: Whether to normalize skip-thought vectors to unit L2 norm.
verbose: Whether to log every batch.
batch_size: Batch size for the encoder.
use_eos: Whether to append the end-of-sentence word to each input
sentence.
Returns:
thought_vectors: A list of numpy arrays corresponding to the skip-thought
encodings of sentences in 'data'.
"""
data = self._preprocess(data, use_eos)
thought_vectors = []
batch_indices = np.arange(0, len(data), batch_size)
for batch, start_index in enumerate(batch_indices):
if verbose:
tf.logging.info("Batch %d / %d.", batch, len(batch_indices))
embeddings, mask = _batch_and_pad(
data[start_index:start_index + batch_size])
feed_dict = {
"encode_emb:0": embeddings,
"encode_mask:0": mask,
}
thought_vectors.extend(
sess.run("encoder/thought_vectors:0", feed_dict=feed_dict))
if use_norm:
thought_vectors = [v / np.linalg.norm(v) for v in thought_vectors]
return thought_vectors
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Skip-Thoughts model for learning sentence vectors.
The model is based on the paper:
"Skip-Thought Vectors"
Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
Antonio Torralba, Raquel Urtasun, Sanja Fidler.
https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf
Layer normalization is applied based on the paper:
"Layer Normalization"
Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
https://arxiv.org/abs/1607.06450
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from skip_thoughts.ops import gru_cell
from skip_thoughts.ops import input_ops
def random_orthonormal_initializer(shape, dtype=tf.float32,
partition_info=None): # pylint: disable=unused-argument
"""Variable initializer that produces a random orthonormal matrix."""
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError("Expecting square shape, got %s" % shape)
_, u, _ = tf.svd(tf.random_normal(shape, dtype=dtype), full_matrices=True)
return u
class SkipThoughtsModel(object):
"""Skip-thoughts model."""
def __init__(self, config, mode="train", input_reader=None):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
config: Object containing configuration parameters.
mode: "train", "eval" or "encode".
input_reader: Subclass of tf.ReaderBase for reading the input serialized
tf.Example protocol buffers. Defaults to TFRecordReader.
Raises:
ValueError: If mode is invalid.
"""
if mode not in ["train", "eval", "encode"]:
raise ValueError("Unrecognized mode: %s" % mode)
self.config = config
self.mode = mode
self.reader = input_reader if input_reader else tf.TFRecordReader()
# Initializer used for non-recurrent weights.
self.uniform_initializer = tf.random_uniform_initializer(
minval=-self.config.uniform_init_scale,
maxval=self.config.uniform_init_scale)
# Input sentences represented as sequences of word ids. "encode" is the
# source sentence, "decode_pre" is the previous sentence and "decode_post"
# is the next sentence.
# Each is an int64 Tensor with shape [batch_size, padded_length].
self.encode_ids = None
self.decode_pre_ids = None
self.decode_post_ids = None
# Boolean masks distinguishing real words (1) from padded words (0).
# Each is an int32 Tensor with shape [batch_size, padded_length].
self.encode_mask = None
self.decode_pre_mask = None
self.decode_post_mask = None
# Input sentences represented as sequences of word embeddings.
# Each is a float32 Tensor with shape [batch_size, padded_length, emb_dim].
self.encode_emb = None
self.decode_pre_emb = None
self.decode_post_emb = None
# The output from the sentence encoder.
# A float32 Tensor with shape [batch_size, num_gru_units].
self.thought_vectors = None
# The cross entropy losses and corresponding weights of the decoders. Used
# for evaluation.
self.target_cross_entropy_losses = []
self.target_cross_entropy_loss_weights = []
# The total loss to optimize.
self.total_loss = None
def build_inputs(self):
"""Builds the ops for reading input data.
Outputs:
self.encode_ids
self.decode_pre_ids
self.decode_post_ids
self.encode_mask
self.decode_pre_mask
self.decode_post_mask
"""
if self.mode == "encode":
# Word embeddings are fed from an external vocabulary which has possibly
# been expanded (see vocabulary_expansion.py).
encode_ids = None
decode_pre_ids = None
decode_post_ids = None
encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask")
decode_pre_mask = None
decode_post_mask = None
else:
# Prefetch serialized tf.Example protos.
input_queue = input_ops.prefetch_input_data(
self.reader,
self.config.input_file_pattern,
shuffle=self.config.shuffle_input_data,
capacity=self.config.input_queue_capacity,
num_reader_threads=self.config.num_input_reader_threads)
# Deserialize a batch.
serialized = input_queue.dequeue_many(self.config.batch_size)
encode, decode_pre, decode_post = input_ops.parse_example_batch(
serialized)
encode_ids = encode.ids
decode_pre_ids = decode_pre.ids
decode_post_ids = decode_post.ids
encode_mask = encode.mask
decode_pre_mask = decode_pre.mask
decode_post_mask = decode_post.mask
self.encode_ids = encode_ids
self.decode_pre_ids = decode_pre_ids
self.decode_post_ids = decode_post_ids
self.encode_mask = encode_mask
self.decode_pre_mask = decode_pre_mask
self.decode_post_mask = decode_post_mask
def build_word_embeddings(self):
"""Builds the word embeddings.
Inputs:
self.encode_ids
self.decode_pre_ids
self.decode_post_ids
Outputs:
self.encode_emb
self.decode_pre_emb
self.decode_post_emb
"""
if self.mode == "encode":
# Word embeddings are fed from an external vocabulary which has possibly
# been expanded (see vocabulary_expansion.py).
encode_emb = tf.placeholder(tf.float32, (
None, None, self.config.word_embedding_dim), "encode_emb")
# No sequences to decode.
decode_pre_emb = None
decode_post_emb = None
else:
word_emb = tf.get_variable(
name="word_embedding",
shape=[self.config.vocab_size, self.config.word_embedding_dim],
initializer=self.uniform_initializer)
encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids)
decode_pre_emb = tf.nn.embedding_lookup(word_emb, self.decode_pre_ids)
decode_post_emb = tf.nn.embedding_lookup(word_emb, self.decode_post_ids)
self.encode_emb = encode_emb
self.decode_pre_emb = decode_pre_emb
self.decode_post_emb = decode_post_emb
def _initialize_gru_cell(self, num_units):
"""Initializes a GRU cell.
The Variables of the GRU cell are initialized in a way that exactly matches
the skip-thoughts paper: recurrent weights are initialized from random
orthonormal matrices and non-recurrent weights are initialized from random
uniform matrices.
Args:
num_units: Number of output units.
Returns:
cell: An instance of RNNCell with variable initializers that match the
skip-thoughts paper.
"""
return gru_cell.LayerNormGRUCell(
num_units,
w_initializer=self.uniform_initializer,
u_initializer=random_orthonormal_initializer,
b_initializer=tf.constant_initializer(0.0))
def build_encoder(self):
"""Builds the sentence encoder.
Inputs:
self.encode_emb
self.encode_mask
Outputs:
self.thought_vectors
Raises:
ValueError: if config.bidirectional_encoder is True and config.encoder_dim
is odd.
"""
with tf.variable_scope("encoder") as scope:
length = tf.to_int32(tf.reduce_sum(self.encode_mask, 1), name="length")
if self.config.bidirectional_encoder:
if self.config.encoder_dim % 2:
raise ValueError(
"encoder_dim must be even when using a bidirectional encoder.")
num_units = self.config.encoder_dim // 2
cell_fw = self._initialize_gru_cell(num_units) # Forward encoder
cell_bw = self._initialize_gru_cell(num_units) # Backward encoder
_, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell_fw,
cell_bw=cell_bw,
inputs=self.encode_emb,
sequence_length=length,
dtype=tf.float32,
scope=scope)
thought_vectors = tf.concat(states, 1, name="thought_vectors")
else:
cell = self._initialize_gru_cell(self.config.encoder_dim)
_, state = tf.nn.dynamic_rnn(
cell=cell,
inputs=self.encode_emb,
sequence_length=length,
dtype=tf.float32,
scope=scope)
# Use an identity operation to name the Tensor in the Graph.
thought_vectors = tf.identity(state, name="thought_vectors")
self.thought_vectors = thought_vectors
def _build_decoder(self, name, embeddings, targets, mask, initial_state,
reuse_logits):
"""Builds a sentence decoder.
Args:
name: Decoder name.
embeddings: Batch of sentences to decode; a float32 Tensor with shape
[batch_size, padded_length, emb_dim].
targets: Batch of target word ids; an int64 Tensor with shape
[batch_size, padded_length].
mask: A 0/1 Tensor with shape [batch_size, padded_length].
initial_state: Initial state of the GRU. A float32 Tensor with shape
[batch_size, num_gru_cells].
reuse_logits: Whether to reuse the logits weights.
"""
# Decoder RNN.
cell = self._initialize_gru_cell(self.config.encoder_dim)
with tf.variable_scope(name) as scope:
# Add a padding word at the start of each sentence (to correspond to the
# prediction of the first word) and remove the last word.
decoder_input = tf.pad(
embeddings[:, :-1, :], [[0, 0], [1, 0], [0, 0]], name="input")
length = tf.reduce_sum(mask, 1, name="length")
decoder_output, _ = tf.nn.dynamic_rnn(
cell=cell,
inputs=decoder_input,
sequence_length=length,
initial_state=initial_state,
scope=scope)
# Stack batch vertically.
decoder_output = tf.reshape(decoder_output, [-1, self.config.encoder_dim])
targets = tf.reshape(targets, [-1])
weights = tf.to_float(tf.reshape(mask, [-1]))
# Logits.
with tf.variable_scope("logits", reuse=reuse_logits) as scope:
logits = tf.contrib.layers.fully_connected(
inputs=decoder_output,
num_outputs=self.config.vocab_size,
activation_fn=None,
weights_initializer=self.uniform_initializer,
scope=scope)
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=logits)
batch_loss = tf.reduce_sum(losses * weights)
tf.losses.add_loss(batch_loss)
tf.summary.scalar("losses/" + name, batch_loss)
self.target_cross_entropy_losses.append(losses)
self.target_cross_entropy_loss_weights.append(weights)
def build_decoders(self):
"""Builds the sentence decoders.
Inputs:
self.decode_pre_emb
self.decode_post_emb
self.decode_pre_ids
self.decode_post_ids
self.decode_pre_mask
self.decode_post_mask
self.thought_vectors
Outputs:
self.target_cross_entropy_losses
self.target_cross_entropy_loss_weights
"""
if self.mode != "encode":
# Pre-sentence decoder.
self._build_decoder("decoder_pre", self.decode_pre_emb,
self.decode_pre_ids, self.decode_pre_mask,
self.thought_vectors, False)
# Post-sentence decoder. Logits weights are reused.
self._build_decoder("decoder_post", self.decode_post_emb,
self.decode_post_ids, self.decode_post_mask,
self.thought_vectors, True)
def build_loss(self):
"""Builds the loss Tensor.
Outputs:
self.total_loss
"""
if self.mode != "encode":
total_loss = tf.losses.get_total_loss()
tf.summary.scalar("losses/total", total_loss)
self.total_loss = total_loss
def build_global_step(self):
"""Builds the global step Tensor.
Outputs:
self.global_step
"""
self.global_step = tf.contrib.framework.create_global_step()
def build(self):
"""Creates all ops for training, evaluation or encoding."""
self.build_inputs()
self.build_word_embeddings()
self.build_encoder()
self.build_decoders()
self.build_loss()
self.build_global_step()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.skip_thoughts.skip_thoughts_model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from skip_thoughts import configuration
from skip_thoughts import skip_thoughts_model
class SkipThoughtsModel(skip_thoughts_model.SkipThoughtsModel):
"""Subclass of SkipThoughtsModel without the disk I/O."""
def build_inputs(self):
if self.mode == "encode":
# Encode mode doesn't read from disk, so defer to parent.
return super(SkipThoughtsModel, self).build_inputs()
else:
# Replace disk I/O with random Tensors.
self.encode_ids = tf.random_uniform(
[self.config.batch_size, 15],
minval=0,
maxval=self.config.vocab_size,
dtype=tf.int64)
self.decode_pre_ids = tf.random_uniform(
[self.config.batch_size, 15],
minval=0,
maxval=self.config.vocab_size,
dtype=tf.int64)
self.decode_post_ids = tf.random_uniform(
[self.config.batch_size, 15],
minval=0,
maxval=self.config.vocab_size,
dtype=tf.int64)
self.encode_mask = tf.ones_like(self.encode_ids)
self.decode_pre_mask = tf.ones_like(self.decode_pre_ids)
self.decode_post_mask = tf.ones_like(self.decode_post_ids)
class SkipThoughtsModelTest(tf.test.TestCase):
def setUp(self):
super(SkipThoughtsModelTest, self).setUp()
self._model_config = configuration.model_config()
def _countModelParameters(self):
"""Counts the number of parameters in the model at top level scope."""
counter = {}
for v in tf.global_variables():
name = v.op.name.split("/")[0]
num_params = v.get_shape().num_elements()
if not num_params:
self.fail("Could not infer num_elements from Variable %s" % v.op.name)
counter[name] = counter.get(name, 0) + num_params
return counter
def _checkModelParameters(self):
"""Verifies the number of parameters in the model."""
param_counts = self._countModelParameters()
expected_param_counts = {
# vocab_size * embedding_size
"word_embedding": 12400000,
# GRU Cells
"encoder": 21772800,
"decoder_pre": 21772800,
"decoder_post": 21772800,
# (encoder_dim + 1) * vocab_size
"logits": 48020000,
"global_step": 1,
}
self.assertDictEqual(expected_param_counts, param_counts)
def _checkOutputs(self, expected_shapes, feed_dict=None):
"""Verifies that the model produces expected outputs.
Args:
expected_shapes: A dict mapping Tensor or Tensor name to expected output
shape.
feed_dict: Values of Tensors to feed into Session.run().
"""
fetches = expected_shapes.keys()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
outputs = sess.run(fetches, feed_dict)
for index, output in enumerate(outputs):
tensor = fetches[index]
expected = expected_shapes[tensor]
actual = output.shape
if expected != actual:
self.fail("Tensor %s has shape %s (expected %s)." % (tensor, actual,
expected))
def testBuildForTraining(self):
model = SkipThoughtsModel(self._model_config, mode="train")
model.build()
self._checkModelParameters()
expected_shapes = {
# [batch_size, length]
model.encode_ids: (128, 15),
model.decode_pre_ids: (128, 15),
model.decode_post_ids: (128, 15),
model.encode_mask: (128, 15),
model.decode_pre_mask: (128, 15),
model.decode_post_mask: (128, 15),
# [batch_size, length, word_embedding_dim]
model.encode_emb: (128, 15, 620),
model.decode_pre_emb: (128, 15, 620),
model.decode_post_emb: (128, 15, 620),
# [batch_size, encoder_dim]
model.thought_vectors: (128, 2400),
# [batch_size * length]
model.target_cross_entropy_losses[0]: (1920,),
model.target_cross_entropy_losses[1]: (1920,),
# [batch_size * length]
model.target_cross_entropy_loss_weights[0]: (1920,),
model.target_cross_entropy_loss_weights[1]: (1920,),
# Scalar
model.total_loss: (),
}
self._checkOutputs(expected_shapes)
def testBuildForEval(self):
model = SkipThoughtsModel(self._model_config, mode="eval")
model.build()
self._checkModelParameters()
expected_shapes = {
# [batch_size, length]
model.encode_ids: (128, 15),
model.decode_pre_ids: (128, 15),
model.decode_post_ids: (128, 15),
model.encode_mask: (128, 15),
model.decode_pre_mask: (128, 15),
model.decode_post_mask: (128, 15),
# [batch_size, length, word_embedding_dim]
model.encode_emb: (128, 15, 620),
model.decode_pre_emb: (128, 15, 620),
model.decode_post_emb: (128, 15, 620),
# [batch_size, encoder_dim]
model.thought_vectors: (128, 2400),
# [batch_size * length]
model.target_cross_entropy_losses[0]: (1920,),
model.target_cross_entropy_losses[1]: (1920,),
# [batch_size * length]
model.target_cross_entropy_loss_weights[0]: (1920,),
model.target_cross_entropy_loss_weights[1]: (1920,),
# Scalar
model.total_loss: (),
}
self._checkOutputs(expected_shapes)
def testBuildForEncode(self):
model = SkipThoughtsModel(self._model_config, mode="encode")
model.build()
# Test feeding a batch of word embeddings to get skip thought vectors.
encode_emb = np.random.rand(64, 15, 620)
encode_mask = np.ones((64, 15), dtype=np.int64)
feed_dict = {model.encode_emb: encode_emb, model.encode_mask: encode_mask}
expected_shapes = {
# [batch_size, encoder_dim]
model.thought_vectors: (64, 2400),
}
self._checkOutputs(expected_shapes, feed_dict)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment