Commit 356c98bd authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents d31aba8a b9785623
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""GradientReversal op Python library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
tf.logging.info(tf.resource_loader.get_data_files_path())
_grl_ops_module = tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(),
'_grl_ops.so'))
gradient_reversal = _grl_ops_module.gradient_reversal
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for grl_ops."""
#from models.domain_adaptation.domain_separation import grl_op_grads # pylint: disable=unused-import
#from models.domain_adaptation.domain_separation import grl_op_shapes # pylint: disable=unused-import
import tensorflow as tf
import grl_op_grads
import grl_ops
FLAGS = tf.app.flags.FLAGS
class GRLOpsTest(tf.test.TestCase):
def testGradientReversalOp(self):
with tf.Graph().as_default():
with self.test_session():
# Test that in forward prop, gradient reversal op acts as the
# identity operation.
examples = tf.constant([5.0, 4.0, 3.0, 2.0, 1.0])
output = grl_ops.gradient_reversal(examples)
expected_output = examples
self.assertAllEqual(output.eval(), expected_output.eval())
# Test that shape inference works as expected.
self.assertAllEqual(output.get_shape(), expected_output.get_shape())
# Test that in backward prop, gradient reversal op multiplies
# gradients by -1.
examples = tf.constant([[1.0]])
w = tf.get_variable(name='w', shape=[1, 1])
b = tf.get_variable(name='b', shape=[1])
init_op = tf.global_variables_initializer()
init_op.run()
features = tf.nn.xw_plus_b(examples, w, b)
# Construct two outputs: features layer passes directly to output1, but
# features layer passes through a gradient reversal layer before
# reaching output2.
output1 = features
output2 = grl_ops.gradient_reversal(features)
gold = tf.constant([1.0])
loss1 = gold - output1
loss2 = gold - output2
opt = tf.train.GradientDescentOptimizer(learning_rate=0.01)
grads_and_vars_1 = opt.compute_gradients(loss1,
tf.trainable_variables())
grads_and_vars_2 = opt.compute_gradients(loss2,
tf.trainable_variables())
self.assertAllEqual(len(grads_and_vars_1), len(grads_and_vars_2))
for i in range(len(grads_and_vars_1)):
g1 = grads_and_vars_1[i][0]
g2 = grads_and_vars_2[i][0]
# Verify that gradients of loss1 are the negative of gradients of
# loss2.
self.assertAllEqual(tf.negative(g1).eval(), g2.eval())
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Domain Adaptation Loss Functions.
The following domain adaptation loss functions are defined:
- Maximum Mean Discrepancy (MMD).
Relevant paper:
Gretton, Arthur, et al.,
"A kernel two-sample test."
The Journal of Machine Learning Research, 2012
- Correlation Loss on a batch.
"""
from functools import partial
import tensorflow as tf
import grl_op_grads # pylint: disable=unused-import
import grl_op_shapes # pylint: disable=unused-import
import grl_ops
import utils
slim = tf.contrib.slim
################################################################################
# SIMILARITY LOSS
################################################################################
def maximum_mean_discrepancy(x, y, kernel=utils.gaussian_kernel_matrix):
r"""Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y.
Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of
the distributions of x and y. Here we use the kernel two sample estimate
using the empirical mean of the two distributions.
MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2
= \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) },
where K = <\phi(x), \phi(y)>,
is the desired kernel function, in this case a radial basis kernel.
Args:
x: a tensor of shape [num_samples, num_features]
y: a tensor of shape [num_samples, num_features]
kernel: a function which computes the kernel in MMD. Defaults to the
GaussianKernelMatrix.
Returns:
a scalar denoting the squared maximum mean discrepancy loss.
"""
with tf.name_scope('MaximumMeanDiscrepancy'):
# \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }
cost = tf.reduce_mean(kernel(x, x))
cost += tf.reduce_mean(kernel(y, y))
cost -= 2 * tf.reduce_mean(kernel(x, y))
# We do not allow the loss to become negative.
cost = tf.where(cost > 0, cost, 0, name='value')
return cost
def mmd_loss(source_samples, target_samples, weight, scope=None):
"""Adds a similarity loss term, the MMD between two representations.
This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
different Gaussian kernels.
Args:
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the MMD loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the MMD loss value.
"""
sigmas = [
1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
1e3, 1e4, 1e5, 1e6
]
gaussian_kernel = partial(
utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))
loss_value = maximum_mean_discrepancy(
source_samples, target_samples, kernel=gaussian_kernel)
loss_value = tf.maximum(1e-4, loss_value) * weight
assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
with tf.control_dependencies([assert_op]):
tag = 'MMD Loss'
if scope:
tag = scope + tag
tf.summary.scalar(tag, loss_value)
tf.losses.add_loss(loss_value)
return loss_value
def correlation_loss(source_samples, target_samples, weight, scope=None):
"""Adds a similarity loss term, the correlation between two representations.
Args:
source_samples: a tensor of shape [num_samples, num_features]
target_samples: a tensor of shape [num_samples, num_features]
weight: a scalar weight for the loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the correlation loss value.
"""
with tf.name_scope('corr_loss'):
source_samples -= tf.reduce_mean(source_samples, 0)
target_samples -= tf.reduce_mean(target_samples, 0)
source_samples = tf.nn.l2_normalize(source_samples, 1)
target_samples = tf.nn.l2_normalize(target_samples, 1)
source_cov = tf.matmul(tf.transpose(source_samples), source_samples)
target_cov = tf.matmul(tf.transpose(target_samples), target_samples)
corr_loss = tf.reduce_mean(tf.square(source_cov - target_cov)) * weight
assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss])
with tf.control_dependencies([assert_op]):
tag = 'Correlation Loss'
if scope:
tag = scope + tag
tf.summary.scalar(tag, corr_loss)
tf.losses.add_loss(corr_loss)
return corr_loss
def dann_loss(source_samples, target_samples, weight, scope=None):
"""Adds the domain adversarial (DANN) loss.
Args:
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the correlation loss value.
"""
with tf.variable_scope('dann'):
batch_size = tf.shape(source_samples)[0]
samples = tf.concat(axis=0, values=[source_samples, target_samples])
samples = slim.flatten(samples)
domain_selection_mask = tf.concat(
axis=0, values=[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))])
# Perform the gradient reversal and be careful with the shape.
grl = grl_ops.gradient_reversal(samples)
grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1]))
grl = slim.fully_connected(grl, 100, scope='fc1')
logits = slim.fully_connected(grl, 1, activation_fn=None, scope='fc2')
domain_predictions = tf.sigmoid(logits)
domain_loss = tf.losses.log_loss(
domain_selection_mask, domain_predictions, weights=weight)
domain_accuracy = utils.accuracy(
tf.round(domain_predictions), domain_selection_mask)
assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
with tf.control_dependencies([assert_op]):
tag_loss = 'losses/domain_loss'
tag_accuracy = 'losses/domain_accuracy'
if scope:
tag_loss = scope + tag_loss
tag_accuracy = scope + tag_accuracy
tf.summary.scalar(tag_loss, domain_loss)
tf.summary.scalar(tag_accuracy, domain_accuracy)
return domain_loss
################################################################################
# DIFFERENCE LOSS
################################################################################
def difference_loss(private_samples, shared_samples, weight=1.0, name=''):
"""Adds the difference loss between the private and shared representations.
Args:
private_samples: a tensor of shape [num_samples, num_features].
shared_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the incoherence loss.
name: the name of the tf summary.
"""
private_samples -= tf.reduce_mean(private_samples, 0)
shared_samples -= tf.reduce_mean(shared_samples, 0)
private_samples = tf.nn.l2_normalize(private_samples, 1)
shared_samples = tf.nn.l2_normalize(shared_samples, 1)
correlation_matrix = tf.matmul(
private_samples, shared_samples, transpose_a=True)
cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
cost = tf.where(cost > 0, cost, 0, name='value')
tf.summary.scalar('losses/Difference Loss {}'.format(name),
cost)
assert_op = tf.Assert(tf.is_finite(cost), [cost])
with tf.control_dependencies([assert_op]):
tf.losses.add_loss(cost)
################################################################################
# TASK LOSS
################################################################################
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(
internal_dot_products,
[internal_dot_products, tf.shape(internal_dot_products)],
'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DSN losses."""
from functools import partial
import numpy as np
import tensorflow as tf
import losses
import utils
def MaximumMeanDiscrepancySlow(x, y, sigmas):
num_samples = x.get_shape().as_list()[0]
def AverageGaussianKernel(x, y, sigmas):
result = 0
for sigma in sigmas:
dist = tf.reduce_sum(tf.square(x - y))
result += tf.exp((-1.0 / (2.0 * sigma)) * dist)
return result / num_samples**2
total = 0
for i in range(num_samples):
for j in range(num_samples):
total += AverageGaussianKernel(x[i, :], x[j, :], sigmas)
total += AverageGaussianKernel(y[i, :], y[j, :], sigmas)
total += -2 * AverageGaussianKernel(x[i, :], y[j, :], sigmas)
return total
class LogQuaternionLossTest(tf.test.TestCase):
def test_log_quaternion_loss_batch(self):
with self.test_session():
predictions = tf.random_uniform((10, 4), seed=1)
predictions = tf.nn.l2_normalize(predictions, 1)
labels = tf.random_uniform((10, 4), seed=1)
labels = tf.nn.l2_normalize(labels, 1)
params = {'batch_size': 10, 'use_logging': False}
x = losses.log_quaternion_loss_batch(predictions, labels, params)
self.assertTrue(((10,) == tf.shape(x).eval()).all())
class MaximumMeanDiscrepancyTest(tf.test.TestCase):
def test_mmd_name(self):
with self.test_session():
x = tf.random_uniform((2, 3), seed=1)
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
loss = losses.maximum_mean_discrepancy(x, x, kernel)
self.assertEquals(loss.op.name, 'MaximumMeanDiscrepancy/value')
def test_mmd_is_zero_when_inputs_are_same(self):
with self.test_session():
x = tf.random_uniform((2, 3), seed=1)
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval())
def test_fast_mmd_is_similar_to_slow_mmd(self):
with self.test_session():
x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
y = tf.constant(np.random.rand(2, 3), tf.float32)
cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval()
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval()
self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
def test_multiple_sigmas(self):
with self.test_session():
x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
y = tf.constant(np.random.rand(2, 3), tf.float32)
sigmas = tf.constant([2., 5., 10, 20, 30])
kernel = partial(utils.gaussian_kernel_matrix, sigmas=sigmas)
cost_old = MaximumMeanDiscrepancySlow(x, y, [2., 5., 10, 20, 30]).eval()
cost_new = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
def test_mmd_is_zero_when_distributions_are_same(self):
with self.test_session():
x = tf.random_uniform((1000, 10), seed=1)
y = tf.random_uniform((1000, 10), seed=3)
kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.]))
loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
self.assertAlmostEqual(0, loss, delta=1e-4)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains different architectures for the different DSN parts.
We define here the modules that can be used in the different parts of the DSN
model.
- shared encoder (dsn_cropped_linemod, dann_xxxx)
- private encoder (default_encoder)
- decoder (large_decoder, gtsrb_decoder, small_decoder)
"""
import tensorflow as tf
#from models.domain_adaptation.domain_separation
import utils
slim = tf.contrib.slim
def default_batch_norm_params(is_training=False):
"""Returns default batch normalization parameters for DSNs.
Args:
is_training: whether or not the model is training.
Returns:
a dictionary that maps batch norm parameter names (strings) to values.
"""
return {
# Decay for the moving averages.
'decay': 0.5,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
'is_training': is_training
}
################################################################################
# PRIVATE ENCODERS
################################################################################
def default_encoder(images, code_size, batch_norm_params=None,
weight_decay=0.0):
"""Encodes the given images to codes of the given size.
Args:
images: a tensor of size [batch_size, height, width, 1].
code_size: the number of hidden units in the code layer of the classifier.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
end_points: the code of the input.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.conv2d], kernel_size=[5, 5], padding='SAME'):
net = slim.conv2d(images, 32, scope='conv1')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
net = slim.conv2d(net, 64, scope='conv2')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
net = slim.flatten(net)
end_points['flatten'] = net
net = slim.fully_connected(net, code_size, scope='fc1')
end_points['fc3'] = net
return end_points
################################################################################
# DECODERS
################################################################################
def large_decoder(codes,
height,
width,
channels,
batch_norm_params=None,
weight_decay=0.0):
"""Decodes the codes to a fixed output size.
Args:
codes: a tensor of size [batch_size, code_size].
height: the height of the output images.
width: the width of the output images.
channels: the number of the output channels.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
recons: the reconstruction tensor of shape [batch_size, height, width, 3].
"""
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net = slim.fully_connected(codes, 600, scope='fc1')
batch_size = net.get_shape().as_list()[0]
net = tf.reshape(net, [batch_size, 10, 10, 6])
net = slim.conv2d(net, 32, [5, 5], scope='conv1_1')
net = tf.image.resize_nearest_neighbor(net, (16, 16))
net = slim.conv2d(net, 32, [5, 5], scope='conv2_1')
net = tf.image.resize_nearest_neighbor(net, (32, 32))
net = slim.conv2d(net, 32, [5, 5], scope='conv3_2')
output_size = [height, width]
net = tf.image.resize_nearest_neighbor(net, output_size)
with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
net = slim.conv2d(net, channels, activation_fn=None, scope='conv4_1')
return net
def gtsrb_decoder(codes,
height,
width,
channels,
batch_norm_params=None,
weight_decay=0.0):
"""Decodes the codes to a fixed output size. This decoder is specific to GTSRB
Args:
codes: a tensor of size [batch_size, 100].
height: the height of the output images.
width: the width of the output images.
channels: the number of the output channels.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
recons: the reconstruction tensor of shape [batch_size, height, width, 3].
Raises:
ValueError: When the input code size is not 100.
"""
batch_size, code_size = codes.get_shape().as_list()
if code_size != 100:
raise ValueError('The code size used as an input to the GTSRB decoder is '
'expected to be 100.')
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net = codes
net = tf.reshape(net, [batch_size, 10, 10, 1])
net = slim.conv2d(net, 32, [3, 3], scope='conv1_1')
# First upsampling 20x20
net = tf.image.resize_nearest_neighbor(net, [20, 20])
net = slim.conv2d(net, 32, [3, 3], scope='conv2_1')
output_size = [height, width]
# Final upsampling 40 x 40
net = tf.image.resize_nearest_neighbor(net, output_size)
with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
net = slim.conv2d(net, 16, scope='conv3_1')
net = slim.conv2d(net, channels, activation_fn=None, scope='conv3_2')
return net
def small_decoder(codes,
height,
width,
channels,
batch_norm_params=None,
weight_decay=0.0):
"""Decodes the codes to a fixed output size.
Args:
codes: a tensor of size [batch_size, code_size].
height: the height of the output images.
width: the width of the output images.
channels: the number of the output channels.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
weight_decay: the value for the weight decay coefficient.
Returns:
recons: the reconstruction tensor of shape [batch_size, height, width, 3].
"""
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net = slim.fully_connected(codes, 300, scope='fc1')
batch_size = net.get_shape().as_list()[0]
net = tf.reshape(net, [batch_size, 10, 10, 3])
net = slim.conv2d(net, 16, [3, 3], scope='conv1_1')
net = slim.conv2d(net, 16, [3, 3], scope='conv1_2')
output_size = [height, width]
net = tf.image.resize_nearest_neighbor(net, output_size)
with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
net = slim.conv2d(net, 16, scope='conv2_1')
net = slim.conv2d(net, channels, activation_fn=None, scope='conv2_2')
return net
################################################################################
# SHARED ENCODERS
################################################################################
def dann_mnist(images,
weight_decay=0.0,
prefix='model',
num_classes=10,
**kwargs):
"""Creates a convolution MNIST model.
Note that this model implements the architecture for MNIST proposed in:
Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
JMLR 2015
Args:
images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
**kwargs: Placeholder for keyword arguments used by other shared encoders.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [2, 2], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 48, [5, 5], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [2, 2], 2, scope='pool2')
end_points['fc3'] = slim.fully_connected(
slim.flatten(end_points['pool2']), 100, scope='fc3')
end_points['fc4'] = slim.fully_connected(
slim.flatten(end_points['fc3']), 100, scope='fc4')
logits = slim.fully_connected(
end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, end_points
def dann_svhn(images,
weight_decay=0.0,
prefix='model',
num_classes=10,
**kwargs):
"""Creates the convolutional SVHN model.
Note that this model implements the architecture for MNIST proposed in:
Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
JMLR 2015
Args:
images: the SVHN digits, a tensor of size [batch_size, 32, 32, 3].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
**kwargs: Placeholder for keyword arguments used by other shared encoders.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [3, 3], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 64, [5, 5], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [3, 3], 2, scope='pool2')
end_points['conv3'] = slim.conv2d(
end_points['pool2'], 128, [5, 5], scope='conv3')
end_points['fc3'] = slim.fully_connected(
slim.flatten(end_points['conv3']), 3072, scope='fc3')
end_points['fc4'] = slim.fully_connected(
slim.flatten(end_points['fc3']), 2048, scope='fc4')
logits = slim.fully_connected(
end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, end_points
def dann_gtsrb(images,
weight_decay=0.0,
prefix='model',
num_classes=43,
**kwargs):
"""Creates the convolutional GTSRB model.
Note that this model implements the architecture for MNIST proposed in:
Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
JMLR 2015
Args:
images: the GTSRB images, a tensor of size [batch_size, 40, 40, 3].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
**kwargs: Placeholder for keyword arguments used by other shared encoders.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [2, 2], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 144, [3, 3], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [2, 2], 2, scope='pool2')
end_points['conv3'] = slim.conv2d(
end_points['pool2'], 256, [5, 5], scope='conv3')
end_points['pool3'] = slim.max_pool2d(
end_points['conv3'], [2, 2], 2, scope='pool3')
end_points['fc3'] = slim.fully_connected(
slim.flatten(end_points['pool3']), 512, scope='fc3')
logits = slim.fully_connected(
end_points['fc3'], num_classes, activation_fn=None, scope='fc4')
return logits, end_points
def dsn_cropped_linemod(images,
weight_decay=0.0,
prefix='model',
num_classes=11,
batch_norm_params=None,
is_training=False):
"""Creates the convolutional pose estimation model for Cropped Linemod.
Args:
images: the Cropped Linemod samples, a tensor of size
[batch_size, 64, 64, 4].
weight_decay: the value for the weight decay coefficient.
prefix: name of the model to use when prefixing tags.
num_classes: the number of output classes to use.
batch_norm_params: a dictionary that maps batch norm parameter names to
values.
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
end_points = {}
tf.summary.image('{}/input_images'.format(prefix), images)
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm if batch_norm_params else None,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.conv2d], padding='SAME'):
end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
end_points['pool1'] = slim.max_pool2d(
end_points['conv1'], [2, 2], 2, scope='pool1')
end_points['conv2'] = slim.conv2d(
end_points['pool1'], 64, [5, 5], scope='conv2')
end_points['pool2'] = slim.max_pool2d(
end_points['conv2'], [2, 2], 2, scope='pool2')
net = slim.flatten(end_points['pool2'])
end_points['fc3'] = slim.fully_connected(net, 128, scope='fc3')
net = slim.dropout(
end_points['fc3'], 0.5, is_training=is_training, scope='dropout')
with tf.variable_scope('quaternion_prediction'):
predicted_quaternion = slim.fully_connected(
net, 4, activation_fn=tf.nn.tanh)
predicted_quaternion = tf.nn.l2_normalize(predicted_quaternion, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc4')
end_points['quaternion_pred'] = predicted_quaternion
return logits, end_points
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DSN components."""
import numpy as np
import tensorflow as tf
#from models.domain_adaptation.domain_separation
import models
class SharedEncodersTest(tf.test.TestCase):
def _testSharedEncoder(self,
input_shape=[5, 28, 28, 1],
model=models.dann_mnist,
is_training=True):
images = tf.to_float(np.random.rand(*input_shape))
with self.test_session() as sess:
logits, _ = model(images)
sess.run(tf.global_variables_initializer())
logits_np = sess.run(logits)
return logits_np
def testBuildGRLMnistModel(self):
logits = self._testSharedEncoder(model=getattr(models,
'dann_mnist'))
self.assertEqual(logits.shape, (5, 10))
self.assertTrue(np.any(logits))
def testBuildGRLSvhnModel(self):
logits = self._testSharedEncoder(model=getattr(models,
'dann_svhn'))
self.assertEqual(logits.shape, (5, 10))
self.assertTrue(np.any(logits))
def testBuildGRLGtsrbModel(self):
logits = self._testSharedEncoder([5, 40, 40, 3],
getattr(models, 'dann_gtsrb'))
self.assertEqual(logits.shape, (5, 43))
self.assertTrue(np.any(logits))
def testBuildPoseModel(self):
logits = self._testSharedEncoder([5, 64, 64, 4],
getattr(models, 'dsn_cropped_linemod'))
self.assertEqual(logits.shape, (5, 11))
self.assertTrue(np.any(logits))
def testBuildPoseModelWithBatchNorm(self):
images = tf.to_float(np.random.rand(10, 64, 64, 4))
with self.test_session() as sess:
logits, _ = getattr(models, 'dsn_cropped_linemod')(
images, batch_norm_params=models.default_batch_norm_params(True))
sess.run(tf.global_variables_initializer())
logits_np = sess.run(logits)
self.assertEqual(logits_np.shape, (10, 11))
self.assertTrue(np.any(logits_np))
class EncoderTest(tf.test.TestCase):
def _testEncoder(self, batch_norm_params=None, channels=1):
images = tf.to_float(np.random.rand(10, 28, 28, channels))
with self.test_session() as sess:
end_points = models.default_encoder(
images, 128, batch_norm_params=batch_norm_params)
sess.run(tf.global_variables_initializer())
private_code = sess.run(end_points['fc3'])
self.assertEqual(private_code.shape, (10, 128))
self.assertTrue(np.any(private_code))
self.assertTrue(np.all(np.isfinite(private_code)))
def testEncoder(self):
self._testEncoder()
def testEncoderMultiChannel(self):
self._testEncoder(None, 4)
def testEncoderIsTrainingBatchNorm(self):
self._testEncoder(models.default_batch_norm_params(True))
def testEncoderBatchNorm(self):
self._testEncoder(models.default_batch_norm_params(False))
class DecoderTest(tf.test.TestCase):
def _testDecoder(self,
height=64,
width=64,
channels=4,
batch_norm_params=None,
decoder=models.small_decoder):
codes = tf.to_float(np.random.rand(32, 100))
with self.test_session() as sess:
output = decoder(
codes,
height=height,
width=width,
channels=channels,
batch_norm_params=batch_norm_params)
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
self.assertEqual(output_np.shape, (32, height, width, channels))
self.assertTrue(np.any(output_np))
self.assertTrue(np.all(np.isfinite(output_np)))
def testSmallDecoder(self):
self._testDecoder(28, 28, 4, None, getattr(models, 'small_decoder'))
def testSmallDecoderThreeChannels(self):
self._testDecoder(28, 28, 3)
def testSmallDecoderBatchNorm(self):
self._testDecoder(28, 28, 4, models.default_batch_norm_params(False))
def testSmallDecoderIsTrainingBatchNorm(self):
self._testDecoder(28, 28, 4, models.default_batch_norm_params(True))
def testLargeDecoder(self):
self._testDecoder(32, 32, 4, None, getattr(models, 'large_decoder'))
def testLargeDecoderThreeChannels(self):
self._testDecoder(32, 32, 3, None, getattr(models, 'large_decoder'))
def testLargeDecoderBatchNorm(self):
self._testDecoder(32, 32, 4,
models.default_batch_norm_params(False),
getattr(models, 'large_decoder'))
def testLargeDecoderIsTrainingBatchNorm(self):
self._testDecoder(32, 32, 4,
models.default_batch_norm_params(True),
getattr(models, 'large_decoder'))
def testGtsrbDecoder(self):
self._testDecoder(40, 40, 3, None, getattr(models, 'large_decoder'))
def testGtsrbDecoderBatchNorm(self):
self._testDecoder(40, 40, 4,
models.default_batch_norm_params(False),
getattr(models, 'gtsrb_decoder'))
def testGtsrbDecoderIsTrainingBatchNorm(self):
self._testDecoder(40, 40, 4,
models.default_batch_norm_params(True),
getattr(models, 'gtsrb_decoder'))
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Auxiliary functions for domain adaptation related losses.
"""
import math
import tensorflow as tf
def create_summaries(end_points, prefix='', max_images=3, use_op_name=False):
"""Creates a tf summary per endpoint.
If the endpoint is a 4 dimensional tensor it displays it as an image
otherwise if it is a two dimensional one it creates a histogram summary.
Args:
end_points: a dictionary of name, tf tensor pairs.
prefix: an optional string to prefix the summary with.
max_images: the maximum number of images to display per summary.
use_op_name: Use the op name as opposed to the shorter end_points key.
"""
for layer_name in end_points:
if use_op_name:
name = end_points[layer_name].op.name
else:
name = layer_name
if len(end_points[layer_name].get_shape().as_list()) == 4:
# if it's an actual image do not attempt to reshape it
if end_points[layer_name].get_shape().as_list()[-1] == 1 or end_points[
layer_name].get_shape().as_list()[-1] == 3:
visualization_image = end_points[layer_name]
else:
visualization_image = reshape_feature_maps(end_points[layer_name])
tf.summary.image(
'{}/{}'.format(prefix, name),
visualization_image,
max_outputs=max_images)
elif len(end_points[layer_name].get_shape().as_list()) == 3:
images = tf.expand_dims(end_points[layer_name], 3)
tf.summary.image(
'{}/{}'.format(prefix, name),
images,
max_outputs=max_images)
elif len(end_points[layer_name].get_shape().as_list()) == 2:
tf.summary.histogram('{}/{}'.format(prefix, name), end_points[layer_name])
def reshape_feature_maps(features_tensor):
"""Reshape activations for tf.summary.image visualization.
Arguments:
features_tensor: a tensor of activations with a square number of feature
maps, eg 4, 9, 16, etc.
Returns:
A composite image with all the feature maps that can be passed as an
argument to tf.summary.image.
"""
assert len(features_tensor.get_shape().as_list()) == 4
num_filters = features_tensor.get_shape().as_list()[-1]
assert num_filters > 0
num_filters_sqrt = math.sqrt(num_filters)
assert num_filters_sqrt.is_integer(
), 'Number of filters should be a square number but got {}'.format(
num_filters)
num_filters_sqrt = int(num_filters_sqrt)
conv_summary = tf.unstack(features_tensor, axis=3)
conv_one_row = tf.concat(axis=2, values=conv_summary[0:num_filters_sqrt])
ind = 1
conv_final = conv_one_row
for ind in range(1, num_filters_sqrt):
conv_one_row = tf.concat(axis=2,
values=conv_summary[
ind * num_filters_sqrt + 0:ind * num_filters_sqrt + num_filters_sqrt])
conv_final = tf.concat(
axis=1, values=[tf.squeeze(conv_final), tf.squeeze(conv_one_row)])
conv_final = tf.expand_dims(conv_final, -1)
return conv_final
def accuracy(predictions, labels):
"""Calculates the classificaton accuracy.
Args:
predictions: the predicted values, a tensor whose size matches 'labels'.
labels: the ground truth values, a tensor of any size.
Returns:
a tensor whose value on evaluation returns the total accuracy.
"""
return tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))
def compute_upsample_values(input_tensor, upsample_height, upsample_width):
"""Compute values for an upsampling op (ops.BatchCropAndResize).
Args:
input_tensor: image tensor with shape [batch, height, width, in_channels]
upsample_height: integer
upsample_width: integer
Returns:
grid_centers: tensor with shape [batch, 1]
crop_sizes: tensor with shape [batch, 1]
output_height: integer
output_width: integer
"""
batch, input_height, input_width, _ = input_tensor.shape
height_half = input_height / 2.
width_half = input_width / 2.
grid_centers = tf.constant(batch * [[height_half, width_half]])
crop_sizes = tf.constant(batch * [[input_height, input_width]])
output_height = input_height * upsample_height
output_width = input_width * upsample_width
return grid_centers, tf.to_float(crop_sizes), output_height, output_width
def compute_pairwise_distances(x, y):
"""Computes the squared pairwise Euclidean distances between x and y.
Args:
x: a tensor of shape [num_x_samples, num_features]
y: a tensor of shape [num_y_samples, num_features]
Returns:
a distance matrix of dimensions [num_x_samples, num_y_samples].
Raises:
ValueError: if the inputs do no matched the specified dimensions.
"""
if not len(x.get_shape()) == len(y.get_shape()) == 2:
raise ValueError('Both inputs should be matrices.')
if x.get_shape().as_list()[1] != y.get_shape().as_list()[1]:
raise ValueError('The number of features should be the same.')
norm = lambda x: tf.reduce_sum(tf.square(x), 1)
# By making the `inner' dimensions of the two matrices equal to 1 using
# broadcasting then we are essentially substracting every pair of rows
# of x and y.
# x will be num_samples x num_features x 1,
# and y will be 1 x num_features x num_samples (after broadcasting).
# After the substraction we will get a
# num_x_samples x num_features x num_y_samples matrix.
# The resulting dist will be of shape num_y_samples x num_x_samples.
# and thus we need to transpose it again.
return tf.transpose(norm(tf.expand_dims(x, 2) - tf.transpose(y)))
def gaussian_kernel_matrix(x, y, sigmas):
r"""Computes a Guassian Radial Basis Kernel between the samples of x and y.
We create a sum of multiple gaussian kernels each having a width sigma_i.
Args:
x: a tensor of shape [num_samples, num_features]
y: a tensor of shape [num_samples, num_features]
sigmas: a tensor of floats which denote the widths of each of the
gaussians in the kernel.
Returns:
A tensor of shape [num_samples{x}, num_samples{y}] with the RBF kernel.
"""
beta = 1. / (2. * (tf.expand_dims(sigmas, 1)))
dist = compute_pairwise_distances(x, y)
s = tf.matmul(beta, tf.reshape(dist, (1, -1)))
return tf.reshape(tf.reduce_sum(tf.exp(-s), 0), tf.shape(dist))
# Description:
# Contains code for domain-adaptation style transfer.
package(
default_visibility = [
":internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//domain_adaptation/...",
],
)
py_library(
name = "pixelda_preprocess",
srcs = ["pixelda_preprocess.py"],
deps = [
],
)
py_test(
name = "pixelda_preprocess_test",
srcs = ["pixelda_preprocess_test.py"],
deps = [
":pixelda_preprocess",
],
)
py_library(
name = "pixelda_model",
srcs = [
"pixelda_model.py",
"pixelda_task_towers.py",
"hparams.py",
],
deps = [
],
)
py_library(
name = "pixelda_utils",
srcs = ["pixelda_utils.py"],
deps = [
],
)
py_library(
name = "pixelda_losses",
srcs = ["pixelda_losses.py"],
deps = [
],
)
py_binary(
name = "pixelda_train",
srcs = ["pixelda_train.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
py_binary(
name = "pixelda_eval",
srcs = ["pixelda_eval.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
licenses(["notice"]) # Apache 2.0
py_binary(
name = "baseline_train",
srcs = ["baseline_train.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
py_binary(
name = "baseline_eval",
srcs = ["baseline_eval.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
The best baselines are obtainable via the following configuration:
## MNIST => MNIST_M
Accuracy:
MNIST-Train: 99.9
MNIST_M-Train: 63.9
MNIST_M-Valid: 63.9
MNIST_M-Test: 63.6
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 105,000
## MNIST => USPS
Accuracy:
MNIST-Train: 100.0
USPS-Train: 82.8
USPS-Valid: 82.8
USPS-Test: 78.9
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 22,000
## MNIST_M => MNIST
Accuracy:
MNIST_M-Train: 100
MNIST-Train: 98.5
MNIST-Valid: 98.5
MNIST-Test: 98.1
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 604,400
## MNIST_M => MNIST_M
Accuracy:
MNIST_M-Train: 100.0
MNIST_M-Valid: 96.6
MNIST_M-Test: 96.4
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 139,400
## USPS => USPS
Accuracy:
USPS-Train: 100.0
USPS-Valid: 100.0
USPS-Test: 96.5
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 67,000
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evals the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_string(
'checkpoint_dir', None, 'The location of the checkpoint files.')
flags.DEFINE_string(
'eval_dir', None, 'The directory where evaluation logs are written.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = 0.0
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
if not tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.MakeDirs(FLAGS.eval_dir)
with tf.Graph().as_default():
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.split_name,
FLAGS.dataset_dir)
num_classes = dataset.num_classes
num_samples = dataset.num_samples
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=False)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
#####################
# Define the losses #
#####################
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
predictions = tf.reshape(tf.argmax(logits, 1), shape=[-1])
class_labels = tf.argmax(labels['classes'], 1)
metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
'Mean_Loss':
tf.contrib.metrics.streaming_mean(total_loss),
'Accuracy':
tf.contrib.metrics.streaming_accuracy(predictions,
tf.reshape(
class_labels,
shape=[-1])),
'Recall_at_5':
tf.contrib.metrics.streaming_recall_at_k(logits, class_labels, 5),
})
tf.summary.histogram('outputs/Predictions', predictions)
tf.summary.histogram('outputs/Ground_Truth', class_labels)
for name, value in metrics_to_values.iteritems():
tf.summary.scalar(name, value)
num_batches = int(math.ceil(num_samples / float(FLAGS.batch_size)))
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=metrics_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('task', 0, 'The task ID.')
flags.DEFINE_integer('num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then '
'the parameters are handled locally by the worker.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_float('learning_rate', 0.001, 'The initial learning rate.')
flags.DEFINE_integer(
'learning_rate_decay_steps', 20000,
'The frequency, in steps, at which the learning rate is decayed.')
flags.DEFINE_float('learning_rate_decay_factor',
0.95,
'The factor with which the learning rate is decayed.')
flags.DEFINE_float('adam_beta1', 0.5, 'The beta1 value for the AdamOptimizer')
flags.DEFINE_float('weight_decay', 1e-5,
'The L2 coefficient on the model weights.')
flags.DEFINE_string(
'logdir', None, 'The location of the logs and checkpoints.')
flags.DEFINE_integer('save_interval_secs', 600,
'How often, in seconds, we save the model to disk.')
flags.DEFINE_integer('save_summaries_secs', 600,
'How often, in seconds, we compute the summaries.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_float(
'moving_average_decay', 0.9999,
'The amount of decay to use for moving averages.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = FLAGS.weight_decay
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
with tf.Graph().as_default():
with tf.device(
tf.train.replica_device_setter(FLAGS.num_ps_tasks, merge_devices=True)):
dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
FLAGS.split_name, FLAGS.dataset_dir)
num_classes = dataset.num_classes
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=True)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# preprocess_fn=preprocess_fn)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
# Define the losses
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
tf.summary.scalar('losses/Total_Loss', total_loss)
# Setup the moving averages
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, slim.get_or_create_global_step())
tf.add_to_collection(
tf.GraphKeys.UPDATE_OPS,
variable_averages.apply(moving_average_variables))
# Specify the optimization scheme:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
slim.get_or_create_global_step(),
FLAGS.learning_rate_decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.adam_beta1)
train_op = slim.learning.create_train_op(total_loss, optimizer)
slim.learning.train(
train_op,
FLAGS.logdir,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define model HParams."""
import tensorflow as tf
def create_hparams(hparam_string=None):
"""Create model hyperparameters. Parse nondefault from given string."""
hparams = tf.contrib.training.HParams(
# The name of the architecture to use.
arch='resnet',
lrelu_leakiness=0.2,
batch_norm_decay=0.9,
weight_decay=1e-5,
normal_init_std=0.02,
generator_kernel_size=3,
discriminator_kernel_size=3,
# Stop training after this many examples are processed
# If none, train indefinitely
num_training_examples=0,
# Apply data augmentation to datasets
# Applies only in training job
augment_source_images=False,
augment_target_images=False,
# Discriminator
# Number of filters in first layer of discriminator
num_discriminator_filters=64,
discriminator_conv_block_size=1, # How many convs to have at each size
discriminator_filter_factor=2.0, # Multiply # filters by this each layer
# Add gaussian noise with this stddev to every hidden layer of D
discriminator_noise_stddev=0.2, # lmetz: Start seeing results at >= 0.1
# If true, add this gaussian noise to input images to D as well
discriminator_image_noise=False,
discriminator_first_stride=1, # Stride in first conv of discriminator
discriminator_do_pooling=False, # If true, replace stride 2 with avg pool
discriminator_dropout_keep_prob=0.9, # keep probability for dropout
# DCGAN Generator
# Number of filters in generator decoder last layer (repeatedly halved
# from 1st layer)
num_decoder_filters=64,
# Number of filters in generator encoder 1st layer (repeatedly doubled
# after 1st layer)
num_encoder_filters=64,
# This is the shape to which the noise vector is projected (if we're
# transferring from noise).
# Write this way instead of [4, 4, 64] for hparam search flexibility
projection_shape_size=4,
projection_shape_channels=64,
# Indicates the method by which we enlarge the spatial representation
# of an image. Possible values include:
# - resize_conv: Performs a nearest neighbor resize followed by a conv.
# - conv2d_transpose: Performs a conv2d_transpose.
upsample_method='resize_conv',
# Visualization
summary_steps=500, # Output image summary every N steps
###################################
# Task Classifier Hyperparameters #
###################################
# Which task-specific prediction tower to use. Possible choices are:
# none: No task tower.
# doubling_pose_estimator: classifier + quaternion regressor.
# [conv + pool]* + FC
# Classifiers used in DSN paper:
# gtsrb: Classifier used for GTSRB
# svhn: Classifier used for SVHN
# mnist: Classifier used for MNIST
# pose_mini: Classifier + regressor used for pose_mini
task_tower='doubling_pose_estimator',
weight_decay_task_classifier=1e-5,
source_task_loss_weight=1.0,
transferred_task_loss_weight=1.0,
# Number of private layers in doubling_pose_estimator task tower
num_private_layers=2,
# The weight for the log quaternion loss we use for source and transferred
# samples of the cropped_linemod dataset.
# In the DSN work, 1/8 of the classifier weight worked well for our log
# quaternion loss
source_pose_weight=0.125 * 2.0,
transferred_pose_weight=0.125 * 1.0,
# If set to True, the style transfer network also attempts to change its
# weights to maximize the performance of the task tower. If set to False,
# then the style transfer network only attempts to change its weights to
# make the transferred images more likely according to the domain
# classifier.
task_tower_in_g_step=True,
task_loss_in_g_weight=1.0, # Weight of task loss in G
#########################################
# 'simple` generator arch model hparams #
#########################################
simple_num_conv_layers=1,
simple_conv_filters=8,
#########################
# Resnet Hyperparameters#
#########################
resnet_blocks=6, # Number of resnet blocks
resnet_filters=64, # Number of filters per conv in resnet blocks
# If true, add original input back to result of convolutions inside the
# resnet arch. If false, it turns into a simple stack of conv/relu/BN
# layers.
resnet_residuals=True,
#######################################
# The residual / interpretable model. #
#######################################
res_int_blocks=2, # The number of residual blocks.
res_int_convs=2, # The number of conv calls inside each block.
res_int_filters=64, # The number of filters used by each convolution.
####################
# Latent variables #
####################
# if true, then generate random noise and project to input for generator
noise_channel=True,
# The number of dimensions in the input noise vector.
noise_dims=10,
# If true, then one hot encode source image class and project as an
# additional channel for the input to generator. This gives the generator
# access to the class, which may help generation performance.
condition_on_source_class=False,
########################
# Loss Hyperparameters #
########################
domain_loss_weight=1.0,
style_transfer_loss_weight=1.0,
########################################################################
# Encourages the transferred images to be similar to the source images #
# using a configurable metric. #
########################################################################
# The weight of the loss function encouraging the source and transferred
# images to be similar. If set to 0, then the loss function is not used.
transferred_similarity_loss_weight=0.0,
# The type of loss used to encourage transferred and source image
# similarity. Valid values include:
# mpse: Mean Pairwise Squared Error
# mse: Mean Squared Error
# hinged_mse: Computes the mean squared error using squared differences
# greater than hparams.transferred_similarity_max_diff
# hinged_mae: Computes the mean absolute error using absolute
# differences greater than hparams.transferred_similarity_max_diff.
transferred_similarity_loss='mpse',
# The maximum allowable difference between the source and target images.
# This value is used, in effect, to produce a hinge loss. Note that the
# range of values should be between 0 and 1.
transferred_similarity_max_diff=0.4,
################################
# Optimization Hyperparameters #
################################
learning_rate=0.001,
batch_size=32,
lr_decay_steps=20000,
lr_decay_rate=0.95,
# Recomendation from the DCGAN paper:
adam_beta1=0.5,
clip_gradient_norm=5.0,
# The number of times we run the discriminator train_op in a row.
discriminator_steps=1,
# The number of times we run the generator train_op in a row.
generator_steps=1)
if hparam_string:
tf.logging.info('Parsing command line hparams: %s', hparam_string)
hparams.parse(hparam_string)
tf.logging.info('Final parsed hparams: %s', hparams.values())
return hparams
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evaluates the PIXELDA model.
-- Compiles the model for CPU.
$ bazel build -c opt third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Compile the model for GPU.
$ bazel build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Runs the training.
$ ./bazel-bin/third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation/pixelda_eval \
--source_dataset=mnist \
--target_dataset=mnist_m \
--dataset_dir=/tmp/datasets/ \
--alsologtostderr
-- Visualize the results.
$ bash learning/brain/tensorboard/tensorboard.sh \
--port 2222 --logdir=/tmp/pixelda/
"""
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/pixelda/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/pixelda/',
'Directory where the results are saved to.')
flags.DEFINE_integer('eval_interval_secs', 60,
'The frequency, in seconds, with which evaluation is run.')
flags.DEFINE_string('target_split_name', 'test',
'The name of the train/test split.')
flags.DEFINE_string('source_split_name', 'train', 'Split for source dataset.'
' Defaults to train.')
flags.DEFINE_string('source_dataset', 'mnist',
'The name of the source dataset.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string(
'dataset_dir',
'', # None,
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def run_eval(run_dir, checkpoint_dir, hparams):
"""Runs the eval loop.
Args:
run_dir: The directory where eval specific logs are placed
checkpoint_dir: The directory where the checkpoints are stored
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for checkpoint_path in slim.evaluation.checkpoints_iterator(
checkpoint_dir, FLAGS.eval_interval_secs):
with tf.Graph().as_default():
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name=FLAGS.target_split_name,
dataset_dir=FLAGS.dataset_dir)
target_images, target_labels = dataset_factory.provide_batch(
FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
target_labels['class'] = tf.argmax(target_labels['classes'], 1)
del target_labels['classes']
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name=FLAGS.source_split_name,
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Input and output datasets must have same number of classes')
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=False,
num_classes=num_target_classes)
#######################
# Metrics & Summaries #
#######################
names_to_values, names_to_updates = create_metrics(end_points,
source_labels,
target_labels, hparams)
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summarize_images(target_images, 'Target')
for name, value in names_to_values.iteritems():
tf.summary.scalar(name, value)
# Use the entire split by default
num_examples = target_dataset.num_samples
num_batches = math.ceil(num_examples / float(hparams.batch_size))
global_step = slim.get_or_create_global_step()
result = slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=checkpoint_path,
logdir=run_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
final_op=names_to_values)
def to_degrees(log_quaternion_loss):
"""Converts a log quaternion distance to an angle.
Args:
log_quaternion_loss: The log quaternion distance between two
unit quaternions (or a batch of pairs of quaternions).
Returns:
The angle in degrees of the implied angle-axis representation.
"""
return tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
def create_metrics(end_points, source_labels, target_labels, hparams):
"""Create metrics for the model.
Args:
end_points: A dictionary of end point name to tensor
source_labels: Labels for source images. batch_size x 1
target_labels: Labels for target images. batch_size x 1
hparams: The hyperparameters struct.
Returns:
Tuple of (names_to_values, names_to_updates), dictionaries that map a metric
name to its value and update op, respectively
"""
###########################################
# Evaluate the Domain Prediction Accuracy #
###########################################
batch_size = hparams.batch_size
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
('eval/Domain_Accuracy-Transferred'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points[
'transferred_domain_logits']))),
tf.zeros(batch_size, dtype=tf.int32)),
('eval/Domain_Accuracy-Target'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points['target_domain_logits']))),
tf.ones(batch_size, dtype=tf.int32))
})
################################
# Evaluate the task classifier #
################################
if 'source_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['source_task_logits'], 1),
source_labels['class'])
if 'transferred_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['transferred_task_logits'], 1),
source_labels['class'])
if 'target_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['target_task_logits'], 1),
target_labels['class'])
##########################################################################
# Pose data-specific losses.
##########################################################################
if 'quaternion' in source_labels.keys():
params = {}
params['use_logging'] = False
params['batch_size'] = batch_size
angle_loss_source = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'source_quaternion'], source_labels['quaternion'], params))
angle_loss_transferred = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'transferred_quaternion'], source_labels['quaternion'], params))
angle_loss_target = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'target_quaternion'], target_labels['quaternion'], params))
metric_name = 'eval/Angle_Loss-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_source)
metric_name = 'eval/Angle_Loss-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_transferred)
metric_name = 'eval/Angle_Loss-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_target)
return names_to_values, names_to_updates
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_eval(
run_dir=FLAGS.eval_dir,
checkpoint_dir=FLAGS.checkpoint_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the various loss functions in use by the PIXELDA model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
def add_domain_classifier_losses(end_points, hparams):
"""Adds losses related to the domain-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
hparams: The hyperparameters struct.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
if hparams.domain_loss_weight == 0:
tf.logging.info(
'Domain classifier loss weight is 0, so not creating losses.')
return 0
# The domain prediction loss is minimized with respect to the domain
# classifier features only. Its aim is to predict the domain of the images.
# Note: 1 = 'real image' label, 0 = 'fake image' label
transferred_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']),
logits=end_points['transferred_domain_logits'])
tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss)
target_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.ones_like(end_points['target_domain_logits']),
logits=end_points['target_domain_logits'])
tf.summary.scalar('Domain_loss_target', target_domain_loss)
# Compute the total domain loss:
total_domain_loss = transferred_domain_loss + target_domain_loss
total_domain_loss *= hparams.domain_loss_weight
tf.summary.scalar('Domain_loss_total', total_domain_loss)
return total_domain_loss
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(internal_dot_products, [
internal_dot_products,
tf.shape(internal_dot_products)
], 'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
def _quaternion_loss(labels, predictions, weight, batch_size, domain,
add_summaries):
"""Creates a Quaternion Loss.
Args:
labels: The true quaternions.
predictions: The predicted quaternions.
weight: A scalar weight.
batch_size: The size of the batches.
domain: The name of the domain from which the labels were taken.
add_summaries: Whether or not to add summaries for the losses.
Returns:
A `Tensor` representing the loss.
"""
assert domain in ['Source', 'Transferred']
params = {'use_logging': False, 'batch_size': batch_size}
loss = weight * log_quaternion_loss(labels, predictions, params)
if add_summaries:
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.histogram(
'Log_Quaternion_Loss_%s' % domain, loss, collections='losses')
tf.summary.scalar(
'Task_Quaternion_Loss_%s' % domain, loss, collections='losses')
return loss
def _add_task_specific_losses(end_points, source_labels, num_classes, hparams,
add_summaries=False):
"""Adds losses related to the task-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
add_summaries: Whether or not to add the summaries.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
# TODO(ddohan): Make sure the l2 regularization is added to the loss
one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes)
total_loss = 0
if 'source_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['source_task_logits'],
weights=hparams.source_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Source', loss)
total_loss += loss
if 'transferred_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['transferred_task_logits'],
weights=hparams.transferred_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Transferred', loss)
total_loss += loss
#########################
# Pose specific losses. #
#########################
if 'quaternion' in source_labels:
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['source_quaternion'],
hparams.source_pose_weight,
hparams.batch_size,
'Source',
add_summaries)
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['transferred_quaternion'],
hparams.transferred_pose_weight,
hparams.batch_size,
'Transferred',
add_summaries)
if add_summaries:
tf.summary.scalar('Task_Loss_Total', total_loss)
return total_loss
def _transferred_similarity_loss(reconstructions,
source_images,
weight=1.0,
method='mse',
max_diff=0.4,
name='similarity'):
"""Computes a loss encouraging similarity between source and transferred.
Args:
reconstructions: A `Tensor` of shape [batch_size, height, width, channels]
source_images: A `Tensor` of shape [batch_size, height, width, channels].
weight: Multiple similarity loss by this weight before returning
method: One of:
mpse = Mean Pairwise Squared Error
mse = Mean Squared Error
hinged_mse = Computes the mean squared error using squared differences
greater than hparams.transferred_similarity_max_diff
hinged_mae = Computes the mean absolute error using absolute
differences greater than hparams.transferred_similarity_max_diff.
max_diff: Maximum unpenalized difference for hinged losses
name: Identifying name to use for creating summaries
Returns:
A `Tensor` representing the transferred similarity loss.
Raises:
ValueError: if `method` is not recognized.
"""
if weight == 0:
return 0
source_channels = source_images.shape.as_list()[-1]
reconstruction_channels = reconstructions.shape.as_list()[-1]
# Convert grayscale source to RGB if target is RGB
if source_channels == 1 and reconstruction_channels != 1:
source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels])
if reconstruction_channels == 1 and source_channels != 1:
reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels])
if method == 'mpse':
reconstruction_similarity_loss_fn = (
tf.contrib.losses.mean_pairwise_squared_error)
elif method == 'masked_mpse':
def masked_mpse(predictions, labels, weight):
"""Masked mpse assuming we have a depth to create a mask from."""
assert labels.shape.as_list()[-1] == 4
mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99))
mask = tf.tile(mask, [1, 1, 1, 4])
predictions *= mask
labels *= mask
tf.image_summary('masked_pred', predictions)
tf.image_summary('masked_label', labels)
return tf.contrib.losses.mean_pairwise_squared_error(
predictions, labels, weight)
reconstruction_similarity_loss_fn = masked_mpse
elif method == 'mse':
reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error
elif method == 'hinged_mse':
def hinged_mse(predictions, labels, weight):
diffs = tf.square(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mse
elif method == 'hinged_mae':
def hinged_mae(predictions, labels, weight):
diffs = tf.abs(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mae
else:
raise ValueError('Unknown reconstruction loss %s' % method)
reconstruction_similarity_loss = reconstruction_similarity_loss_fn(
reconstructions, source_images, weight)
name = '%s_Similarity_(%s)' % (name, method)
tf.summary.scalar(name, reconstruction_similarity_loss)
return reconstruction_similarity_loss
def g_step_loss(source_images, source_labels, end_points, hparams, num_classes):
"""Configures the loss function which runs during the g-step.
Args:
source_images: A `Tensor` of shape [batch_size, height, width, channels].
source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys
are 'class' and 'quaternion'.
end_points: A map of the network end points.
hparams: The hyperparameters struct.
num_classes: Number of classes for classifier loss
Returns:
A `Tensor` representing a loss function.
Raises:
ValueError: if hparams.transferred_similarity_loss_weight is non-zero but
hparams.transferred_similarity_loss is invalid.
"""
generator_loss = 0
################################################################
# Adds a loss which encourages the discriminator probabilities #
# to be high (near one).
################################################################
# As per the GAN paper, maximize the log probs, instead of minimizing
# log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is
# the same thing.
style_transfer_loss = tf.losses.sigmoid_cross_entropy(
logits=end_points['transferred_domain_logits'],
multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']),
weights=hparams.style_transfer_loss_weight)
tf.summary.scalar('Style_transfer_loss', style_transfer_loss)
generator_loss += style_transfer_loss
# Optimizes the style transfer network to produce transferred images similar
# to the source images.
generator_loss += _transferred_similarity_loss(
end_points['transferred_images'],
source_images,
weight=hparams.transferred_similarity_loss_weight,
method=hparams.transferred_similarity_loss,
name='transferred_similarity')
# Optimizes the style transfer network to maximize classification accuracy.
if source_labels is not None and hparams.task_tower_in_g_step:
generator_loss += _add_task_specific_losses(
end_points, source_labels, num_classes,
hparams) * hparams.task_loss_in_g_weight
return generator_loss
def d_step_loss(end_points, source_labels, num_classes, hparams):
"""Configures the losses during the D-Step.
Note that during the D-step, the model optimizes both the domain (binary)
classifier and the task classifier.
Args:
end_points: A map of the network end points.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
Returns:
A `Tensor` representing the value of the D-step loss.
"""
domain_classifier_loss = add_domain_classifier_losses(end_points, hparams)
task_classifier_loss = 0
if source_labels is not None:
task_classifier_loss = _add_task_specific_losses(
end_points, source_labels, num_classes, hparams, add_summaries=True)
return domain_classifier_loss + task_classifier_loss
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the Domain Adaptation via Style Transfer (PixelDA) model components.
A number of details in the implementation make reference to one of the following
works:
- "Unsupervised Representation Learning with Deep Convolutional
Generative Adversarial Networks""
https://arxiv.org/abs/1511.06434
This paper makes several architecture recommendations:
1. Use strided convs in discriminator, fractional-strided convs in generator
2. batchnorm everywhere
3. remove fully connected layers for deep models
4. ReLu for all layers in generator, except tanh on output
5. LeakyReLu for everything in discriminator
"""
import functools
import math
# Dependency imports
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
def create_model(hparams,
target_images,
source_images=None,
source_labels=None,
is_training=False,
noise=None,
num_classes=None):
"""Create a GAN model.
Arguments:
hparams: HParam object specifying model params
target_images: A `Tensor` of size [batch_size, height, width, channels]. It
is assumed that the images are [-1, 1] normalized.
source_images: A `Tensor` of size [batch_size, height, width, channels]. It
is assumed that the images are [-1, 1] normalized.
source_labels: A `Tensor` of size [batch_size] of categorical labels between
[0, num_classes]
is_training: whether model is currently training
noise: If None, model generates its own noise. Otherwise use provided.
num_classes: Number of classes for classification
Returns:
end_points dict with model outputs
Raises:
ValueError: unknown hparams.arch setting
"""
if num_classes is None and hparams.arch in ['resnet', 'simple']:
raise ValueError('Num classes must be provided to create task classifier')
if target_images.dtype != tf.float32:
raise ValueError('target_images must be tf.float32 and [-1, 1] normalized.')
if source_images is not None and source_images.dtype != tf.float32:
raise ValueError('source_images must be tf.float32 and [-1, 1] normalized.')
###########################
# Create latent variables #
###########################
latent_vars = dict()
if hparams.noise_channel:
noise_shape = [hparams.batch_size, hparams.noise_dims]
if noise is not None:
assert noise.shape.as_list() == noise_shape
tf.logging.info('Using provided noise')
else:
tf.logging.info('Using random noise')
noise = tf.random_uniform(
shape=noise_shape,
minval=-1,
maxval=1,
dtype=tf.float32,
name='random_noise')
latent_vars['noise'] = noise
####################
# Create generator #
####################
with slim.arg_scope(
[slim.conv2d, slim.conv2d_transpose, slim.fully_connected],
normalizer_params=batch_norm_params(is_training,
hparams.batch_norm_decay),
weights_initializer=tf.random_normal_initializer(
stddev=hparams.normal_init_std),
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if hparams.arch == 'dcgan':
end_points = dcgan(
target_images, latent_vars, hparams, scope='generator')
elif hparams.arch == 'resnet':
end_points = resnet_generator(
source_images,
target_images.shape.as_list()[1:4],
hparams=hparams,
latent_vars=latent_vars)
elif hparams.arch == 'residual_interpretation':
end_points = residual_interpretation_generator(
source_images, is_training=is_training, hparams=hparams)
elif hparams.arch == 'simple':
end_points = simple_generator(
source_images,
target_images,
is_training=is_training,
hparams=hparams,
latent_vars=latent_vars)
elif hparams.arch == 'identity':
# Pass through unmodified, besides changing # channels
# Used to calculate baseline numbers
# Also set `generator_steps=0` for baseline
if hparams.generator_steps:
raise ValueError('Must set generator_steps=0 for identity arch. Is %s'
% hparams.generator_steps)
transferred_images = source_images
source_channels = source_images.shape.as_list()[-1]
target_channels = target_images.shape.as_list()[-1]
if source_channels == 1 and target_channels == 3:
transferred_images = tf.tile(source_images, [1, 1, 1, 3])
if source_channels == 3 and target_channels == 1:
transferred_images = tf.image.rgb_to_grayscale(source_images)
end_points = {'transferred_images': transferred_images}
else:
raise ValueError('Unknown architecture: %s' % hparams.arch)
#####################
# Domain Classifier #
#####################
if hparams.arch in [
'dcgan', 'resnet', 'residual_interpretation', 'simple', 'identity',
]:
# Add a discriminator for these architectures
end_points['transferred_domain_logits'] = predict_domain(
end_points['transferred_images'],
hparams,
is_training=is_training,
reuse=False)
end_points['target_domain_logits'] = predict_domain(
target_images,
hparams,
is_training=is_training,
reuse=True)
###################
# Task Classifier #
###################
if hparams.task_tower != 'none' and hparams.arch in [
'resnet', 'residual_interpretation', 'simple', 'identity',
]:
with tf.variable_scope('discriminator'):
with tf.variable_scope('task_tower'):
end_points['source_task_logits'], end_points[
'source_quaternion'] = pixelda_task_towers.add_task_specific_model(
source_images,
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=False,
private_scope='source_task_classifier',
reuse_shared=False)
end_points['transferred_task_logits'], end_points[
'transferred_quaternion'] = (
pixelda_task_towers.add_task_specific_model(
end_points['transferred_images'],
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=False,
private_scope='transferred_task_classifier',
reuse_shared=True))
end_points['target_task_logits'], end_points[
'target_quaternion'] = pixelda_task_towers.add_task_specific_model(
target_images,
hparams,
num_classes=num_classes,
is_training=is_training,
reuse_private=True,
private_scope='transferred_task_classifier',
reuse_shared=True)
# Remove any endpoints with None values
return dict((k, v) for k, v in end_points.iteritems() if v is not None)
def batch_norm_params(is_training, batch_norm_decay):
return {
'is_training': is_training,
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
}
def lrelu(x, leakiness=0.2):
"""Relu, with optional leaky support."""
return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
def upsample(net, num_filters, scale=2, method='resize_conv', scope=None):
"""Performs spatial upsampling of the given features.
Args:
net: A `Tensor` of shape [batch_size, height, width, filters].
num_filters: The number of output filters.
scale: The scale of the upsampling. Must be a positive integer greater or
equal to two.
method: The method by which the features are upsampled. Valid options
include 'resize_conv' and 'conv2d_transpose'.
scope: An optional variable scope.
Returns:
A new set of features of shape
[batch_size, height*scale, width*scale, num_filters].
Raises:
ValueError: if `method` is not valid or
"""
if scale < 2:
raise ValueError('scale must be greater or equal to two.')
with tf.variable_scope(scope, 'upsample', [net]):
if method == 'resize_conv':
net = tf.image.resize_nearest_neighbor(
net, [net.shape.as_list()[1] * scale,
net.shape.as_list()[2] * scale],
align_corners=True,
name='resize')
return slim.conv2d(net, num_filters, stride=1, scope='conv')
elif method == 'conv2d_transpose':
return slim.conv2d_transpose(net, num_filters, scope='deconv')
else:
raise ValueError('Upsample method [%s] was not recognized.' % method)
def project_latent_vars(hparams, proj_shape, latent_vars, combine_method='sum'):
"""Generate noise and project to input volume size.
Args:
hparams: The hyperparameter HParams struct.
proj_shape: Shape to project noise (not including batch size).
latent_vars: dictionary of `'key': Tensor of shape [batch_size, N]`
combine_method: How to combine the projected values.
sum = project to volume then sum
concat = concatenate along last dimension (i.e. channel)
Returns:
If combine_method=sum, a `Tensor` of size `hparams.projection_shape`
If combine_method=concat and there are N latent vars, a `Tensor` of size
`hparams.projection_shape`, with the last channel multiplied by N
Raises:
ValueError: combine_method is not one of sum/concat
"""
values = []
for var in latent_vars:
with tf.variable_scope(var):
# Project & reshape noise to a HxWxC input
projected = slim.fully_connected(
latent_vars[var],
np.prod(proj_shape),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm)
values.append(tf.reshape(projected, [hparams.batch_size] + proj_shape))
if combine_method == 'sum':
result = values[0]
for value in values[1:]:
result += value
elif combine_method == 'concat':
# Concatenate along last axis
result = tf.concat(values, len(proj_shape))
else:
raise ValueError('Unknown combine_method %s' % combine_method)
tf.logging.info('Latent variables projected to size %s volume', result.shape)
return result
def resnet_block(net, hparams):
"""Create a resnet block."""
net_in = net
net = slim.conv2d(
net,
hparams.resnet_filters,
stride=1,
normalizer_fn=slim.batch_norm,
activation_fn=tf.nn.relu)
net = slim.conv2d(
net,
hparams.resnet_filters,
stride=1,
normalizer_fn=slim.batch_norm,
activation_fn=None)
if hparams.resnet_residuals:
net += net_in
return net
def resnet_stack(images, output_shape, hparams, scope=None):
"""Create a resnet style transfer block.
Args:
images: [batch-size, height, width, channels] image tensor to feed as input
output_shape: output image shape in form [height, width, channels]
hparams: hparams objects
scope: Variable scope
Returns:
Images after processing with resnet blocks.
"""
end_points = {}
if hparams.noise_channel:
# separate the noise for visualization
end_points['noise'] = images[:, :, :, -1]
assert images.shape.as_list()[1:3] == output_shape[0:2]
with tf.variable_scope(scope, 'resnet_style_transfer', [images]):
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=slim.batch_norm,
kernel_size=[hparams.generator_kernel_size] * 2,
stride=1):
net = slim.conv2d(
images,
hparams.resnet_filters,
normalizer_fn=None,
activation_fn=tf.nn.relu)
for block in range(hparams.resnet_blocks):
net = resnet_block(net, hparams)
end_points['resnet_block_{}'.format(block)] = net
net = slim.conv2d(
net,
output_shape[-1],
kernel_size=[1, 1],
normalizer_fn=None,
activation_fn=tf.nn.tanh,
scope='conv_out')
end_points['transferred_images'] = net
return net, end_points
def predict_domain(images,
hparams,
is_training=False,
reuse=False,
scope='discriminator'):
"""Creates a discriminator for a GAN.
Args:
images: A `Tensor` of size [batch_size, height, width, channels]. It is
assumed that the images are centered between -1 and 1.
hparams: hparam object with params for discriminator
is_training: Specifies whether or not we're training or testing.
reuse: Whether to reuse variable scope
scope: An optional variable_scope.
Returns:
[batch size, 1] - logit output of discriminator.
"""
with tf.variable_scope(scope, 'discriminator', [images], reuse=reuse):
lrelu_partial = functools.partial(lrelu, leakiness=hparams.lrelu_leakiness)
with slim.arg_scope(
[slim.conv2d],
kernel_size=[hparams.discriminator_kernel_size] * 2,
activation_fn=lrelu_partial,
stride=2,
normalizer_fn=slim.batch_norm):
def add_noise(hidden, scope_num=None):
if scope_num:
hidden = slim.dropout(
hidden,
hparams.discriminator_dropout_keep_prob,
is_training=is_training,
scope='dropout_%s' % scope_num)
if hparams.discriminator_noise_stddev == 0:
return hidden
return hidden + tf.random_normal(
hidden.shape.as_list(),
mean=0.0,
stddev=hparams.discriminator_noise_stddev)
# As per the recommendation of the DCGAN paper, we don't use batch norm
# on the discriminator input (https://arxiv.org/pdf/1511.06434v2.pdf).
if hparams.discriminator_image_noise:
images = add_noise(images)
net = slim.conv2d(
images,
hparams.num_discriminator_filters,
normalizer_fn=None,
stride=hparams.discriminator_first_stride,
scope='conv1_stride%s' % hparams.discriminator_first_stride)
net = add_noise(net, 1)
block_id = 2
# Repeatedly stack
# discriminator_conv_block_size-1 conv layers with stride 1
# followed by a stride 2 layer
# Add (optional) noise at every point
while net.shape.as_list()[1] > hparams.projection_shape_size:
num_filters = int(hparams.num_discriminator_filters *
(hparams.discriminator_filter_factor**(block_id - 1)))
for conv_id in range(1, hparams.discriminator_conv_block_size):
net = slim.conv2d(
net,
num_filters,
stride=1,
scope='conv_%s_%s' % (block_id, conv_id))
if hparams.discriminator_do_pooling:
net = slim.conv2d(
net, num_filters, scope='conv_%s_prepool' % block_id)
net = slim.avg_pool2d(
net, kernel_size=[2, 2], stride=2, scope='pool_%s' % block_id)
else:
net = slim.conv2d(
net, num_filters, scope='conv_%s_stride2' % block_id)
net = add_noise(net, block_id)
block_id += 1
net = slim.flatten(net)
net = slim.fully_connected(
net,
1,
# Models with BN here generally produce noise
normalizer_fn=None,
activation_fn=None,
scope='fc_logit_out') # Returns logits!
return net
def dcgan_generator(images, output_shape, hparams, scope=None):
"""Transforms the visual style of the input images.
Args:
images: A `Tensor` of shape [batch_size, height, width, channels].
output_shape: A list or tuple of 3 elements: the output height, width and
number of channels.
hparams: hparams object with generator parameters
scope: Scope to place generator inside
Returns:
A `Tensor` of shape [batch_size, height, width, output_channels] which
represents the result of style transfer.
Raises:
ValueError: If `output_shape` is not a list or tuple or if it doesn't have
three elements or if `output_shape` or `images` arent square.
"""
if not isinstance(output_shape, (tuple, list)):
raise ValueError('output_shape must be a tuple or list.')
elif len(output_shape) != 3:
raise ValueError('output_shape must have three elements.')
if output_shape[0] != output_shape[1]:
raise ValueError('output_shape must be square')
if images.shape.as_list()[1] != images.shape.as_list()[2]:
raise ValueError('images height and width must match.')
outdim = output_shape[0]
indim = images.shape.as_list()[1]
num_iterations = int(math.ceil(math.log(float(outdim) / float(indim), 2.0)))
with slim.arg_scope(
[slim.conv2d, slim.conv2d_transpose],
kernel_size=[hparams.generator_kernel_size] * 2,
stride=2):
with tf.variable_scope(scope or 'generator'):
net = images
# Repeatedly halve # filters until = hparams.decode_filters in last layer
for i in range(num_iterations):
num_filters = hparams.num_decoder_filters * 2**(num_iterations - i - 1)
net = slim.conv2d_transpose(net, num_filters, scope='deconv_%s' % i)
# Crop down to desired size (e.g. 32x32 -> 28x28)
dif = net.shape.as_list()[1] - outdim
low = dif / 2
high = net.shape.as_list()[1] - low
net = net[:, low:high, low:high, :]
# No batch norm on generator output
net = slim.conv2d(
net,
output_shape[2],
kernel_size=[1, 1],
stride=1,
normalizer_fn=None,
activation_fn=tf.tanh,
scope='conv_out')
return net
def dcgan(target_images, latent_vars, hparams, scope='dcgan'):
"""Creates the PixelDA model.
Args:
target_images: A `Tensor` of shape [batch_size, height, width, 3]
sampled from the image domain to which we want to transfer.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
hparams: The hyperparameter map.
scope: Surround generator component with this scope
Returns:
A dictionary of model outputs.
"""
proj_shape = [
hparams.projection_shape_size, hparams.projection_shape_size,
hparams.projection_shape_channels
]
source_volume = project_latent_vars(
hparams, proj_shape, latent_vars, combine_method='concat')
###################################################
# Transfer the source images to the target style. #
###################################################
with tf.variable_scope(scope, 'generator', [target_images]):
transferred_images = dcgan_generator(
source_volume,
output_shape=target_images.shape.as_list()[1:4],
hparams=hparams)
assert transferred_images.shape.as_list() == target_images.shape.as_list()
return {'transferred_images': transferred_images}
def resnet_generator(images, output_shape, hparams, latent_vars=None):
"""Creates a ResNet-based generator.
Args:
images: A `Tensor` of shape [batch_size, height, width, num_channels]
sampled from the image domain from which we want to transfer
output_shape: A length-3 array indicating the height, width and channels of
the output.
hparams: The hyperparameter map.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
Returns:
A dictionary of model outputs.
"""
with tf.variable_scope('generator'):
if latent_vars:
noise_channel = project_latent_vars(
hparams,
proj_shape=images.shape.as_list()[1:3] + [1],
latent_vars=latent_vars,
combine_method='concat')
images = tf.concat([images, noise_channel], 3)
transferred_images, end_points = resnet_stack(
images,
output_shape=output_shape,
hparams=hparams,
scope='resnet_stack')
end_points['transferred_images'] = transferred_images
return end_points
def residual_interpretation_block(images, hparams, scope):
"""Learns a residual image which is added to the incoming image.
Args:
images: A `Tensor` of size [batch_size, height, width, 3]
hparams: The hyperparameters struct.
scope: The name of the variable op scope.
Returns:
The updated images.
"""
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=None,
kernel_size=[hparams.generator_kernel_size] * 2):
net = images
for _ in range(hparams.res_int_convs):
net = slim.conv2d(
net, hparams.res_int_filters, activation_fn=tf.nn.relu)
net = slim.conv2d(net, 3, activation_fn=tf.nn.tanh)
# Add the residual
images += net
# Clip the output
images = tf.maximum(images, -1.0)
images = tf.minimum(images, 1.0)
return images
def residual_interpretation_generator(images,
is_training,
hparams,
latent_vars=None):
"""Creates a generator producing purely residual transformations.
A residual generator differs from the resnet generator in that each 'block' of
the residual generator produces a residual image. Consequently, the 'progress'
of the model generation process can be directly observed at inference time,
making it easier to diagnose and understand.
Args:
images: A `Tensor` of shape [batch_size, height, width, num_channels]
sampled from the image domain from which we want to transfer. It is
assumed that the images are centered between -1 and 1.
is_training: whether or not the model is training.
hparams: The hyperparameter map.
latent_vars: dictionary of 'key': Tensor of shape [batch_size, N]
Returns:
A dictionary of model outputs.
"""
end_points = {}
with tf.variable_scope('generator'):
if latent_vars:
projected_latent = project_latent_vars(
hparams,
proj_shape=images.shape.as_list()[1:3] + [images.shape.as_list()[-1]],
latent_vars=latent_vars,
combine_method='sum')
images += projected_latent
with tf.variable_scope(None, 'residual_style_transfer', [images]):
for i in range(hparams.res_int_blocks):
images = residual_interpretation_block(images, hparams,
'residual_%d' % i)
end_points['transferred_images_%d' % i] = images
end_points['transferred_images'] = images
return end_points
def simple_generator(source_images, target_images, is_training, hparams,
latent_vars):
"""Simple generator architecture (stack of convs) for trying small models."""
end_points = {}
with tf.variable_scope('generator'):
feed_source_images = source_images
if latent_vars:
projected_latent = project_latent_vars(
hparams,
proj_shape=source_images.shape.as_list()[1:3] + [1],
latent_vars=latent_vars,
combine_method='concat')
feed_source_images = tf.concat([source_images, projected_latent], 3)
end_points = {}
###################################################
# Transfer the source images to the target style. #
###################################################
with slim.arg_scope(
[slim.conv2d],
normalizer_fn=slim.batch_norm,
stride=1,
kernel_size=[hparams.generator_kernel_size] * 2):
net = feed_source_images
# N convolutions
for i in range(1, hparams.simple_num_conv_layers):
normalizer_fn = None
if i != 0:
normalizer_fn = slim.batch_norm
net = slim.conv2d(
net,
hparams.simple_conv_filters,
normalizer_fn=normalizer_fn,
activation_fn=tf.nn.relu)
# Project back to right # image channels
net = slim.conv2d(
net,
target_images.shape.as_list()[-1],
kernel_size=[1, 1],
stride=1,
normalizer_fn=None,
activation_fn=tf.tanh,
scope='conv_out')
transferred_images = net
assert transferred_images.shape.as_list() == target_images.shape.as_list()
end_points['transferred_images'] = transferred_images
return end_points
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains functions for preprocessing the inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
def preprocess_classification(image, labels, is_training=False):
"""Preprocesses the image and labels for classification purposes.
Preprocessing includes shifting the images to be 0-centered between -1 and 1.
This is not only a popular method of preprocessing (inception) but is also
the mechanism used by DSNs.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
is_training: Whether or not we're training the model.
Returns:
The preprocessed image and labels.
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
return image, labels
def preprocess_style_transfer(image,
labels,
augment=False,
size=None,
is_training=False):
"""Preprocesses the image and labels for style transfer purposes.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
augment: Whether to apply data augmentation to inputs
size: The height and width to which images should be resized. If left as
`None`, then no resizing is performed
is_training: Whether or not we're training the model
Returns:
The preprocessed image and labels. Scaled to [-1, 1]
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
if augment and is_training:
image = image_augmentation(image)
if size:
image = resize_image(image, size)
image -= 0.5
image *= 2
return image, labels
def image_augmentation(image):
"""Performs data augmentation by randomly permuting the inputs.
Args:
image: A float `Tensor` of size [height, width, channels] with values
in range[0,1].
Returns:
The mutated batch of images
"""
# Apply photometric data augmentation (contrast etc.)
num_channels = image.shape_as_list()[-1]
if num_channels == 4:
# Only augment image part
image, depth = image[:, :, 0:3], image[:, :, 3:4]
elif num_channels == 1:
image = tf.image.grayscale_to_rgb(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.clip_by_value(image, 0, 1.0)
if num_channels == 4:
image = tf.concat(2, [image, depth])
elif num_channels == 1:
image = tf.image.rgb_to_grayscale(image)
return image
def resize_image(image, size=None):
"""Resize image to target size.
Args:
image: A `Tensor` of size [height, width, 3].
size: (height, width) to resize image to.
Returns:
resized image
"""
if size is None:
raise ValueError('Must specify size')
if image.shape_as_list()[:2] == size:
# Don't resize if not necessary
return image
image = tf.expand_dims(image, 0)
image = tf.image.resize_images(image, size)
image = tf.squeeze(image, 0)
return image
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for domain_adaptation.pixel_domain_adaptation.pixelda_preprocess."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
class PixelDAPreprocessTest(tf.test.TestCase):
def assert_preprocess_classification_is_centered(self, dtype, is_training):
tf.set_random_seed(0)
if dtype == tf.uint8:
image = tf.random_uniform((100, 200, 3), maxval=255, dtype=tf.int64)
image = tf.cast(image, tf.uint8)
else:
image = tf.random_uniform((100, 200, 3), maxval=1.0, dtype=dtype)
labels = {}
image, labels = pixelda_preprocess.preprocess_classification(
image, labels, is_training=is_training)
with self.test_session() as sess:
np_image = sess.run(image)
self.assertTrue(np_image.min() <= -0.95)
self.assertTrue(np_image.min() >= -1.0)
self.assertTrue(np_image.max() >= 0.95)
self.assertTrue(np_image.max() <= 1.0)
def testPreprocessClassificationZeroCentersUint8DuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=True)
def testPreprocessClassificationZeroCentersUint8DuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=False)
def testPreprocessClassificationZeroCentersFloatDuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=True)
def testPreprocessClassificationZeroCentersFloatDuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task towers for PixelDA model."""
import tensorflow as tf
slim = tf.contrib.slim
def add_task_specific_model(images,
hparams,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope=None):
"""Create a classifier for the given images.
The classifier is composed of a few 'private' layers followed by a few
'shared' layers. This lets us account for different image 'style', while
sharing the last few layers as 'content' layers.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
hparams: model hparams
num_classes: The number of output classes.
is_training: whether model is training
reuse_private: Whether or not to reuse the private weights, which are the
first few layers in the classifier
private_scope: The name of the variable_scope for the private (unshared)
components of the classifier.
reuse_shared: Whether or not to reuse the shared weights, which are the last
few layers in the classifier
shared_scope: The name of the variable_scope for the shared components of
the classifier.
Returns:
The logits, a `Tensor` of shape [batch_size, num_classes].
Raises:
ValueError: If hparams.task_classifier is an unknown value
"""
model = hparams.task_tower
# Make sure the classifier name shows up in graph
shared_scope = shared_scope or (model + '_shared')
kwargs = {
'num_classes': num_classes,
'is_training': is_training,
'reuse_private': reuse_private,
'reuse_shared': reuse_shared,
}
if private_scope:
kwargs['private_scope'] = private_scope
if shared_scope:
kwargs['shared_scope'] = shared_scope
quaternion_pred = None
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay_task_classifier)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if model == 'doubling_pose_estimator':
logits, quaternion_pred = doubling_cnn_class_and_quaternion(
images, num_private_layers=hparams.num_private_layers, **kwargs)
elif model == 'mnist':
logits, _ = mnist_classifier(images, **kwargs)
elif model == 'svhn':
logits, _ = svhn_classifier(images, **kwargs)
elif model == 'gtsrb':
logits, _ = gtsrb_classifier(images, **kwargs)
elif model == 'pose_mini':
logits, quaternion_pred = pose_mini_tower(images, **kwargs)
else:
raise ValueError('Unknown task classifier %s' % model)
return logits, quaternion_pred
#####################################
# Classifiers used in the DSN paper #
#####################################
def mnist_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope='mnist',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional MNIST model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits, endpoints = conv_mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 48, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool2']), 100, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 100, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def svhn_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional SVHN model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [3, 3], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 64, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [3, 3], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 128, [5, 5], scope='conv3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['conv3']), 3072, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 2048, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def gtsrb_classifier(images,
is_training=False,
num_classes=43,
reuse_private=False,
private_scope='gtsrb',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional GTSRB model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
reuse_private: Whether or not to reuse the private components of the model.
private_scope: The name of the private scope.
reuse_shared: Whether or not to reuse the shared components of the model.
shared_scope: The name of the shared scope.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 144, [3, 3], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 256, [5, 5], scope='conv3')
net['pool3'] = slim.max_pool2d(net['conv3'], [2, 2], 2, scope='pool3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool3']), 512, scope='fc3')
logits = slim.fully_connected(
net['fc3'], num_classes, activation_fn=None, scope='fc4')
return logits, net
#########################
# pose_mini task towers #
#########################
def pose_mini_tower(images,
num_classes=11,
is_training=False,
reuse_private=False,
private_scope='pose_mini',
reuse_shared=False,
shared_scope='task_model'):
"""Task tower for the pose_mini dataset."""
with tf.variable_scope(private_scope, reuse=reuse_private):
net = slim.conv2d(images, 32, [5, 5], scope='conv1')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2')
net = slim.flatten(net)
net = slim.fully_connected(net, 128, scope='fc3')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
with tf.variable_scope('quaternion_prediction'):
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc4')
return logits, quaternion_pred
def doubling_cnn_class_and_quaternion(images,
num_private_layers=1,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope='doubling_cnn',
reuse_shared=False,
shared_scope='task_model'):
"""Alternate conv, pool while doubling filter count."""
net = images
depth = 32
layer_id = 1
with tf.variable_scope(private_scope, reuse=reuse_private):
while num_private_layers > 0 and net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
num_private_layers -= 1
with tf.variable_scope(shared_scope, reuse=reuse_shared):
while net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
net = slim.flatten(net)
net = slim.fully_connected(net, 100, scope='fc1')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc_logits')
return logits, quaternion_pred
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment