Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
68609ca7
Commit
68609ca7
authored
Mar 16, 2017
by
Christopher Shallue
Browse files
TF implementation of Skip Thoughts.
parent
51fcc99b
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
501 additions
and
0 deletions
+501
-0
skip_thoughts/skip_thoughts/track_perplexity.py
skip_thoughts/skip_thoughts/track_perplexity.py
+199
-0
skip_thoughts/skip_thoughts/train.py
skip_thoughts/skip_thoughts/train.py
+99
-0
skip_thoughts/skip_thoughts/vocabulary_expansion.py
skip_thoughts/skip_thoughts/vocabulary_expansion.py
+203
-0
No files found.
skip_thoughts/skip_thoughts/track_perplexity.py
0 → 100644
View file @
68609ca7
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tracks training progress via per-word perplexity.
This script should be run concurrently with training so that summaries show up
in TensorBoard.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
os.path
import
time
import
numpy
as
np
import
tensorflow
as
tf
from
skip_thoughts
import
configuration
from
skip_thoughts
import
skip_thoughts_model
FLAGS
=
tf
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_string
(
"input_file_pattern"
,
None
,
"File pattern of sharded TFRecord input files."
)
tf
.
flags
.
DEFINE_string
(
"checkpoint_dir"
,
None
,
"Directory containing model checkpoints."
)
tf
.
flags
.
DEFINE_string
(
"eval_dir"
,
None
,
"Directory to write event logs to."
)
tf
.
flags
.
DEFINE_integer
(
"eval_interval_secs"
,
600
,
"Interval between evaluation runs."
)
tf
.
flags
.
DEFINE_integer
(
"num_eval_examples"
,
50000
,
"Number of examples for evaluation."
)
tf
.
flags
.
DEFINE_integer
(
"min_global_step"
,
100
,
"Minimum global step to run evaluation."
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
def
evaluate_model
(
sess
,
losses
,
weights
,
num_batches
,
global_step
,
summary_writer
,
summary_op
):
"""Computes perplexity-per-word over the evaluation dataset.
Summaries and perplexity-per-word are written out to the eval directory.
Args:
sess: Session object.
losses: A Tensor of any shape; the target cross entropy losses for the
current batch.
weights: A Tensor of weights corresponding to losses.
num_batches: Integer; the number of evaluation batches.
global_step: Integer; global step of the model checkpoint.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
"""
# Log model summaries on a single batch.
summary_str
=
sess
.
run
(
summary_op
)
summary_writer
.
add_summary
(
summary_str
,
global_step
)
start_time
=
time
.
time
()
sum_losses
=
0.0
sum_weights
=
0.0
for
i
in
xrange
(
num_batches
):
batch_losses
,
batch_weights
=
sess
.
run
([
losses
,
weights
])
sum_losses
+=
np
.
sum
(
batch_losses
*
batch_weights
)
sum_weights
+=
np
.
sum
(
batch_weights
)
if
not
i
%
100
:
tf
.
logging
.
info
(
"Computed losses for %d of %d batches."
,
i
+
1
,
num_batches
)
eval_time
=
time
.
time
()
-
start_time
perplexity
=
math
.
exp
(
sum_losses
/
sum_weights
)
tf
.
logging
.
info
(
"Perplexity = %f (%.2f sec)"
,
perplexity
,
eval_time
)
# Log perplexity to the SummaryWriter.
summary
=
tf
.
Summary
()
value
=
summary
.
value
.
add
()
value
.
simple_value
=
perplexity
value
.
tag
=
"perplexity"
summary_writer
.
add_summary
(
summary
,
global_step
)
# Write the Events file to the eval directory.
summary_writer
.
flush
()
tf
.
logging
.
info
(
"Finished processing evaluation at global step %d."
,
global_step
)
def
run_once
(
model
,
losses
,
weights
,
saver
,
summary_writer
,
summary_op
):
"""Evaluates the latest model checkpoint.
Args:
model: Instance of SkipThoughtsModel; the model to evaluate.
losses: Tensor; the target cross entropy losses for the current batch.
weights: A Tensor of weights corresponding to losses.
saver: Instance of tf.train.Saver for restoring model Variables.
summary_writer: Instance of FileWriter.
summary_op: Op for generating model summaries.
"""
model_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
checkpoint_dir
)
if
not
model_path
:
tf
.
logging
.
info
(
"Skipping evaluation. No checkpoint found in: %s"
,
FLAGS
.
checkpoint_dir
)
return
with
tf
.
Session
()
as
sess
:
# Load model from checkpoint.
tf
.
logging
.
info
(
"Loading model from checkpoint: %s"
,
model_path
)
saver
.
restore
(
sess
,
model_path
)
global_step
=
tf
.
train
.
global_step
(
sess
,
model
.
global_step
.
name
)
tf
.
logging
.
info
(
"Successfully loaded %s at global step = %d."
,
os
.
path
.
basename
(
model_path
),
global_step
)
if
global_step
<
FLAGS
.
min_global_step
:
tf
.
logging
.
info
(
"Skipping evaluation. Global step = %d < %d"
,
global_step
,
FLAGS
.
min_global_step
)
return
# Start the queue runners.
coord
=
tf
.
train
.
Coordinator
()
threads
=
tf
.
train
.
start_queue_runners
(
coord
=
coord
)
num_eval_batches
=
int
(
math
.
ceil
(
FLAGS
.
num_eval_examples
/
model
.
config
.
batch_size
))
# Run evaluation on the latest checkpoint.
try
:
evaluate_model
(
sess
,
losses
,
weights
,
num_eval_batches
,
global_step
,
summary_writer
,
summary_op
)
except
tf
.
InvalidArgumentError
:
tf
.
logging
.
error
(
"Evaluation raised InvalidArgumentError (e.g. due to Nans)."
)
finally
:
coord
.
request_stop
()
coord
.
join
(
threads
,
stop_grace_period_secs
=
10
)
def
main
(
unused_argv
):
if
not
FLAGS
.
input_file_pattern
:
raise
ValueError
(
"--input_file_pattern is required."
)
if
not
FLAGS
.
checkpoint_dir
:
raise
ValueError
(
"--checkpoint_dir is required."
)
if
not
FLAGS
.
eval_dir
:
raise
ValueError
(
"--eval_dir is required."
)
# Create the evaluation directory if it doesn't exist.
eval_dir
=
FLAGS
.
eval_dir
if
not
tf
.
gfile
.
IsDirectory
(
eval_dir
):
tf
.
logging
.
info
(
"Creating eval directory: %s"
,
eval_dir
)
tf
.
gfile
.
MakeDirs
(
eval_dir
)
g
=
tf
.
Graph
()
with
g
.
as_default
():
# Build the model for evaluation.
model_config
=
configuration
.
model_config
(
input_file_pattern
=
FLAGS
.
input_file_pattern
,
input_queue_capacity
=
FLAGS
.
num_eval_examples
,
shuffle_input_data
=
False
)
model
=
skip_thoughts_model
.
SkipThoughtsModel
(
model_config
,
mode
=
"eval"
)
model
.
build
()
losses
=
tf
.
concat
(
model
.
target_cross_entropy_losses
,
0
)
weights
=
tf
.
concat
(
model
.
target_cross_entropy_loss_weights
,
0
)
# Create the Saver to restore model Variables.
saver
=
tf
.
train
.
Saver
()
# Create the summary operation and the summary writer.
summary_op
=
tf
.
summary
.
merge_all
()
summary_writer
=
tf
.
summary
.
FileWriter
(
eval_dir
)
g
.
finalize
()
# Run a new evaluation run every eval_interval_secs.
while
True
:
start
=
time
.
time
()
tf
.
logging
.
info
(
"Starting evaluation at "
+
time
.
strftime
(
"%Y-%m-%d-%H:%M:%S"
,
time
.
localtime
()))
run_once
(
model
,
losses
,
weights
,
saver
,
summary_writer
,
summary_op
)
time_to_next_eval
=
start
+
FLAGS
.
eval_interval_secs
-
time
.
time
()
if
time_to_next_eval
>
0
:
time
.
sleep
(
time_to_next_eval
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
skip_thoughts/skip_thoughts/train.py
0 → 100644
View file @
68609ca7
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train the skip-thoughts model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
skip_thoughts
import
configuration
from
skip_thoughts
import
skip_thoughts_model
FLAGS
=
tf
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_string
(
"input_file_pattern"
,
None
,
"File pattern of sharded TFRecord files containing "
"tf.Example protos."
)
tf
.
flags
.
DEFINE_string
(
"train_dir"
,
None
,
"Directory for saving and loading checkpoints."
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
def
_setup_learning_rate
(
config
,
global_step
):
"""Sets up the learning rate with optional exponential decay.
Args:
config: Object containing learning rate configuration parameters.
global_step: Tensor; the global step.
Returns:
learning_rate: Tensor; the learning rate with exponential decay.
"""
if
config
.
learning_rate_decay_factor
>
0
:
learning_rate
=
tf
.
train
.
exponential_decay
(
learning_rate
=
float
(
config
.
learning_rate
),
global_step
=
global_step
,
decay_steps
=
config
.
learning_rate_decay_steps
,
decay_rate
=
config
.
learning_rate_decay_factor
,
staircase
=
False
)
else
:
learning_rate
=
tf
.
constant
(
config
.
learning_rate
)
return
learning_rate
def
main
(
unused_argv
):
if
not
FLAGS
.
input_file_pattern
:
raise
ValueError
(
"--input_file_pattern is required."
)
if
not
FLAGS
.
train_dir
:
raise
ValueError
(
"--train_dir is required."
)
model_config
=
configuration
.
model_config
(
input_file_pattern
=
FLAGS
.
input_file_pattern
)
training_config
=
configuration
.
training_config
()
tf
.
logging
.
info
(
"Building training graph."
)
g
=
tf
.
Graph
()
with
g
.
as_default
():
model
=
skip_thoughts_model
.
SkipThoughtsModel
(
model_config
,
mode
=
"train"
)
model
.
build
()
learning_rate
=
_setup_learning_rate
(
training_config
,
model
.
global_step
)
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
)
train_tensor
=
tf
.
contrib
.
slim
.
learning
.
create_train_op
(
total_loss
=
model
.
total_loss
,
optimizer
=
optimizer
,
global_step
=
model
.
global_step
,
clip_gradient_norm
=
training_config
.
clip_gradient_norm
)
saver
=
tf
.
train
.
Saver
()
tf
.
contrib
.
slim
.
learning
.
train
(
train_op
=
train_tensor
,
logdir
=
FLAGS
.
train_dir
,
graph
=
g
,
global_step
=
model
.
global_step
,
number_of_steps
=
training_config
.
number_of_steps
,
save_summaries_secs
=
training_config
.
save_summaries_secs
,
saver
=
saver
,
save_interval_secs
=
training_config
.
save_model_secs
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
skip_thoughts/skip_thoughts/vocabulary_expansion.py
0 → 100644
View file @
68609ca7
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compute an expanded vocabulary of embeddings using a word2vec model.
This script loads the word embeddings from a trained skip-thoughts model and
from a trained word2vec model (typically with a larger vocabulary). It trains a
linear regression model without regularization to learn a linear mapping from
the word2vec embedding space to the skip-thoughts embedding space. The model is
then applied to all words in the word2vec vocabulary, yielding vectors in the
skip-thoughts word embedding space for the union of the two vocabularies.
The linear regression task is to learn a parameter matrix W to minimize
|| X - Y * W ||^2,
where X is a matrix of skip-thoughts embeddings of shape [num_words, dim1],
Y is a matrix of word2vec embeddings of shape [num_words, dim2], and W is a
matrix of shape [dim2, dim1].
This is based on the "Translation Matrix" method from the paper:
"Exploiting Similarities among Languages for Machine Translation"
Tomas Mikolov, Quoc V. Le, Ilya Sutskever
https://arxiv.org/abs/1309.4168
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
os.path
import
gensim.models
import
numpy
as
np
import
sklearn.linear_model
import
tensorflow
as
tf
FLAGS
=
tf
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_string
(
"skip_thoughts_model"
,
None
,
"Checkpoint file or directory containing a checkpoint "
"file."
)
tf
.
flags
.
DEFINE_string
(
"skip_thoughts_vocab"
,
None
,
"Path to vocabulary file containing a list of newline-"
"separated words where the word id is the "
"corresponding 0-based index in the file."
)
tf
.
flags
.
DEFINE_string
(
"word2vec_model"
,
None
,
"File containing a word2vec model in binary format."
)
tf
.
flags
.
DEFINE_string
(
"output_dir"
,
None
,
"Output directory."
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
def
_load_skip_thoughts_embeddings
(
checkpoint_path
):
"""Loads the embedding matrix from a skip-thoughts model checkpoint.
Args:
checkpoint_path: Model checkpoint file or directory containing a checkpoint
file.
Returns:
word_embedding: A numpy array of shape [vocab_size, embedding_dim].
Raises:
ValueError: If no checkpoint file matches checkpoint_path.
"""
if
tf
.
gfile
.
IsDirectory
(
checkpoint_path
):
checkpoint_file
=
tf
.
train
.
latest_checkpoint
(
checkpoint_path
)
if
not
checkpoint_file
:
raise
ValueError
(
"No checkpoint file found in %s"
%
checkpoint_path
)
else
:
checkpoint_file
=
checkpoint_path
tf
.
logging
.
info
(
"Loading skip-thoughts embedding matrix from %s"
,
checkpoint_file
)
reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_file
)
word_embedding
=
reader
.
get_tensor
(
"word_embedding"
)
tf
.
logging
.
info
(
"Loaded skip-thoughts embedding matrix of shape %s"
,
word_embedding
.
shape
)
return
word_embedding
def
_load_vocabulary
(
filename
):
"""Loads a vocabulary file.
Args:
filename: Path to text file containing newline-separated words.
Returns:
vocab: A dictionary mapping word to word id.
"""
tf
.
logging
.
info
(
"Reading vocabulary from %s"
,
filename
)
vocab
=
collections
.
OrderedDict
()
with
tf
.
gfile
.
GFile
(
filename
,
mode
=
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
word
=
line
.
decode
(
"utf-8"
).
strip
()
assert
word
not
in
vocab
,
"Attempting to add word twice: %s"
%
word
vocab
[
word
]
=
i
tf
.
logging
.
info
(
"Read vocabulary of size %d"
,
len
(
vocab
))
return
vocab
def
_expand_vocabulary
(
skip_thoughts_emb
,
skip_thoughts_vocab
,
word2vec
):
"""Runs vocabulary expansion on a skip-thoughts model using a word2vec model.
Args:
skip_thoughts_emb: A numpy array of shape [skip_thoughts_vocab_size,
skip_thoughts_embedding_dim].
skip_thoughts_vocab: A dictionary of word to id.
word2vec: An instance of gensim.models.Word2Vec.
Returns:
combined_emb: A dictionary mapping words to embedding vectors.
"""
# Find words shared between the two vocabularies.
tf
.
logging
.
info
(
"Finding shared words"
)
shared_words
=
[
w
for
w
in
word2vec
.
vocab
if
w
in
skip_thoughts_vocab
]
# Select embedding vectors for shared words.
tf
.
logging
.
info
(
"Selecting embeddings for %d shared words"
,
len
(
shared_words
))
shared_st_emb
=
skip_thoughts_emb
[[
skip_thoughts_vocab
[
w
]
for
w
in
shared_words
]]
shared_w2v_emb
=
word2vec
[
shared_words
]
# Train a linear regression model on the shared embedding vectors.
tf
.
logging
.
info
(
"Training linear regression model"
)
model
=
sklearn
.
linear_model
.
LinearRegression
()
model
.
fit
(
shared_w2v_emb
,
shared_st_emb
)
# Create the expanded vocabulary.
tf
.
logging
.
info
(
"Creating embeddings for expanded vocabuary"
)
combined_emb
=
collections
.
OrderedDict
()
for
w
in
word2vec
.
vocab
:
# Ignore words with underscores (spaces).
if
"_"
not
in
w
:
w_emb
=
model
.
predict
(
word2vec
[
w
].
reshape
(
1
,
-
1
))
combined_emb
[
w
]
=
w_emb
.
reshape
(
-
1
)
for
w
in
skip_thoughts_vocab
:
combined_emb
[
w
]
=
skip_thoughts_emb
[
skip_thoughts_vocab
[
w
]]
tf
.
logging
.
info
(
"Created expanded vocabulary of %d words"
,
len
(
combined_emb
))
return
combined_emb
def
main
(
unused_argv
):
if
not
FLAGS
.
skip_thoughts_model
:
raise
ValueError
(
"--skip_thoughts_model is required."
)
if
not
FLAGS
.
skip_thoughts_vocab
:
raise
ValueError
(
"--skip_thoughts_vocab is required."
)
if
not
FLAGS
.
word2vec_model
:
raise
ValueError
(
"--word2vec_model is required."
)
if
not
FLAGS
.
output_dir
:
raise
ValueError
(
"--output_dir is required."
)
if
not
tf
.
gfile
.
IsDirectory
(
FLAGS
.
output_dir
):
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
# Load the skip-thoughts embeddings and vocabulary.
skip_thoughts_emb
=
_load_skip_thoughts_embeddings
(
FLAGS
.
skip_thoughts_model
)
skip_thoughts_vocab
=
_load_vocabulary
(
FLAGS
.
skip_thoughts_vocab
)
# Load the Word2Vec model.
word2vec
=
gensim
.
models
.
Word2Vec
.
load_word2vec_format
(
FLAGS
.
word2vec_model
,
binary
=
True
)
# Run vocabulary expansion.
embedding_map
=
_expand_vocabulary
(
skip_thoughts_emb
,
skip_thoughts_vocab
,
word2vec
)
# Save the output.
vocab
=
embedding_map
.
keys
()
vocab_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"vocab.txt"
)
with
tf
.
gfile
.
GFile
(
vocab_file
,
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
vocab
))
tf
.
logging
.
info
(
"Wrote vocabulary file to %s"
,
vocab_file
)
embeddings
=
np
.
array
(
embedding_map
.
values
())
embeddings_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"embeddings.npy"
)
np
.
save
(
embeddings_file
,
embeddings
)
tf
.
logging
.
info
(
"Wrote embeddings file to %s"
,
embeddings_file
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment