Commit ef8f894a authored by Konstantinos Bousmalis's avatar Konstantinos Bousmalis Committed by Konstantinos Bousmalis
Browse files

DSN infrastructure staging

parent 40a5739a
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to create a DSN model and add the different losses to it.
Specifically, in this file we define the:
- Shared Encoding Similarity Loss Module, with:
- The MMD Similarity method
- The Correlation Similarity method
- The Gradient Reversal (Domain-Adversarial) method
- Difference Loss Module
- Reconstruction Loss Module
- Task Loss Module
"""
from functools import partial
import tensorflow as tf
import losses
import models
import utils
slim = tf.contrib.slim
################################################################################
# HELPER FUNCTIONS
################################################################################
def dsn_loss_coefficient(params):
"""The global_step-dependent weight that specifies when to kick in DSN losses.
Args:
params: A dictionary of parameters. Expecting 'domain_separation_startpoint'
Returns:
A weight to that effectively enables or disables the DSN-related losses,
i.e. similarity, difference, and reconstruction losses.
"""
return tf.where(
tf.less(slim.get_or_create_global_step(),
params['domain_separation_startpoint']), 1e-10, 1.0)
################################################################################
# MODEL CREATION
################################################################################
def create_model(source_images, source_labels, domain_selection_mask,
target_images, target_labels, similarity_loss, params,
basic_tower_name):
"""Creates a DSN model.
Args:
source_images: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
hot for the number of classes.
domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
the labeled images that belong to the source domain.
target_images: images from the target domain, a tensor of size
[batch_size, height width, channels].
target_labels: a dictionary with the name, tensor pairs.
similarity_loss: The type of method to use for encouraging
the codes from the shared encoder to be similar.
params: A dictionary of parameters. Expecting 'weight_decay',
'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
'decoder_name', 'encoder_name'
basic_tower_name: the name of the tower to use for the shared encoder.
Raises:
ValueError: if the arch is not one of the available architectures.
"""
network = getattr(models, basic_tower_name)
num_classes = source_labels['classes'].get_shape().as_list()[1]
# Make sure we are using the appropriate number of classes.
network = partial(network, num_classes=num_classes)
# Add the classification/pose estimation loss to the source domain.
source_endpoints = add_task_loss(source_images, source_labels, network,
params)
if similarity_loss == 'none':
# No domain adaptation, we can stop here.
return
with tf.variable_scope('towers', reuse=True):
target_logits, target_endpoints = network(
target_images, weight_decay=params['weight_decay'], prefix='target')
# Plot target accuracy of the train set.
target_accuracy = utils.accuracy(
tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
if 'quaternions' in target_labels:
target_quaternion_loss = losses.log_quaternion_loss(
target_labels['quaternions'], target_endpoints['quaternion_pred'],
params)
tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)
tf.summary.scalar('eval/Target accuracy', target_accuracy)
source_shared = source_endpoints[params['layers_to_regularize']]
target_shared = target_endpoints[params['layers_to_regularize']]
# When using the semisupervised model we include labeled target data in the
# source classifier. We do not want to include these target domain when
# we use the similarity loss.
indices = tf.range(0, source_shared.get_shape().as_list()[0])
indices = tf.boolean_mask(indices, domain_selection_mask)
add_similarity_loss(similarity_loss,
tf.gather(source_shared, indices),
tf.gather(target_shared, indices), params)
if params['use_separation']:
add_autoencoders(
source_images,
source_shared,
target_images,
target_shared,
params=params,)
def add_similarity_loss(method_name,
source_samples,
target_samples,
params,
scope=None):
"""Adds a loss encouraging the shared encoding from each domain to be similar.
Args:
method_name: the name of the encoding similarity method to use. Valid
options include `dann_loss', `mmd_loss' or `correlation_loss'.
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
params: a dictionary of parameters. Expecting 'gamma_weight'.
scope: optional name scope for summary tags.
Raises:
ValueError: if `method_name` is not recognized.
"""
weight = dsn_loss_coefficient(params) * params['gamma_weight']
method = getattr(losses, method_name)
method(source_samples, target_samples, weight, scope)
def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
"""Adds a reconstruction loss.
Args:
recon_loss_name: The name of the reconstruction loss.
images: A `Tensor` of size [batch_size, height, width, 3].
recons: A `Tensor` whose size matches `images`.
weight: A scalar coefficient for the loss.
domain: The name of the domain being reconstructed.
Raises:
ValueError: If `recon_loss_name` is not recognized.
"""
if recon_loss_name == 'sum_of_pairwise_squares':
loss_fn = tf.contrib.losses.mean_pairwise_squared_error
elif recon_loss_name == 'sum_of_squares':
loss_fn = tf.contrib.losses.mean_squared_error
else:
raise ValueError('recon_loss_name value [%s] not recognized.' %
recon_loss_name)
loss = loss_fn(recons, images, weight)
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.scalar('losses/%s Recon Loss' % domain, loss)
def add_autoencoders(source_data, source_shared, target_data, target_shared,
params):
"""Adds the encoders/decoders for our domain separation model w/ incoherence.
Args:
source_data: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_shared: a tensor with first dimension batch_size
target_data: images from the target domain, a tensor of size
[batch_size, height, width, channels]
target_shared: a tensor with first dimension batch_size
params: A dictionary of parameters. Expecting 'layers_to_regularize',
'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
'encoder_name', 'weight_decay'
"""
def normalize_images(images):
images -= tf.reduce_min(images)
return images / tf.reduce_max(images)
def concat_operation(shared_repr, private_repr):
return shared_repr + private_repr
mu = dsn_loss_coefficient(params)
# The layer to concatenate the networks at.
concat_layer = params['layers_to_regularize']
# The coefficient for modulating the private/shared difference loss.
difference_loss_weight = params['beta_weight'] * mu
# The reconstruction weight.
recon_loss_weight = params['alpha_weight'] * mu
# The reconstruction loss to use.
recon_loss_name = params['recon_loss_name']
# The decoder/encoder to use.
decoder_name = params['decoder_name']
encoder_name = params['encoder_name']
_, height, width, _ = source_data.get_shape().as_list()
code_size = source_shared.get_shape().as_list()[-1]
weight_decay = params['weight_decay']
encoder_fn = getattr(models, encoder_name)
# Target Auto-encoding.
with tf.variable_scope('source_encoder'):
source_endpoints = encoder_fn(
source_data, code_size, weight_decay=weight_decay)
with tf.variable_scope('target_encoder'):
target_endpoints = encoder_fn(
target_data, code_size, weight_decay=weight_decay)
decoder_fn = getattr(models, decoder_name)
decoder = partial(
decoder_fn,
height=height,
width=width,
channels=source_data.get_shape().as_list()[-1],
weight_decay=weight_decay)
# Source Auto-encoding.
source_private = source_endpoints[concat_layer]
target_private = target_endpoints[concat_layer]
with tf.variable_scope('decoder'):
source_recons = decoder(concat_operation(source_shared, source_private))
with tf.variable_scope('decoder', reuse=True):
source_private_recons = decoder(
concat_operation(tf.zeros_like(source_private), source_private))
source_shared_recons = decoder(
concat_operation(source_shared, tf.zeros_like(source_shared)))
with tf.variable_scope('decoder', reuse=True):
target_recons = decoder(concat_operation(target_shared, target_private))
target_shared_recons = decoder(
concat_operation(target_shared, tf.zeros_like(target_shared)))
target_private_recons = decoder(
concat_operation(tf.zeros_like(target_private), target_private))
losses.difference_loss(
source_private,
source_shared,
weight=difference_loss_weight,
name='Source')
losses.difference_loss(
target_private,
target_shared,
weight=difference_loss_weight,
name='Target')
add_reconstruction_loss(recon_loss_name, source_data, source_recons,
recon_loss_weight, 'source')
add_reconstruction_loss(recon_loss_name, target_data, target_recons,
recon_loss_weight, 'target')
# Add summaries
source_reconstructions = tf.concat(
map(normalize_images, [
source_data, source_recons, source_shared_recons,
source_private_recons
]), 2)
target_reconstructions = tf.concat(
map(normalize_images, [
target_data, target_recons, target_shared_recons,
target_private_recons
]), 2)
tf.summary.image(
'Source Images:Recons:RGB',
source_reconstructions[:, :, :, :3],
max_outputs=10)
tf.summary.image(
'Target Images:Recons:RGB',
target_reconstructions[:, :, :, :3],
max_outputs=10)
if source_reconstructions.get_shape().as_list()[3] == 4:
tf.summary.image(
'Source Images:Recons:Depth',
source_reconstructions[:, :, :, 3:4],
max_outputs=10)
tf.summary.image(
'Target Images:Recons:Depth',
target_reconstructions[:, :, :, 3:4],
max_outputs=10)
def add_task_loss(source_images, source_labels, basic_tower, params):
"""Adds a classification and/or pose estimation loss to the model.
Args:
source_images: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_labels: labels from the source domain, a tensor of size [batch_size].
or a tuple of (quaternions, class_labels)
basic_tower: a function that creates the single tower of the model.
params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
Returns:
The source endpoints.
Raises:
RuntimeError: if basic tower does not support pose estimation.
"""
with tf.variable_scope('towers'):
source_logits, source_endpoints = basic_tower(
source_images, weight_decay=params['weight_decay'], prefix='Source')
if 'quaternions' in source_labels: # We have pose estimation as well
if 'quaternion_pred' not in source_endpoints:
raise RuntimeError('Please use a model for estimation e.g. pose_mini')
loss = losses.log_quaternion_loss(source_labels['quaternions'],
source_endpoints['quaternion_pred'],
params)
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
quaternion_loss = loss
tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
slim.losses.add_loss(quaternion_loss * params['pose_weight'])
tf.summary.scalar('losses/quaternion_loss', quaternion_loss)
classification_loss = tf.losses.softmax_cross_entropy(
source_labels['classes'], source_logits)
tf.summary.scalar('losses/classification_loss', classification_loss)
return source_endpoints
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
r"""Evaluation for Domain Separation Networks (DSNs).
To build locally for CPU:
blaze build -c opt --copt=-mavx \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
To build locally for GPU:
blaze build -c opt --copt=-mavx --config=cuda_clang \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
To run locally:
$
./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
\
--alsologtostderr
"""
# pylint: enable=line-too-long
import math
import numpy as np
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.domain_separation import losses
from domain_adaptation.domain_separation import models
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 32,
'The number of images in each batch.')
tf.app.flags.DEFINE_string('master', '',
'BNS name of the TensorFlow master to use.')
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
'Directory where the model was written to.')
tf.app.flags.DEFINE_string(
'eval_dir', '/tmp/da/',
'Directory where we should write the tf summaries to.')
tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/',
'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_string('dataset', 'mnist_m',
'Which dataset to test on: "mnist", "mnist_m".')
tf.app.flags.DEFINE_string('split', 'valid',
'Which portion to test on: "valid", "test".')
tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist',
'The basic tower building block.')
tf.app.flags.DEFINE_bool('enable_precision_recall', False,
'If True, precision and recall for each class will '
'be added to the metrics.')
tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
def quaternion_metric(predictions, labels):
params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
return slim.metrics.streaming_mean(logcost)
def angle_diff(true_q, pred_q):
angles = 2 * (
180.0 /
np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
return angles
def provide_batch_fn():
""" The provide_batch function to use. """
return dataset_factory.provide_batch
def main(_):
g = tf.Graph()
with g.as_default():
# Load the data.
images, labels = provide_batch_fn()(
FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4)
num_classes = labels['classes'].get_shape().as_list()[1]
tf.summary.image('eval_images', images, max_outputs=3)
# Define the model:
with tf.variable_scope('towers'):
basic_tower = getattr(models, FLAGS.basic_tower)
predictions, endpoints = basic_tower(
images,
num_classes=num_classes,
is_training=False,
batch_norm_params=None)
metric_names_to_values = {}
# Define the metrics:
if 'quaternions' in labels: # Also have to evaluate pose estimation!
quaternion_loss = quaternion_metric(labels['quaternions'],
endpoints['quaternion_pred'])
angle_errors, = tf.py_func(
angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
[tf.float32])
metric_names_to_values[
'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
metric_names_to_values['Quaternion Loss'] = quaternion_loss
accuracy = tf.contrib.metrics.streaming_accuracy(
tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
predictions = tf.argmax(predictions, 1)
labels = tf.argmax(labels['classes'], 1)
metric_names_to_values['Accuracy'] = accuracy
if FLAGS.enable_precision_recall:
for i in xrange(num_classes):
index_map = tf.one_hot(i, depth=num_classes)
name = 'PR/Precision_{}'.format(i)
metric_names_to_values[name] = slim.metrics.streaming_precision(
tf.gather(index_map, predictions), tf.gather(index_map, labels))
name = 'PR/Recall_{}'.format(i)
metric_names_to_values[name] = slim.metrics.streaming_recall(
tf.gather(index_map, predictions), tf.gather(index_map, labels))
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
metric_names_to_values)
# Create the summary ops such that they also print out to std output:
summary_ops = []
for metric_name, metric_value in names_to_values.iteritems():
op = tf.summary.scalar(metric_name, metric_value)
op = tf.Print(op, [metric_value], metric_name)
summary_ops.append(op)
# This ensures that we make a single pass over all of the data.
num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
# Setup the global step.
slim.get_or_create_global_step()
slim.evaluation.evaluation_loop(
FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
summary_op=tf.summary.merge(summary_ops))
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DSN model assembly functions."""
import numpy as np
import tensorflow as tf
import dsn
class HelperFunctionsTest(tf.test.TestCase):
def testBasicDomainSeparationStartPoint(self):
with self.test_session() as sess:
# Test for when global_step < domain_separation_startpoint
step = tf.contrib.slim.get_or_create_global_step()
sess.run(tf.initialize_all_variables()) # global_step = 0
params = {'domain_separation_startpoint': 2}
weight = dsn.dsn_loss_coefficient(params)
weight_np = sess.run(weight)
self.assertAlmostEqual(weight_np, 1e-10)
step_op = tf.assign_add(step, 1)
step_np = sess.run(step_op) # global_step = 1
weight = dsn.dsn_loss_coefficient(params)
weight_np = sess.run(weight)
self.assertAlmostEqual(weight_np, 1e-10)
# Test for when global_step >= domain_separation_startpoint
step_np = sess.run(step_op) # global_step = 2
tf.logging.info(step_np)
weight = dsn.dsn_loss_coefficient(params)
weight_np = sess.run(weight)
self.assertAlmostEqual(weight_np, 1.0)
class DsnModelAssemblyTest(tf.test.TestCase):
def _testBuildDefaultModel(self):
images = tf.to_float(np.random.rand(32, 28, 28, 1))
labels = {}
labels['classes'] = tf.one_hot(
tf.to_int32(np.random.randint(0, 9, (32))), 10)
params = {
'use_separation': True,
'layers_to_regularize': 'fc3',
'weight_decay': 0.0,
'ps_tasks': 1,
'domain_separation_startpoint': 1,
'alpha_weight': 1,
'beta_weight': 1,
'gamma_weight': 1,
'recon_loss_name': 'sum_of_squares',
'decoder_name': 'small_decoder',
'encoder_name': 'default_encoder',
}
return images, labels, params
def testBuildModelDann(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelDannSumOfPairwiseSquares(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelDannMultiPSTasks(self):
images, labels, params = self._testBuildDefaultModel()
params['ps_tasks'] = 10
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelMmd(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'mmd_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelCorr(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'correlation_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelNoDomainAdaptation(self):
images, labels, params = self._testBuildDefaultModel()
params['use_separation'] = False
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 1)
self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 0)
def testBuildModelNoAdaptationWeightDecay(self):
images, labels, params = self._testBuildDefaultModel()
params['use_separation'] = False
params['weight_decay'] = 1e-5
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 1)
self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1)
def testBuildModelNoSeparation(self):
images, labels, params = self._testBuildDefaultModel()
params['use_separation'] = False
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 2)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
r"""Training for Domain Separation Networks (DSNs).
-- Compile:
$ blaze build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_train
-- Run:
$
./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_train
\
--similarity_loss=dann \
--basic_tower=dsn_cropped_linemod \
--source_dataset=pose_synthetic \
--target_dataset=pose_real \
--learning_rate=0.012 \
--alpha_weight=0.26 \
--gamma_weight=0.0115 \
--weight_decay=4e-5 \
--layers_to_regularize=fc3 \
--use_separation \
--alsologtostderr
"""
# pylint: enable=line-too-long
from __future__ import division
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
import dsn
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 32,
'The number of images in each batch.')
tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic',
'Source dataset to train on.')
tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
'Target dataset to train on.')
tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
'Target dataset to train on.')
tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/',
'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_string('master', '',
'BNS name of the TensorFlow master to use.')
tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/',
'Directory where to write event logs.')
tf.app.flags.DEFINE_string(
'layers_to_regularize', 'fc3',
'Comma-seperated list of layer names to use MMD regularization on.')
tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate')
tf.app.flags.DEFINE_float('alpha_weight', 1e-6,
'The coefficient for scaling the reconstruction '
'loss.')
tf.app.flags.DEFINE_float(
'beta_weight', 1e-6,
'The coefficient for scaling the private/shared difference loss.')
tf.app.flags.DEFINE_float(
'gamma_weight', 1e-6,
'The coefficient for scaling the shared encoding similarity loss.')
tf.app.flags.DEFINE_float('pose_weight', 0.125,
'The coefficient for scaling the pose loss.')
tf.app.flags.DEFINE_float(
'weight_decay', 1e-6,
'The coefficient for the L2 regularization applied for all weights.')
tf.app.flags.DEFINE_integer(
'save_summaries_secs', 60,
'The frequency with which summaries are saved, in seconds.')
tf.app.flags.DEFINE_integer(
'save_interval_secs', 60,
'The frequency with which the model is saved, in seconds.')
tf.app.flags.DEFINE_integer(
'max_number_of_steps', None,
'The maximum number of gradient steps. Use None to train indefinitely.')
tf.app.flags.DEFINE_integer(
'domain_separation_startpoint', 1,
'The global step to add the domain separation losses.')
tf.app.flags.DEFINE_integer(
'bipartite_assignment_top_k', 3,
'The number of top-k matches to use in bipartite matching adaptation.')
tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.')
tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.')
tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
tf.app.flags.DEFINE_bool('use_separation', False,
'Use our domain separation model.')
tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
tf.app.flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
tf.app.flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
tf.app.flags.DEFINE_string('decoder_name', 'small_decoder',
'The decoder to use.')
tf.app.flags.DEFINE_string('encoder_name', 'default_encoder',
'The encoder to use.')
################################################################################
# Flags that control the architecture and losses
################################################################################
tf.app.flags.DEFINE_string(
'similarity_loss', 'grl',
'The method to use for encouraging the common encoder codes to be '
'similar, one of "grl", "mmd", "corr".')
tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares',
'The name of the reconstruction loss.')
tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
'The basic tower building block.')
def provide_batch_fn():
""" The provide_batch function to use. """
return dataset_factory.provide_batch
def main(_):
model_params = {
'use_separation': FLAGS.use_separation,
'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
'layers_to_regularize': FLAGS.layers_to_regularize,
'alpha_weight': FLAGS.alpha_weight,
'beta_weight': FLAGS.beta_weight,
'gamma_weight': FLAGS.gamma_weight,
'pose_weight': FLAGS.pose_weight,
'recon_loss_name': FLAGS.recon_loss_name,
'decoder_name': FLAGS.decoder_name,
'encoder_name': FLAGS.encoder_name,
'weight_decay': FLAGS.weight_decay,
'batch_size': FLAGS.batch_size,
'use_logging': FLAGS.use_logging,
'ps_tasks': FLAGS.ps_tasks,
'task': FLAGS.task,
}
g = tf.Graph()
with g.as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Load the data.
source_images, source_labels = provide_batch_fn()(
FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
FLAGS.batch_size, FLAGS.num_preprocessing_threads)
target_images, target_labels = provide_batch_fn()(
FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
FLAGS.batch_size, FLAGS.num_preprocessing_threads)
# In the unsupervised case all the samples in the labeled
# domain are from the source domain.
domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
True)
# When using the semisupervised model we include labeled target data in
# the source labelled data.
if FLAGS.target_labeled_dataset != 'none':
# 1000 is the maximum number of labelled target samples that exists in
# the datasets.
target_semi_images, target_semi_labels = data_provider.provide(
FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)
# Calculate the proportion of source domain samples in the semi-
# supervised setting, so that the proportion is set accordingly in the
# batches.
proportion = float(source_labels['num_train_samples']) / (
source_labels['num_train_samples'] +
target_semi_labels['num_train_samples'])
rnd_tensor = tf.random_uniform(
(target_semi_images.get_shape().as_list()[0],))
domain_selection_mask = rnd_tensor < proportion
source_images = tf.where(domain_selection_mask, source_images,
target_semi_images)
source_class_labels = tf.where(domain_selection_mask,
source_labels['classes'],
target_semi_labels['classes'])
if 'quaternions' in source_labels:
source_pose_labels = tf.where(domain_selection_mask,
source_labels['quaternions'],
target_semi_labels['quaternions'])
(source_images, source_class_labels, source_pose_labels,
domain_selection_mask) = tf.train.shuffle_batch(
[
source_images, source_class_labels, source_pose_labels,
domain_selection_mask
],
FLAGS.batch_size,
50000,
5000,
num_threads=1,
enqueue_many=True)
else:
(source_images, source_class_labels,
domain_selection_mask) = tf.train.shuffle_batch(
[source_images, source_class_labels, domain_selection_mask],
FLAGS.batch_size,
50000,
5000,
num_threads=1,
enqueue_many=True)
source_labels = {}
source_labels['classes'] = source_class_labels
if 'quaternions' in source_labels:
source_labels['quaternions'] = source_pose_labels
slim.get_or_create_global_step()
tf.summary.image('source_images', source_images, max_outputs=3)
tf.summary.image('target_images', target_images, max_outputs=3)
dsn.create_model(
source_images,
source_labels,
domain_selection_mask,
target_images,
target_labels,
FLAGS.similarity_loss,
model_params,
basic_tower_name=FLAGS.basic_tower)
# Configure the optimization scheme:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
slim.get_or_create_global_step(),
FLAGS.decay_steps,
FLAGS.decay_rate,
staircase=True,
name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
tf.summary.scalar('total_loss', tf.losses.get_total_loss())
opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
tf.logging.set_verbosity(tf.logging.INFO)
# Run training.
loss_tensor = slim.learning.create_train_op(
slim.losses.get_total_loss(),
opt,
summarize_gradients=True,
colocate_gradients_with_ops=True)
slim.learning.train(
train_op=loss_tensor,
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
number_of_steps=FLAGS.max_number_of_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in grl_ops.py."""
import tensorflow as tf
@tf.RegisterGradient("GradientReversal")
def _GradientReversalGrad(_, grad):
"""The gradients for `gradient_reversal`.
Args:
_: The `gradient_reversal` `Operation` that we are differentiating,
which we can use to find the inputs and outputs of the original op.
grad: Gradient with respect to the output of the `gradient_reversal` op.
Returns:
Gradient with respect to the input of `gradient_reversal`, which is simply
the negative of the input gradient.
"""
return tf.negative(grad)
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains the implementations of the ops registered in
// grl_ops.cc.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
// The gradient reversal op is used in domain adversarial training. It behaves
// as the identity op during forward propagation, and multiplies its input by -1
// during backward propagation.
class GradientReversalOp : public OpKernel {
public:
explicit GradientReversalOp(OpKernelConstruction* context)
: OpKernel(context) {}
// Gradient reversal op behaves as the identity op during forward
// propagation. Compute() function copied from the IdentityOp::Compute()
// function here: third_party/tensorflow/core/kernels/identity_op.h.
void Compute(OpKernelContext* context) override {
if (IsRefType(context->input_dtype(0))) {
context->forward_ref_input_to_ref_output(0, 0);
} else {
context->set_output(0, context->input(0));
}
}
};
REGISTER_KERNEL_BUILDER(Name("GradientReversal").Device(DEVICE_CPU),
GradientReversalOp);
} // namespace tensorflow
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shape inference for operators defined in grl_ops.cc."""
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Contains custom ops.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// This custom op is used by adversarial training.
REGISTER_OP("GradientReversal")
.Input("input: float")
.Output("output: float")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
This op copies the input to the output during forward propagation, and
negates the input during backward propagation.
input: Tensor.
output: Tensor, copied from input.
)doc");
} // namespace tensorflow
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""GradientReversal op Python library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
tf.logging.info(tf.resource_loader.get_data_files_path())
_grl_ops_module = tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(),
'_grl_ops.so'))
gradient_reversal = _grl_ops_module.gradient_reversal
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for grl_ops."""
#from models.domain_adaptation.domain_separation import grl_op_grads # pylint: disable=unused-import
#from models.domain_adaptation.domain_separation import grl_op_shapes # pylint: disable=unused-import
import tensorflow as tf
import grl_op_grads
import grl_ops
FLAGS = tf.app.flags.FLAGS
class GRLOpsTest(tf.test.TestCase):
def testGradientReversalOp(self):
with tf.Graph().as_default():
with self.test_session():
# Test that in forward prop, gradient reversal op acts as the
# identity operation.
examples = tf.constant([5.0, 4.0, 3.0, 2.0, 1.0])
output = grl_ops.gradient_reversal(examples)
expected_output = examples
self.assertAllEqual(output.eval(), expected_output.eval())
# Test that shape inference works as expected.
self.assertAllEqual(output.get_shape(), expected_output.get_shape())
# Test that in backward prop, gradient reversal op multiplies
# gradients by -1.
examples = tf.constant([[1.0]])
w = tf.get_variable(name='w', shape=[1, 1])
b = tf.get_variable(name='b', shape=[1])
init_op = tf.global_variables_initializer()
init_op.run()
features = tf.nn.xw_plus_b(examples, w, b)
# Construct two outputs: features layer passes directly to output1, but
# features layer passes through a gradient reversal layer before
# reaching output2.
output1 = features
output2 = grl_ops.gradient_reversal(features)
gold = tf.constant([1.0])
loss1 = gold - output1
loss2 = gold - output2
opt = tf.train.GradientDescentOptimizer(learning_rate=0.01)
grads_and_vars_1 = opt.compute_gradients(loss1,
tf.trainable_variables())
grads_and_vars_2 = opt.compute_gradients(loss2,
tf.trainable_variables())
self.assertAllEqual(len(grads_and_vars_1), len(grads_and_vars_2))
for i in range(len(grads_and_vars_1)):
g1 = grads_and_vars_1[i][0]
g2 = grads_and_vars_2[i][0]
# Verify that gradients of loss1 are the negative of gradients of
# loss2.
self.assertAllEqual(tf.negative(g1).eval(), g2.eval())
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Domain Adaptation Loss Functions.
The following domain adaptation loss functions are defined:
- Maximum Mean Discrepancy (MMD).
Relevant paper:
Gretton, Arthur, et al.,
"A kernel two-sample test."
The Journal of Machine Learning Research, 2012
- Correlation Loss on a batch.
"""
from functools import partial
import tensorflow as tf
import grl_op_grads # pylint: disable=unused-import
import grl_op_shapes # pylint: disable=unused-import
import grl_ops
import utils
slim = tf.contrib.slim
################################################################################
# SIMILARITY LOSS
################################################################################
def maximum_mean_discrepancy(x, y, kernel=utils.gaussian_kernel_matrix):
r"""Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y.
Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of
the distributions of x and y. Here we use the kernel two sample estimate
using the empirical mean of the two distributions.
MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2
= \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) },
where K = <\phi(x), \phi(y)>,
is the desired kernel function, in this case a radial basis kernel.
Args:
x: a tensor of shape [num_samples, num_features]
y: a tensor of shape [num_samples, num_features]
kernel: a function which computes the kernel in MMD. Defaults to the
GaussianKernelMatrix.
Returns:
a scalar denoting the squared maximum mean discrepancy loss.
"""
with tf.name_scope('MaximumMeanDiscrepancy'):
# \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }
cost = tf.reduce_mean(kernel(x, x))
cost += tf.reduce_mean(kernel(y, y))
cost -= 2 * tf.reduce_mean(kernel(x, y))
# We do not allow the loss to become negative.
cost = tf.where(cost > 0, cost, 0, name='value')
return cost
def mmd_loss(source_samples, target_samples, weight, scope=None):
"""Adds a similarity loss term, the MMD between two representations.
This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
different Gaussian kernels.
Args:
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the MMD loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the MMD loss value.
"""
sigmas = [
1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
1e3, 1e4, 1e5, 1e6
]
gaussian_kernel = partial(
utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))
loss_value = maximum_mean_discrepancy(
source_samples, target_samples, kernel=gaussian_kernel)
loss_value = tf.maximum(1e-4, loss_value) * weight
assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
with tf.control_dependencies([assert_op]):
tag = 'MMD Loss'
if scope:
tag = scope + tag
tf.contrib.deprecated.scalar_summary(tag, loss_value)
tf.losses.add_loss(loss_value)
return loss_value
def correlation_loss(source_samples, target_samples, weight, scope=None):
"""Adds a similarity loss term, the correlation between two representations.
Args:
source_samples: a tensor of shape [num_samples, num_features]
target_samples: a tensor of shape [num_samples, num_features]
weight: a scalar weight for the loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the correlation loss value.
"""
with tf.name_scope('corr_loss'):
source_samples -= tf.reduce_mean(source_samples, 0)
target_samples -= tf.reduce_mean(target_samples, 0)
source_samples = tf.nn.l2_normalize(source_samples, 1)
target_samples = tf.nn.l2_normalize(target_samples, 1)
source_cov = tf.matmul(tf.transpose(source_samples), source_samples)
target_cov = tf.matmul(tf.transpose(target_samples), target_samples)
corr_loss = tf.reduce_mean(tf.square(source_cov - target_cov)) * weight
assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss])
with tf.control_dependencies([assert_op]):
tag = 'Correlation Loss'
if scope:
tag = scope + tag
tf.contrib.deprecated.scalar_summary(tag, corr_loss)
tf.losses.add_loss(corr_loss)
return corr_loss
def dann_loss(source_samples, target_samples, weight, scope=None):
"""Adds the domain adversarial (DANN) loss.
Args:
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the correlation loss value.
"""
with tf.variable_scope('dann'):
batch_size = tf.shape(source_samples)[0]
samples = tf.concat([source_samples, target_samples], 0)
samples = slim.flatten(samples)
domain_selection_mask = tf.concat(
[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], 0)
# Perform the gradient reversal and be careful with the shape.
grl = grl_ops.gradient_reversal(samples)
grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1]))
grl = slim.fully_connected(grl, 100, scope='fc1')
logits = slim.fully_connected(grl, 1, activation_fn=None, scope='fc2')
domain_predictions = tf.sigmoid(logits)
domain_loss = tf.losses.log_loss(
domain_selection_mask, domain_predictions, weights=weight)
domain_accuracy = utils.accuracy(
tf.round(domain_predictions), domain_selection_mask)
assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
with tf.control_dependencies([assert_op]):
tag_loss = 'losses/Domain Loss'
tag_accuracy = 'losses/Domain Accuracy'
if scope:
tag_loss = scope + tag_loss
tag_accuracy = scope + tag_accuracy
tf.contrib.deprecated.scalar_summary(
tag_loss, domain_loss, name='domain_loss_summary')
tf.contrib.deprecated.scalar_summary(
tag_accuracy, domain_accuracy, name='domain_accuracy_summary')
return domain_loss
################################################################################
# DIFFERENCE LOSS
################################################################################
def difference_loss(private_samples, shared_samples, weight=1.0, name=''):
"""Adds the difference loss between the private and shared representations.
Args:
private_samples: a tensor of shape [num_samples, num_features].
shared_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the incoherence loss.
name: the name of the tf summary.
"""
private_samples -= tf.reduce_mean(private_samples, 0)
shared_samples -= tf.reduce_mean(shared_samples, 0)
private_samples = tf.nn.l2_normalize(private_samples, 1)
shared_samples = tf.nn.l2_normalize(shared_samples, 1)
correlation_matrix = tf.matmul(
private_samples, shared_samples, transpose_a=True)
cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
cost = tf.where(cost > 0, cost, 0, name='value')
tf.contrib.deprecated.scalar_summary('losses/Difference Loss {}'.format(name),
cost)
assert_op = tf.Assert(tf.is_finite(cost), [cost])
with tf.control_dependencies([assert_op]):
tf.losses.add_loss(cost)
################################################################################
# TASK LOSS
################################################################################
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(
internal_dot_products,
[internal_dot_products, tf.shape(internal_dot_products)],
'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DSN losses."""
from functools import partial
import numpy as np
import tensorflow as tf
import losses
import utils
def MaximumMeanDiscrepancySlow(x, y, sigmas):
num_samples = x.get_shape().as_list()[0]
def AverageGaussianKernel(x, y, sigmas):
result = 0
for sigma in sigmas:
dist = tf.reduce_sum(tf.square(x - y))
result += tf.exp((-1.0 / (2.0 * sigma)) * dist)
return result / num_samples**2
total = 0
for i in range(num_samples):
for j in range(num_samples):
total += AverageGaussianKernel(x[i, :], x[j, :], sigmas)
total += AverageGaussianKernel(y[i, :], y[j, :], sigmas)
total += -2 * AverageGaussianKernel(x[i, :], y[j, :], sigmas)
return total
class LogQuaternionLossTest(tf.test.TestCase):
def test_log_quaternion_loss_batch(self):
with self.test_session():
predictions = tf.random_uniform((10, 4), seed=1)
predictions = tf.nn.l2_normalize(predictions, 1)
labels = tf.random_uniform((10, 4), seed=1)
labels = tf.nn.l2_normalize(labels, 1)
params = {'batch_size': 10, 'use_logging': False}
x = losses.log_quaternion_loss_batch(predictions, labels, params)
self.assertTrue(((10,) == tf.shape(x).eval()).all())
class MaximumMeanDiscrepancyTest(tf.test.TestCase):
def test_mmd_name(self):
with self.test_session():
x = tf.random_uniform((2, 3), seed=1)
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
loss = losses.maximum_mean_discrepancy(x, x, kernel)
self.assertEquals(loss.op.name, 'MaximumMeanDiscrepancy/value')
def test_mmd_is_zero_when_inputs_are_same(self):
with self.test_session():
x = tf.random_uniform((2, 3), seed=1)
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval())
def test_fast_mmd_is_similar_to_slow_mmd(self):
with self.test_session():
x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
y = tf.constant(np.random.rand(2, 3), tf.float32)
cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval()
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval()
self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
def test_multiple_sigmas(self):
with self.test_session():
x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
y = tf.constant(np.random.rand(2, 3), tf.float32)
sigmas = tf.constant([2., 5., 10, 20, 30])
kernel = partial(utils.gaussian_kernel_matrix, sigmas=sigmas)
cost_old = MaximumMeanDiscrepancySlow(x, y, [2., 5., 10, 20, 30]).eval()
cost_new = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
def test_mmd_is_zero_when_distributions_are_same(self):
with self.test_session():
x = tf.random_uniform((1000, 10), seed=1)
y = tf.random_uniform((1000, 10), seed=3)
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.]))
loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
self.assertAlmostEqual(0, loss, delta=1e-4)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains different architectures for the different DSN parts.
We define here the modules that can be used in the different parts of the DSN
model.
- shared encoder (dsn_cropped_linemod, dann_xxxx)
- private encoder (default_encoder)
- decoder (large_decoder, gtsrb_decoder, small_decoder)
"""
import tensorflow as tf
#from models.domain_adaptation.domain_separation
import utils
slim = tf.contrib.slim
def default_batch_norm_params(is_training=False):
"""Returns default batch normalization parameters for DSNs.
Args:
is_training: whether or not the model is training.
Returns:
a dictionary that maps batch norm parameter names (strings) to values.
"""
return {
# Decay for the moving averages.
'decay': 0.5,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
'is_training': is_training
}
################################################################################
# PRIVATE ENCODERS
################################################################################
def default_encoder(images, code_size, batch_norm_params=None,
weight_decay=0.0):
"""Encodes the given images to codes of the given size.
Args:
images: a tensor of size [batch_size, height, width, 1].
code_size: the number of hidden units in the code layer of the classifier.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
end_points: the code of the input.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.conv2d], kernel_size=[5, 5], padding='SAME'):
net = slim.conv2d(images, 32, scope='conv1')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
net = slim.conv2d(net, 64, scope='conv2')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
net = slim.flatten(net)
end_points['flatten'] = net
net = slim.fully_connected(net, code_size, scope='fc1')
end_points['fc3'] = net
return end_points
################################################################################
# DECODERS
################################################################################
def large_decoder(codes,
height,
width,
channels,
batch_norm_params=None,
weight_decay=0.0):
"""Decodes the codes to a fixed output size.
Args:
codes: a tensor of size [batch_size, code_size].
height: the height of the output images.
width: the width of the output images.
channels: the number of the output channels.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
recons: the reconstruction tensor of shape [batch_size, height, width, 3].
"""
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net = slim.fully_connected(codes, 600, scope='fc1')
batch_size = net.get_shape().as_list()[0]
net = tf.reshape(net, [batch_size, 10, 10, 6])
net = slim.conv2d(net, 32, [5, 5], scope='conv1_1')
net = tf.image.resize_nearest_neighbor(net, (16, 16))
net = slim.conv2d(net, 32, [5, 5], scope='conv2_1')
net = tf.image.resize_nearest_neighbor(net, (32, 32))
net = slim.conv2d(net, 32, [5, 5], scope='conv3_2')
output_size = [height, width]
net = tf.image.resize_nearest_neighbor(net, output_size)
with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
net = slim.conv2d(net, channels, activation_fn=None, scope='conv4_1')
return net
def gtsrb_decoder(codes,
height,
width,
channels,
batch_norm_params=None,
weight_decay=0.0):
"""Decodes the codes to a fixed output size. This decoder is specific to GTSRB
Args:
codes: a tensor of size [batch_size, 100].
height: the height of the output images.
width: the width of the output images.
channels: the number of the output channels.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
recons: the reconstruction tensor of shape [batch_size, height, width, 3].
Raises:
ValueError: When the input code size is not 100.
"""
batch_size, code_size = codes.get_shape().as_list()
if code_size != 100:
raise ValueError('The code size used as an input to the GTSRB decoder is '
'expected to be 100.')
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net = codes
net = tf.reshape(net, [batch_size, 10, 10, 1])
net = slim.conv2d(net, 32, [3, 3], scope='conv1_1')
# First upsampling 20x20
net = tf.image.resize_nearest_neighbor(net, [20, 20])
net = slim.conv2d(net, 32, [3, 3], scope='conv2_1')
output_size = [height, width]
# Final upsampling 40 x 40
net = tf.image.resize_nearest_neighbor(net, output_size)
with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
net = slim.conv2d(net, 16, scope='conv3_1')
net = slim.conv2d(net, channels, activation_fn=None, scope='conv3_2')
return net
def small_decoder(codes,
height,
width,
channels,
batch_norm_params=None,
weight_decay=0.0):
"""Decodes the codes to a fixed output size.
Args:
codes: a tensor of size [batch_size, code_size].
height: the height of the output images.
width: the width of the output images.
channels: the number of the output channels.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
recons: the reconstruction tensor of shape [batch_size, height, width, 3].
"""
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net = slim.fully_connected(codes, 300, scope='fc1')
batch_size = net.get_shape().as_list()[0]
net = tf.reshape(net, [batch_size, 10, 10, 3])
net = slim.conv2d(net, 16, [3, 3], scope='conv1_1')
net = slim.conv2d(net, 16, [3, 3], scope='conv1_2')
output_size = [height, width]
net = tf.image.resize_nearest_neighbor(net, output_size)
with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
net = slim.conv2d(net, 16, scope='conv2_1')
net = slim.conv2d(net, channels, activation_fn=None, scope='conv2_2')
return net
################################################################################
# SHARED ENCODERS
################################################################################
def dann_mnist(images,
weight_decay=0.0,
prefix='model',
num_classes=10,
**kwargs):
"""Creates a convolution MNIST model.
Note that this model implements the architecture for MNIST proposed in:
Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
JMLR 2015
Args:
images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
**kwargs: Placeholder for keyword arguments used by other shared encoders.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [2, 2], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 48, [5, 5], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [2, 2], 2, scope='pool2')
end_points['fc3'] = slim.fully_connected(
slim.flatten(end_points['pool2']), 100, scope='fc3')
end_points['fc4'] = slim.fully_connected(
slim.flatten(end_points['fc3']), 100, scope='fc4')
logits = slim.fully_connected(
end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, end_points
def dann_svhn(images,
weight_decay=0.0,
prefix='model',
num_classes=10,
**kwargs):
"""Creates the convolutional SVHN model.
Note that this model implements the architecture for MNIST proposed in:
Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
JMLR 2015
Args:
images: the SVHN digits, a tensor of size [batch_size, 32, 32, 3].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
**kwargs: Placeholder for keyword arguments used by other shared encoders.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [3, 3], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 64, [5, 5], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [3, 3], 2, scope='pool2')
end_points['conv3'] = slim.conv2d(
end_points['pool2'], 128, [5, 5], scope='conv3')
end_points['fc3'] = slim.fully_connected(
slim.flatten(end_points['conv3']), 3072, scope='fc3')
end_points['fc4'] = slim.fully_connected(
slim.flatten(end_points['fc3']), 2048, scope='fc4')
logits = slim.fully_connected(
end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, end_points
def dann_gtsrb(images,
weight_decay=0.0,
prefix='model',
num_classes=43,
**kwargs):
"""Creates the convolutional GTSRB model.
Note that this model implements the architecture for MNIST proposed in:
Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
JMLR 2015
Args:
images: the GTSRB images, a tensor of size [batch_size, 40, 40, 3].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
**kwargs: Placeholder for keyword arguments used by other shared encoders.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [2, 2], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 144, [3, 3], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [2, 2], 2, scope='pool2')
end_points['conv3'] = slim.conv2d(
end_points['pool2'], 256, [5, 5], scope='conv3')
end_points['pool3'] = slim.max_pool2d(
end_points['conv3'], [2, 2], 2, scope='pool3')
end_points['fc3'] = slim.fully_connected(
slim.flatten(end_points['pool3']), 512, scope='fc3')
logits = slim.fully_connected(
end_points['fc3'], num_classes, activation_fn=None, scope='fc4')
return logits, end_points
def dsn_cropped_linemod(images,
weight_decay=0.0,
prefix='model',
num_classes=11,
batch_norm_params=None,
is_training=False):
"""Creates the convolutional pose estimation model for Cropped Linemod.
Args:
images: the Cropped Linemod samples, a tensor of size
[batch_size, 64, 64, 4].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
tf.summary.image('{}/input_images'.format(prefix), images)
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm if batch_norm_params else None,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [2, 2], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 64, [5, 5], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [2, 2], 2, scope='pool2')
net = slim.flatten(end_points['pool2'])
end_points['fc3'] = slim.fully_connected(net, 128, scope='fc3')
net = slim.dropout(
end_points['fc3'], 0.5, is_training=is_training, scope='dropout')
with tf.variable_scope('quaternion_prediction'):
predicted_quaternion = slim.fully_connected(
net, 4, activation_fn=tf.nn.tanh)
predicted_quaternion = tf.nn.l2_normalize(predicted_quaternion, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc4')
end_points['quaternion_pred'] = predicted_quaternion
return logits, end_points
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DSN components."""
import numpy as np
import tensorflow as tf
#from models.domain_adaptation.domain_separation
import models
class SharedEncodersTest(tf.test.TestCase):
def _testSharedEncoder(self,
input_shape=[5, 28, 28, 1],
model=models.dann_mnist,
is_training=True):
images = tf.to_float(np.random.rand(*input_shape))
with self.test_session() as sess:
logits, _ = model(images)
sess.run(tf.global_variables_initializer())
logits_np = sess.run(logits)
return logits_np
def testBuildGRLMnistModel(self):
logits = self._testSharedEncoder(model=getattr(models,
'dann_mnist'))
self.assertEqual(logits.shape, (5, 10))
self.assertTrue(np.any(logits))
def testBuildGRLSvhnModel(self):
logits = self._testSharedEncoder(model=getattr(models,
'dann_svhn'))
self.assertEqual(logits.shape, (5, 10))
self.assertTrue(np.any(logits))
def testBuildGRLGtsrbModel(self):
logits = self._testSharedEncoder([5, 40, 40, 3],
getattr(models, 'dann_gtsrb'))
self.assertEqual(logits.shape, (5, 43))
self.assertTrue(np.any(logits))
def testBuildPoseModel(self):
logits = self._testSharedEncoder([5, 64, 64, 4],
getattr(models, 'dsn_cropped_linemod'))
self.assertEqual(logits.shape, (5, 11))
self.assertTrue(np.any(logits))
def testBuildPoseModelWithBatchNorm(self):
images = tf.to_float(np.random.rand(10, 64, 64, 4))
with self.test_session() as sess:
logits, _ = getattr(models, 'dsn_cropped_linemod')(
images, batch_norm_params=models.default_batch_norm_params(True))
sess.run(tf.global_variables_initializer())
logits_np = sess.run(logits)
self.assertEqual(logits_np.shape, (10, 11))
self.assertTrue(np.any(logits_np))
class EncoderTest(tf.test.TestCase):
def _testEncoder(self, batch_norm_params=None, channels=1):
images = tf.to_float(np.random.rand(10, 28, 28, channels))
with self.test_session() as sess:
end_points = models.default_encoder(
images, 128, batch_norm_params=batch_norm_params)
sess.run(tf.global_variables_initializer())
private_code = sess.run(end_points['fc3'])
self.assertEqual(private_code.shape, (10, 128))
self.assertTrue(np.any(private_code))
self.assertTrue(np.all(np.isfinite(private_code)))
def testEncoder(self):
self._testEncoder()
def testEncoderMultiChannel(self):
self._testEncoder(None, 4)
def testEncoderIsTrainingBatchNorm(self):
self._testEncoder(models.default_batch_norm_params(True))
def testEncoderBatchNorm(self):
self._testEncoder(models.default_batch_norm_params(False))
class DecoderTest(tf.test.TestCase):
def _testDecoder(self,
height=64,
width=64,
channels=4,
batch_norm_params=None,
decoder=models.small_decoder):
codes = tf.to_float(np.random.rand(32, 100))
with self.test_session() as sess:
output = decoder(
codes,
height=height,
width=width,
channels=channels,
batch_norm_params=batch_norm_params)
sess.run(tf.initialize_all_variables())
output_np = sess.run(output)
self.assertEqual(output_np.shape, (32, height, width, channels))
self.assertTrue(np.any(output_np))
self.assertTrue(np.all(np.isfinite(output_np)))
def testSmallDecoder(self):
self._testDecoder(28, 28, 4, None, getattr(models, 'small_decoder'))
def testSmallDecoderThreeChannels(self):
self._testDecoder(28, 28, 3)
def testSmallDecoderBatchNorm(self):
self._testDecoder(28, 28, 4, models.default_batch_norm_params(False))
def testSmallDecoderIsTrainingBatchNorm(self):
self._testDecoder(28, 28, 4, models.default_batch_norm_params(True))
def testLargeDecoder(self):
self._testDecoder(32, 32, 4, None, getattr(models, 'large_decoder'))
def testLargeDecoderThreeChannels(self):
self._testDecoder(32, 32, 3, None, getattr(models, 'large_decoder'))
def testLargeDecoderBatchNorm(self):
self._testDecoder(32, 32, 4,
models.default_batch_norm_params(False),
getattr(models, 'large_decoder'))
def testLargeDecoderIsTrainingBatchNorm(self):
self._testDecoder(32, 32, 4,
models.default_batch_norm_params(True),
getattr(models, 'large_decoder'))
def testGtsrbDecoder(self):
self._testDecoder(40, 40, 3, None, getattr(models, 'large_decoder'))
def testGtsrbDecoderBatchNorm(self):
self._testDecoder(40, 40, 4,
models.default_batch_norm_params(False),
getattr(models, 'gtsrb_decoder'))
def testGtsrbDecoderIsTrainingBatchNorm(self):
self._testDecoder(40, 40, 4,
models.default_batch_norm_params(True),
getattr(models, 'gtsrb_decoder'))
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Auxiliary functions for domain adaptation related losses.
"""
import math
import tensorflow as tf
def create_summaries(end_points, prefix='', max_images=3, use_op_name=False):
"""Creates a tf summary per endpoint.
If the endpoint is a 4 dimensional tensor it displays it as an image
otherwise if it is a two dimensional one it creates a histogram summary.
Args:
end_points: a dictionary of name, tf tensor pairs.
prefix: an optional string to prefix the summary with.
max_images: the maximum number of images to display per summary.
use_op_name: Use the op name as opposed to the shorter end_points key.
"""
for layer_name in end_points:
if use_op_name:
name = end_points[layer_name].op.name
else:
name = layer_name
if len(end_points[layer_name].get_shape().as_list()) == 4:
# if it's an actual image do not attempt to reshape it
if end_points[layer_name].get_shape().as_list()[-1] == 1 or end_points[
layer_name].get_shape().as_list()[-1] == 3:
visualization_image = end_points[layer_name]
else:
visualization_image = reshape_feature_maps(end_points[layer_name])
tf.summary.image(
'{}/{}'.format(prefix, name),
visualization_image,
max_outputs=max_images)
elif len(end_points[layer_name].get_shape().as_list()) == 3:
images = tf.expand_dims(end_points[layer_name], 3)
tf.summary.image(
'{}/{}'.format(prefix, name),
images,
max_outputs=max_images)
elif len(end_points[layer_name].get_shape().as_list()) == 2:
tf.summary.histogram('{}/{}'.format(prefix, name), end_points[layer_name])
def reshape_feature_maps(features_tensor):
"""Reshape activations for tf.summary.image visualization.
Arguments:
features_tensor: a tensor of activations with a square number of feature
maps, eg 4, 9, 16, etc.
Returns:
A composite image with all the feature maps that can be passed as an
argument to tf.summary.image.
"""
assert len(features_tensor.get_shape().as_list()) == 4
num_filters = features_tensor.get_shape().as_list()[-1]
assert num_filters > 0
num_filters_sqrt = math.sqrt(num_filters)
assert num_filters_sqrt.is_integer(
), 'Number of filters should be a square number but got {}'.format(
num_filters)
num_filters_sqrt = int(num_filters_sqrt)
conv_summary = tf.unstack(features_tensor, axis=3)
conv_one_row = tf.concat(conv_summary[0:num_filters_sqrt], 2)
ind = 1
conv_final = conv_one_row
for ind in range(1, num_filters_sqrt):
conv_one_row = tf.concat(conv_summary[
ind * num_filters_sqrt + 0:ind * num_filters_sqrt + num_filters_sqrt],
2)
conv_final = tf.concat(
[tf.squeeze(conv_final), tf.squeeze(conv_one_row)], 1)
conv_final = tf.expand_dims(conv_final, -1)
return conv_final
def accuracy(predictions, labels):
"""Calculates the classificaton accuracy.
Args:
predictions: the predicted values, a tensor whose size matches 'labels'.
labels: the ground truth values, a tensor of any size.
Returns:
a tensor whose value on evaluation returns the total accuracy.
"""
return tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))
def compute_upsample_values(input_tensor, upsample_height, upsample_width):
"""Compute values for an upsampling op (ops.BatchCropAndResize).
Args:
input_tensor: image tensor with shape [batch, height, width, in_channels]
upsample_height: integer
upsample_width: integer
Returns:
grid_centers: tensor with shape [batch, 1]
crop_sizes: tensor with shape [batch, 1]
output_height: integer
output_width: integer
"""
batch, input_height, input_width, _ = input_tensor.shape
height_half = input_height / 2.
width_half = input_width / 2.
grid_centers = tf.constant(batch * [[height_half, width_half]])
crop_sizes = tf.constant(batch * [[input_height, input_width]])
output_height = input_height * upsample_height
output_width = input_width * upsample_width
return grid_centers, tf.to_float(crop_sizes), output_height, output_width
def compute_pairwise_distances(x, y):
"""Computes the squared pairwise Euclidean distances between x and y.
Args:
x: a tensor of shape [num_x_samples, num_features]
y: a tensor of shape [num_y_samples, num_features]
Returns:
a distance matrix of dimensions [num_x_samples, num_y_samples].
Raises:
ValueError: if the inputs do no matched the specified dimensions.
"""
if not len(x.get_shape()) == len(y.get_shape()) == 2:
raise ValueError('Both inputs should be matrices.')
if x.get_shape().as_list()[1] != y.get_shape().as_list()[1]:
raise ValueError('The number of features should be the same.')
norm = lambda x: tf.reduce_sum(tf.square(x), 1)
# By making the `inner' dimensions of the two matrices equal to 1 using
# broadcasting then we are essentially substracting every pair of rows
# of x and y.
# x will be num_samples x num_features x 1,
# and y will be 1 x num_features x num_samples (after broadcasting).
# After the substraction we will get a
# num_x_samples x num_features x num_y_samples matrix.
# The resulting dist will be of shape num_y_samples x num_x_samples.
# and thus we need to transpose it again.
return tf.transpose(norm(tf.expand_dims(x, 2) - tf.transpose(y)))
def gaussian_kernel_matrix(x, y, sigmas):
r"""Computes a Guassian Radial Basis Kernel between the samples of x and y.
We create a sum of multiple gaussian kernels each having a width sigma_i.
Args:
x: a tensor of shape [num_samples, num_features]
y: a tensor of shape [num_samples, num_features]
sigmas: a tensor of floats which denote the widths of each of the
gaussians in the kernel.
Returns:
A tensor of shape [num_samples{x}, num_samples{y}] with the RBF kernel.
"""
beta = 1. / (2. * (tf.expand_dims(sigmas, 1)))
dist = compute_pairwise_distances(x, y)
s = tf.matmul(beta, tf.reshape(dist, (1, -1)))
return tf.reshape(tf.reduce_sum(tf.exp(-s), 0), tf.shape(dist))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment