Unverified Commit 5a5d3305 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #2969 from coreylynch/master

Adding TCN.
parents 69cf6fca aa3d4422
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for svtcn_loss.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from estimators import svtcn_loss
import tensorflow as tf
class SVTCNLoss(tf.test.TestCase):
def testSVTCNLoss(self):
with self.test_session():
num_data = 64
num_sequences = 2
num_data_per_seq = num_data // num_sequences
feat_dim = 6
margin = 1.0
times = np.tile(np.arange(num_data_per_seq, dtype=np.int32),
num_sequences)
times = np.reshape(times, [times.shape[0], 1])
sequence_ids = np.concatenate(
[np.ones(num_data_per_seq)*i for i in range(num_sequences)])
sequence_ids = np.reshape(sequence_ids, [sequence_ids.shape[0], 1])
pos_radius = 6
neg_radius = 12
embedding = np.random.rand(num_data, feat_dim).astype(np.float32)
# Compute the loss in NP
# Get a positive mask, i.e. indices for each time index
# that are inside the positive range.
in_pos_range = np.less_equal(
np.abs(times - times.transpose()), pos_radius)
# Get a negative mask, i.e. indices for each time index
# that are inside the negative range (> t + (neg_mult * pos_radius)
# and < t - (neg_mult * pos_radius).
in_neg_range = np.greater(np.abs(times - times.transpose()), neg_radius)
sequence_adjacency = sequence_ids == sequence_ids.T
sequence_adjacency_not = np.logical_not(sequence_adjacency)
pdist_matrix = euclidean_distances(embedding, squared=True)
loss_np = 0.0
num_positives = 0.0
for i in range(num_data):
for j in range(num_data):
if in_pos_range[i, j] and i != j and sequence_adjacency[i, j]:
num_positives += 1.0
pos_distance = pdist_matrix[i][j]
neg_distances = []
for k in range(num_data):
if in_neg_range[i, k] or sequence_adjacency_not[i, k]:
neg_distances.append(pdist_matrix[i][k])
neg_distances.sort() # sort by distance
chosen_neg_distance = neg_distances[0]
for l in range(len(neg_distances)):
chosen_neg_distance = neg_distances[l]
if chosen_neg_distance > pos_distance:
break
loss_np += np.maximum(
0.0, margin - chosen_neg_distance + pos_distance)
loss_np /= num_positives
# Compute the loss in TF
loss_tf = svtcn_loss.singleview_tcn_loss(
embeddings=tf.convert_to_tensor(embedding),
timesteps=tf.convert_to_tensor(times),
pos_radius=pos_radius,
neg_radius=neg_radius,
margin=margin,
sequence_ids=tf.convert_to_tensor(sequence_ids),
multiseq=True
)
loss_tf = loss_tf.eval()
self.assertAllClose(loss_np, loss_tf)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Calculates running validation of TCN models (and baseline comparisons)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from estimators.get_estimator import get_estimator
from utils import util
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string(
'config_paths', '',
"""
Path to a YAML configuration files defining FLAG values. Multiple files
can be separated by the `#` symbol. Files are merged recursively. Setting
a key in these files is equivalent to setting the FLAG value with
the same name.
""")
tf.flags.DEFINE_string(
'model_params', '{}', 'YAML configuration string for the model parameters.')
tf.app.flags.DEFINE_string('master', 'local',
'BNS name of the TensorFlow master to use')
tf.app.flags.DEFINE_string(
'logdir', '/tmp/tcn', 'Directory where to write event logs.')
FLAGS = tf.app.flags.FLAGS
def main(_):
"""Runs main eval loop."""
# Parse config dict from yaml config files / command line flags.
logdir = FLAGS.logdir
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
# Choose an estimator based on training strategy.
estimator = get_estimator(config, logdir)
# Wait for the first checkpoint file to be written.
while not tf.train.latest_checkpoint(logdir):
tf.logging.info('Waiting for a checkpoint file...')
time.sleep(10)
# Run validation.
while True:
estimator.evaluate()
if __name__ == '__main__':
tf.app.run()
This image diff could not be displayed because it is too large. You can view the blob instead.
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Generates imitation videos.
Generate single pairwise imitation videos:
blaze build -c opt --config=cuda --copt=-mavx \
learning/brain/research/tcn/generate_videos && \
blaze-bin/learning/brain/research/tcn/generate_videos \
--logtostderr \
--config_paths $config_paths \
--checkpointdir $checkpointdir \
--checkpoint_iter $checkpoint_iter \
--query_records_dir $query_records_dir \
--target_records_dir $target_records_dir \
--outdir $outdir \
--mode single \
--num_query_sequences 1 \
--num_target_sequences -1
# Generate imitation videos with multiple sequences in the target set:
query_records_path
blaze build -c opt --config=cuda --copt=-mavx \
learning/brain/research/tcn/generate_videos && \
blaze-bin/learning/brain/research/tcn/generate_videos \
--logtostderr \
--config_paths $config_paths \
--checkpointdir $checkpointdir \
--checkpoint_iter $checkpoint_iter \
--query_records_dir $query_records_dir \
--target_records_dir $target_records_dir \
--outdir $outdir \
--num_multi_targets 1 \
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import tensorflow as tf
import os
import matplotlib
matplotlib.use("pdf")
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from estimators.get_estimator import get_estimator
from utils import util
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string(
'config_paths', '',
"""
Path to a YAML configuration files defining FLAG values. Multiple files
can be separated by the `#` symbol. Files are merged recursively. Setting
a key in these files is equivalent to setting the FLAG value with
the same name.
""")
tf.flags.DEFINE_string(
'model_params', '{}', 'YAML configuration string for the model parameters.')
tf.app.flags.DEFINE_string(
'checkpointdir', '/tmp/tcn', 'Path to model checkpoints.')
tf.app.flags.DEFINE_string(
'checkpoint_iter', '', 'Checkpoint iter to use.')
tf.app.flags.DEFINE_integer(
'num_multi_targets', -1,
'Number of imitation vids in the target set per imitation video.')
tf.app.flags.DEFINE_string(
'outdir', '/tmp/tcn', 'Path to write embeddings to.')
tf.app.flags.DEFINE_string(
'mode', 'single', 'single | multi. Single means generate imitation vids'
'where query is being imitated by single sequence. Multi'
'means generate imitation vids where query is being'
'imitated by multiple.')
tf.app.flags.DEFINE_string('query_records_dir', '',
'Directory of image tfrecords.')
tf.app.flags.DEFINE_string('target_records_dir', '',
'Directory of image tfrecords.')
tf.app.flags.DEFINE_integer('query_view', 1,
'Viewpoint of the query video.')
tf.app.flags.DEFINE_integer('target_view', 0,
'Viewpoint of the imitation video.')
tf.app.flags.DEFINE_integer('smoothing_window', 5,
'Number of frames to smooth over.')
tf.app.flags.DEFINE_integer('num_query_sequences', -1,
'Number of query sequences to embed.')
tf.app.flags.DEFINE_integer('num_target_sequences', -1,
'Number of target sequences to embed.')
FLAGS = tf.app.flags.FLAGS
def SmoothEmbeddings(embs):
"""Temporally smoothes a sequence of embeddings."""
new_embs = []
window = int(FLAGS.smoothing_window)
for i in range(len(embs)):
min_i = max(i-window, 0)
max_i = min(i+window, len(embs))
new_embs.append(np.mean(embs[min_i:max_i, :], axis=0))
return np.array(new_embs)
def MakeImitationVideo(
outdir, vidname, query_im_strs, knn_im_strs, height=640, width=360):
"""Creates a KNN imitation video.
For each frame in vid0, pair with the frame at index in knn_indices in
vids1. Write video to disk.
Args:
outdir: String, directory to write videos.
vidname: String, name of video.
query_im_strs: Numpy array holding query image strings.
knn_im_strs: Numpy array holding knn image strings.
height: Int, height of raw images.
width: Int, width of raw images.
"""
if not tf.gfile.Exists(outdir):
tf.gfile.MakeDirs(outdir)
vid_path = os.path.join(outdir, vidname)
combined = zip(query_im_strs, knn_im_strs)
# Create and write the video.
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
im = ax.imshow(
np.zeros((height, width*2, 3)), cmap='gray', interpolation='nearest')
im.set_clim([0, 1])
plt.tight_layout(pad=0, w_pad=0, h_pad=0)
# pylint: disable=invalid-name
def update_img(pair):
"""Decode pairs of image strings, update a video."""
im_i, im_j = pair
nparr_i = np.fromstring(str(im_i), np.uint8)
img_np_i = cv2.imdecode(nparr_i, 1)
img_np_i = img_np_i[..., [2, 1, 0]]
nparr_j = np.fromstring(str(im_j), np.uint8)
img_np_j = cv2.imdecode(nparr_j, 1)
img_np_j = img_np_j[..., [2, 1, 0]]
# Optionally reshape the images to be same size.
frame = np.concatenate([img_np_i, img_np_j], axis=1)
im.set_data(frame)
return im
ani = animation.FuncAnimation(fig, update_img, combined, interval=15)
writer = animation.writers['ffmpeg'](fps=15)
dpi = 100
tf.logging.info('Writing video to:\n %s \n' % vid_path)
ani.save('%s.mp4' % vid_path, writer=writer, dpi=dpi)
def GenerateImitationVideo(
vid_name, query_ims, query_embs, target_ims, target_embs, height, width):
"""Generates a single cross-sequence imitation video.
For each frame in some query sequence, find the nearest neighbor from
some target sequence in embedding space.
Args:
vid_name: String, the name of the video.
query_ims: Numpy array of shape [query sequence length, height, width, 3].
query_embs: Numpy array of shape [query sequence length, embedding size].
target_ims: Numpy array of shape [target sequence length, height, width,
3].
target_embs: Numpy array of shape [target sequence length, embedding
size].
height: Int, height of the raw image.
width: Int, width of the raw image.
"""
# For each query frame, find the index of the nearest neighbor in the
# target video.
knn_indices = [util.KNNIds(q, target_embs, k=1)[0] for q in query_embs]
# Create and write out the video.
assert knn_indices
knn_ims = np.array([target_ims[k] for k in knn_indices])
MakeImitationVideo(FLAGS.outdir, vid_name, query_ims, knn_ims, height, width)
def SingleImitationVideos(
query_records, target_records, config, height, width):
"""Generates pairwise imitation videos.
This creates all pairs of target imitating query videos, where each frame
on the left is matched to a nearest neighbor coming a single
embedded target video.
Args:
query_records: List of Strings, paths to tfrecord datasets to use as
queries.
target_records: List of Strings, paths to tfrecord datasets to use as
targets.
config: A T object describing training config.
height: Int, height of the raw image.
width: Int, width of the raw image.
"""
# Embed query and target data.
(query_sequences_to_data,
target_sequences_to_data) = EmbedQueryTargetData(
query_records, target_records, config)
qview = FLAGS.query_view
tview = FLAGS.target_view
# Loop over query videos.
for task_i, data_i in query_sequences_to_data.iteritems():
for task_j, data_j in target_sequences_to_data.iteritems():
i_ims = data_i['images']
i_embs = data_i['embeddings']
query_embs = SmoothEmbeddings(i_embs[qview])
query_ims = i_ims[qview]
j_ims = data_j['images']
j_embs = data_j['embeddings']
target_embs = SmoothEmbeddings(j_embs[tview])
target_ims = j_ims[tview]
tf.logging.info('Generating %s imitating %s video.' % (task_j, task_i))
vid_name = 'q%sv%s_im%sv%s' % (task_i, qview, task_j, tview)
vid_name = vid_name.replace('/', '_')
GenerateImitationVideo(vid_name, query_ims, query_embs,
target_ims, target_embs, height, width)
def MultiImitationVideos(
query_records, target_records, config, height, width):
"""Creates multi-imitation videos.
This creates videos where every frame on the left is matched to a nearest
neighbor coming from a set of multiple embedded target videos.
Args:
query_records: List of Strings, paths to tfrecord datasets to use as
queries.
target_records: List of Strings, paths to tfrecord datasets to use as
targets.
config: A T object describing training config.
height: Int, height of the raw image.
width: Int, width of the raw image.
"""
# Embed query and target data.
(query_sequences_to_data,
target_sequences_to_data) = EmbedQueryTargetData(
query_records, target_records, config)
qview = FLAGS.query_view
tview = FLAGS.target_view
# Loop over query videos.
for task_i, data_i in query_sequences_to_data.iteritems():
i_ims = data_i['images']
i_embs = data_i['embeddings']
query_embs = SmoothEmbeddings(i_embs[qview])
query_ims = i_ims[qview]
all_target_embs = []
all_target_ims = []
# If num_imitation_vids is -1, add all seq embeddings to the target set.
if FLAGS.num_multi_targets == -1:
num_multi_targets = len(target_sequences_to_data)
else:
# Else, add some specified number of seq embeddings to the target set.
num_multi_targets = FLAGS.num_multi_targets
for j in range(num_multi_targets):
task_j = target_sequences_to_data.keys()[j]
data_j = target_sequences_to_data[task_j]
print('Adding %s to target set' % task_j)
j_ims = data_j['images']
j_embs = data_j['embeddings']
target_embs = SmoothEmbeddings(j_embs[tview])
target_ims = j_ims[tview]
all_target_embs.extend(target_embs)
all_target_ims.extend(target_ims)
# Generate a "j imitating i" video.
tf.logging.info('Generating all imitating %s video.' % task_i)
vid_name = 'q%sv%s_multiv%s' % (task_i, qview, tview)
vid_name = vid_name.replace('/', '_')
GenerateImitationVideo(vid_name, query_ims, query_embs,
all_target_ims, all_target_embs, height, width)
def SameSequenceVideos(query_records, config, height, width):
"""Generate same sequence, cross-view imitation videos."""
batch_size = config.data.embed_batch_size
# Choose an estimator based on training strategy.
estimator = get_estimator(config, FLAGS.checkpointdir)
# Choose a checkpoint path to restore.
checkpointdir = FLAGS.checkpointdir
checkpoint_path = os.path.join(checkpointdir,
'model.ckpt-%s' % FLAGS.checkpoint_iter)
# Embed num_sequences query sequences, store embeddings and image strings in
# query_sequences_to_data.
sequences_to_data = {}
for (view_embeddings, view_raw_image_strings, seqname) in estimator.inference(
query_records, checkpoint_path, batch_size,
num_sequences=FLAGS.num_query_sequences):
sequences_to_data[seqname] = {
'embeddings': view_embeddings,
'images': view_raw_image_strings,
}
# Loop over query videos.
qview = FLAGS.query_view
tview = FLAGS.target_view
for task_i, data_i in sequences_to_data.iteritems():
ims = data_i['images']
embs = data_i['embeddings']
query_embs = SmoothEmbeddings(embs[qview])
query_ims = ims[qview]
target_embs = SmoothEmbeddings(embs[tview])
target_ims = ims[tview]
tf.logging.info('Generating %s imitating %s video.' % (task_i, task_i))
vid_name = 'q%sv%s_im%sv%s' % (task_i, qview, task_i, tview)
vid_name = vid_name.replace('/', '_')
GenerateImitationVideo(vid_name, query_ims, query_embs,
target_ims, target_embs, height, width)
def EmbedQueryTargetData(query_records, target_records, config):
"""Embeds the full set of query_records and target_records.
Args:
query_records: List of Strings, paths to tfrecord datasets to use as
queries.
target_records: List of Strings, paths to tfrecord datasets to use as
targets.
config: A T object describing training config.
Returns:
query_sequences_to_data: A dict holding 'embeddings' and 'images'
target_sequences_to_data: A dict holding 'embeddings' and 'images'
"""
batch_size = config.data.embed_batch_size
# Choose an estimator based on training strategy.
estimator = get_estimator(config, FLAGS.checkpointdir)
# Choose a checkpoint path to restore.
checkpointdir = FLAGS.checkpointdir
checkpoint_path = os.path.join(checkpointdir,
'model.ckpt-%s' % FLAGS.checkpoint_iter)
# Embed num_sequences query sequences, store embeddings and image strings in
# query_sequences_to_data.
num_query_sequences = FLAGS.num_query_sequences
num_target_sequences = FLAGS.num_target_sequences
query_sequences_to_data = {}
for (view_embeddings, view_raw_image_strings, seqname) in estimator.inference(
query_records, checkpoint_path, batch_size,
num_sequences=num_query_sequences):
query_sequences_to_data[seqname] = {
'embeddings': view_embeddings,
'images': view_raw_image_strings,
}
if (query_records == target_records) and (
num_query_sequences == num_target_sequences):
target_sequences_to_data = query_sequences_to_data
else:
# Embed num_sequences target sequences, store embeddings and image strings
# in sequences_to_data.
target_sequences_to_data = {}
for (view_embeddings, view_raw_image_strings,
seqname) in estimator.inference(
target_records, checkpoint_path, batch_size,
num_sequences=num_target_sequences):
target_sequences_to_data[seqname] = {
'embeddings': view_embeddings,
'images': view_raw_image_strings,
}
return query_sequences_to_data, target_sequences_to_data
def main(_):
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
# Get tables to embed.
query_records_dir = FLAGS.query_records_dir
query_records = util.GetFilesRecursively(query_records_dir)
target_records_dir = FLAGS.target_records_dir
target_records = util.GetFilesRecursively(target_records_dir)
height = config.data.raw_height
width = config.data.raw_width
mode = FLAGS.mode
if mode == 'multi':
# Generate videos where target set is composed of multiple videos.
MultiImitationVideos(query_records, target_records, config,
height, width)
elif mode == 'single':
# Generate videos where target set is a single video.
SingleImitationVideos(query_records, target_records, config,
height, width)
elif mode == 'same':
# Generate videos where target set is the same as query, but diff view.
SameSequenceVideos(query_records, config, height, width)
else:
raise ValueError('Unknown mode %s' % mode)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates test Recall@K statistics on labeled classification problems."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import os
import numpy as np
from sklearn.metrics.pairwise import pairwise_distances
import data_providers
from estimators.get_estimator import get_estimator
from utils import util
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string(
'config_paths', '',
"""
Path to a YAML configuration files defining FLAG values. Multiple files
can be separated by the `#` symbol. Files are merged recursively. Setting
a key in these files is equivalent to setting the FLAG value with
the same name.
""")
tf.flags.DEFINE_string(
'model_params', '{}', 'YAML configuration string for the model parameters.')
tf.app.flags.DEFINE_string(
'mode', 'validation',
'Which dataset to evaluate: `validation` | `test`.')
tf.app.flags.DEFINE_string('master', 'local',
'BNS name of the TensorFlow master to use')
tf.app.flags.DEFINE_string(
'checkpoint_iter', '', 'Evaluate this specific checkpoint.')
tf.app.flags.DEFINE_string(
'checkpointdir', '/tmp/tcn', 'Path to model checkpoints.')
tf.app.flags.DEFINE_string('outdir', '/tmp/tcn', 'Path to write summaries to.')
FLAGS = tf.app.flags.FLAGS
def nearest_cross_sequence_neighbors(data, tasks, n_neighbors=1):
"""Computes the n_neighbors nearest neighbors for every row in data.
Args:
data: A np.float32 array of shape [num_data, embedding size] holding
an embedded validation / test dataset.
tasks: A list of strings of size [num_data] holding the task or sequence
name that each row belongs to.
n_neighbors: The number of knn indices to return for each row.
Returns:
indices: an np.int32 array of size [num_data, n_neighbors] holding the
n_neighbors nearest indices for every row in data. These are
restricted to be from different named sequences (as defined in `tasks`).
"""
# Compute the pairwise sequence adjacency matrix from `tasks`.
num_data = data.shape[0]
tasks = np.array(tasks)
tasks = np.reshape(tasks, (num_data, 1))
assert len(tasks.shape) == 2
not_adjacent = (tasks != tasks.T)
# Compute the symmetric pairwise distance matrix.
pdist = pairwise_distances(data, metric='sqeuclidean')
# For every row in the pairwise distance matrix, only consider
# cross-sequence columns.
indices = np.zeros((num_data, n_neighbors), dtype=np.int32)
for idx in range(num_data):
# Restrict to cross_sequence neighbors.
distances = [(
pdist[idx][i], i) for i in xrange(num_data) if not_adjacent[idx][i]]
_, nearest_indices = zip(*sorted(
distances, key=lambda x: x[0])[:n_neighbors])
indices[idx] = nearest_indices
return indices
def compute_cross_sequence_recall_at_k(retrieved_labels, labels, k_list):
"""Compute recall@k for a given list of k values.
Recall is one if an example of the same class is retrieved among the
top k nearest neighbors given a query example and zero otherwise.
Counting the recall for all examples and averaging the counts returns
recall@k score.
Args:
retrieved_labels: 2-D Numpy array of KNN labels for every embedding.
labels: 1-D Numpy array of shape [number of data].
k_list: List of k values to evaluate recall@k.
Returns:
recall_list: List of recall@k values.
"""
kvalue_to_recall = dict(zip(k_list, np.zeros(len(k_list))))
# For each value of K.
for k in k_list:
matches = defaultdict(float)
counts = defaultdict(float)
# For each (row index, label value) in the query labels.
for i, label_value in enumerate(labels):
# Loop over the K nearest retrieved labels.
if label_value in retrieved_labels[i][:k]:
matches[label_value] += 1.
# Increment the denominator.
counts[label_value] += 1.
kvalue_to_recall[k] = np.mean(
[matches[l]/counts[l] for l in matches])
return [kvalue_to_recall[i] for i in k_list]
def compute_cross_sequence_recalls_at_k(
embeddings, labels, label_attr_keys, tasks, k_list, summary_writer,
training_step):
"""Computes and reports the recall@k for each classification problem.
This takes an embedding matrix and an array of multiclass labels
with size [num_data, number of classification problems], then
computes the average recall@k for each classification problem
as well as the average across problems.
Args:
embeddings: A np.float32 array of size [num_data, embedding_size]
representing the embedded validation or test dataset.
labels: A np.int32 array of size [num_data, num_classification_problems]
holding multiclass labels for each embedding for each problem.
label_attr_keys: List of strings, holds the names of the classification
problems.
tasks: A list of strings describing the video sequence each row
belongs to. This is used to restrict the recall@k computation
to cross-sequence examples.
k_list: A list of ints, the k values to evaluate recall@k.
summary_writer: A tf.summary.FileWriter.
training_step: Int, the current training step we're evaluating.
"""
num_data = float(embeddings.shape[0])
assert labels.shape[0] == num_data
# Compute knn indices.
indices = nearest_cross_sequence_neighbors(
embeddings, tasks, n_neighbors=max(k_list))
retrieved_labels = labels[indices]
# Compute the recall@k for each classification problem.
recall_lists = []
for idx, label_attr in enumerate(label_attr_keys):
problem_labels = labels[:, idx]
# Take all indices, all k labels for the problem indexed by idx.
problem_retrieved = retrieved_labels[:, :, idx]
recall_list = compute_cross_sequence_recall_at_k(
retrieved_labels=problem_retrieved,
labels=problem_labels,
k_list=k_list)
recall_lists.append(recall_list)
for (k, recall) in zip(k_list, recall_list):
recall_error = 1-recall
summ = tf.Summary(value=[tf.Summary.Value(
tag='validation/classification/%s error@top%d' % (
label_attr, k),
simple_value=recall_error)])
print('%s recall@K=%d' % (label_attr, k), recall_error)
summary_writer.add_summary(summ, int(training_step))
# Report an average recall@k across problems.
recall_lists = np.array(recall_lists)
for i in range(recall_lists.shape[1]):
average_recall = np.mean(recall_lists[:, i])
recall_error = 1 - average_recall
summ = tf.Summary(value=[tf.Summary.Value(
tag='validation/classification/average error@top%d' % k_list[i],
simple_value=recall_error)])
print('Average recall@K=%d' % k_list[i], recall_error)
summary_writer.add_summary(summ, int(training_step))
def evaluate_once(
estimator, input_fn_by_view, batch_size, checkpoint_path,
label_attr_keys, embedding_size, num_views, k_list):
"""Compute the recall@k for a given checkpoint path.
Args:
estimator: an `Estimator` object to evaluate.
input_fn_by_view: An input_fn to an `Estimator's` predict method. Takes
a view index and returns a dict holding ops for getting raw images for
the view.
batch_size: Int, size of the labeled eval batch.
checkpoint_path: String, path to the specific checkpoint being evaluated.
label_attr_keys: A list of Strings, holding each attribute name.
embedding_size: Int, the size of the embedding.
num_views: Int, number of views in the dataset.
k_list: List of ints, list of K values to compute recall at K for.
"""
feat_matrix = np.zeros((0, embedding_size))
label_vect = np.zeros((0, len(label_attr_keys)))
tasks = []
eval_tensor_keys = ['embeddings', 'tasks', 'classification_labels']
# Iterate all views in the dataset.
for view_index in range(num_views):
# Set up a graph for embedding entire dataset.
predictions = estimator.inference(
input_fn_by_view(view_index), checkpoint_path,
batch_size, predict_keys=eval_tensor_keys)
# Enumerate predictions.
for i, p in enumerate(predictions):
if i % 100 == 0:
tf.logging.info('Embedded %d images for view %d' % (i, view_index))
label = p['classification_labels']
task = p['tasks']
embedding = p['embeddings']
# Collect (embedding, label, task) data.
feat_matrix = np.append(feat_matrix, [embedding], axis=0)
label_vect = np.append(label_vect, [label], axis=0)
tasks.append(task)
# Compute recall statistics.
ckpt_step = int(checkpoint_path.split('-')[-1])
summary_dir = os.path.join(FLAGS.outdir, 'labeled_eval_summaries')
summary_writer = tf.summary.FileWriter(summary_dir)
compute_cross_sequence_recalls_at_k(
feat_matrix, label_vect, label_attr_keys, tasks, k_list,
summary_writer, ckpt_step)
def get_labeled_tables(config):
"""Gets either labeled test or validation tables, based on flags."""
# Get a list of filenames corresponding to labeled data.
mode = FLAGS.mode
if mode == 'validation':
labeled_tables = util.GetFilesRecursively(config.data.labeled.validation)
elif mode == 'test':
labeled_tables = util.GetFilesRecursively(config.data.labeled.test)
else:
raise ValueError('Unknown dataset: %s' % mode)
return labeled_tables
def main(_):
"""Runs main labeled eval loop."""
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
# Choose an estimator based on training strategy.
checkpointdir = FLAGS.checkpointdir
estimator = get_estimator(config, checkpointdir)
# Get data configs.
image_attr_keys = config.data.labeled.image_attr_keys
label_attr_keys = config.data.labeled.label_attr_keys
embedding_size = config.embedding_size
num_views = config.data.num_views
k_list = config.val.recall_at_k_list
batch_size = config.data.batch_size
# Get either labeled validation or test tables.
labeled_tables = get_labeled_tables(config)
def input_fn_by_view(view_index):
"""Returns an input_fn for use with a tf.Estimator by view."""
def input_fn():
# Get raw labeled images.
(preprocessed_images, labels,
tasks) = data_providers.labeled_data_provider(
labeled_tables,
estimator.preprocess_data, view_index, image_attr_keys,
label_attr_keys, batch_size=batch_size)
return {
'batch_preprocessed': preprocessed_images,
'tasks': tasks,
'classification_labels': labels,
}, None
return input_fn
# If evaluating a specific checkpoint, do that.
if FLAGS.checkpoint_iter:
checkpoint_path = os.path.join(
'%s/model.ckpt-%s' % (checkpointdir, FLAGS.checkpoint_iter))
evaluate_once(
estimator, input_fn_by_view, batch_size, checkpoint_path,
label_attr_keys, embedding_size, num_views, k_list)
else:
for checkpoint_path in tf.contrib.training.checkpoints_iterator(
checkpointdir):
evaluate_once(
estimator, input_fn_by_view, batch_size, checkpoint_path,
label_attr_keys, embedding_size, num_views, k_list)
if __name__ == '__main__':
tf.app.run()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains TCN models (and baseline comparisons)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from estimators.get_estimator import get_estimator
from utils import util
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string(
'config_paths', '',
"""
Path to a YAML configuration files defining FLAG values. Multiple files
can be separated by the `#` symbol. Files are merged recursively. Setting
a key in these files is equivalent to setting the FLAG value with
the same name.
""")
tf.flags.DEFINE_string(
'model_params', '{}', 'YAML configuration string for the model parameters.')
tf.app.flags.DEFINE_string('master', 'local',
'BNS name of the TensorFlow master to use')
tf.app.flags.DEFINE_string(
'logdir', '/tmp/tcn', 'Directory where to write event logs.')
tf.app.flags.DEFINE_integer(
'task', 0, 'Task id of the replica running the training.')
tf.app.flags.DEFINE_integer(
'ps_tasks', 0, 'Number of tasks in the ps job. If 0 no ps job is used.')
FLAGS = tf.app.flags.FLAGS
def main(_):
"""Runs main training loop."""
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(
FLAGS.config_paths, FLAGS.model_params, save=True, logdir=FLAGS.logdir)
# Choose an estimator based on training strategy.
estimator = get_estimator(config, FLAGS.logdir)
# Run training
estimator.train()
if __name__ == '__main__':
tf.app.run()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment