"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "d93322e85351853841d897f8018c01b9165b0285"
Commit dff0f0c1 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Merge branch 'master' of github.com:tensorflow/models

parents da341f70 36203f09
licenses(["notice"]) # Apache 2.0
py_binary(
name = "baseline_train",
srcs = ["baseline_train.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
py_binary(
name = "baseline_eval",
srcs = ["baseline_eval.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
The best baselines are obtainable via the following configuration:
## MNIST => MNIST_M
Accuracy:
MNIST-Train: 99.9
MNIST_M-Train: 63.9
MNIST_M-Valid: 63.9
MNIST_M-Test: 63.6
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 105,000
## MNIST => USPS
Accuracy:
MNIST-Train: 100.0
USPS-Train: 82.8
USPS-Valid: 82.8
USPS-Test: 78.9
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 22,000
## MNIST_M => MNIST
Accuracy:
MNIST_M-Train: 100
MNIST-Train: 98.5
MNIST-Valid: 98.5
MNIST-Test: 98.1
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 604,400
## MNIST_M => MNIST_M
Accuracy:
MNIST_M-Train: 100.0
MNIST_M-Valid: 96.6
MNIST_M-Test: 96.4
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 139,400
## USPS => USPS
Accuracy:
USPS-Train: 100.0
USPS-Valid: 100.0
USPS-Test: 96.5
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 67,000
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evals the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_string(
'checkpoint_dir', None, 'The location of the checkpoint files.')
flags.DEFINE_string(
'eval_dir', None, 'The directory where evaluation logs are written.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = 0.0
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
if not tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.MakeDirs(FLAGS.eval_dir)
with tf.Graph().as_default():
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.split_name,
FLAGS.dataset_dir)
num_classes = dataset.num_classes
num_samples = dataset.num_samples
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=False)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
#####################
# Define the losses #
#####################
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
predictions = tf.reshape(tf.argmax(logits, 1), shape=[-1])
class_labels = tf.argmax(labels['classes'], 1)
metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
'Mean_Loss':
tf.contrib.metrics.streaming_mean(total_loss),
'Accuracy':
tf.contrib.metrics.streaming_accuracy(predictions,
tf.reshape(
class_labels,
shape=[-1])),
'Recall_at_5':
tf.contrib.metrics.streaming_recall_at_k(logits, class_labels, 5),
})
tf.summary.histogram('outputs/Predictions', predictions)
tf.summary.histogram('outputs/Ground_Truth', class_labels)
for name, value in metrics_to_values.iteritems():
tf.summary.scalar(name, value)
num_batches = int(math.ceil(num_samples / float(FLAGS.batch_size)))
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=metrics_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('task', 0, 'The task ID.')
flags.DEFINE_integer('num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then '
'the parameters are handled locally by the worker.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_float('learning_rate', 0.001, 'The initial learning rate.')
flags.DEFINE_integer(
'learning_rate_decay_steps', 20000,
'The frequency, in steps, at which the learning rate is decayed.')
flags.DEFINE_float('learning_rate_decay_factor',
0.95,
'The factor with which the learning rate is decayed.')
flags.DEFINE_float('adam_beta1', 0.5, 'The beta1 value for the AdamOptimizer')
flags.DEFINE_float('weight_decay', 1e-5,
'The L2 coefficient on the model weights.')
flags.DEFINE_string(
'logdir', None, 'The location of the logs and checkpoints.')
flags.DEFINE_integer('save_interval_secs', 600,
'How often, in seconds, we save the model to disk.')
flags.DEFINE_integer('save_summaries_secs', 600,
'How often, in seconds, we compute the summaries.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_float(
'moving_average_decay', 0.9999,
'The amount of decay to use for moving averages.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = FLAGS.weight_decay
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
with tf.Graph().as_default():
with tf.device(
tf.train.replica_device_setter(FLAGS.num_ps_tasks, merge_devices=True)):
dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
FLAGS.split_name, FLAGS.dataset_dir)
num_classes = dataset.num_classes
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=True)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# preprocess_fn=preprocess_fn)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
# Define the losses
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
tf.summary.scalar('losses/Total_Loss', total_loss)
# Setup the moving averages
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, slim.get_or_create_global_step())
tf.add_to_collection(
tf.GraphKeys.UPDATE_OPS,
variable_averages.apply(moving_average_variables))
# Specify the optimization scheme:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
slim.get_or_create_global_step(),
FLAGS.learning_rate_decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.adam_beta1)
train_op = slim.learning.create_train_op(total_loss, optimizer)
slim.learning.train(
train_op,
FLAGS.logdir,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define model HParams."""
import tensorflow as tf
def create_hparams(hparam_string=None):
"""Create model hyperparameters. Parse nondefault from given string."""
hparams = tf.contrib.training.HParams(
# The name of the architecture to use.
arch='resnet',
lrelu_leakiness=0.2,
batch_norm_decay=0.9,
weight_decay=1e-5,
normal_init_std=0.02,
generator_kernel_size=3,
discriminator_kernel_size=3,
# Stop training after this many examples are processed
# If none, train indefinitely
num_training_examples=0,
# Apply data augmentation to datasets
# Applies only in training job
augment_source_images=False,
augment_target_images=False,
# Discriminator
# Number of filters in first layer of discriminator
num_discriminator_filters=64,
discriminator_conv_block_size=1, # How many convs to have at each size
discriminator_filter_factor=2.0, # Multiply # filters by this each layer
# Add gaussian noise with this stddev to every hidden layer of D
discriminator_noise_stddev=0.2, # lmetz: Start seeing results at >= 0.1
# If true, add this gaussian noise to input images to D as well
discriminator_image_noise=False,
discriminator_first_stride=1, # Stride in first conv of discriminator
discriminator_do_pooling=False, # If true, replace stride 2 with avg pool
discriminator_dropout_keep_prob=0.9, # keep probability for dropout
# DCGAN Generator
# Number of filters in generator decoder last layer (repeatedly halved
# from 1st layer)
num_decoder_filters=64,
# Number of filters in generator encoder 1st layer (repeatedly doubled
# after 1st layer)
num_encoder_filters=64,
# This is the shape to which the noise vector is projected (if we're
# transferring from noise).
# Write this way instead of [4, 4, 64] for hparam search flexibility
projection_shape_size=4,
projection_shape_channels=64,
# Indicates the method by which we enlarge the spatial representation
# of an image. Possible values include:
# - resize_conv: Performs a nearest neighbor resize followed by a conv.
# - conv2d_transpose: Performs a conv2d_transpose.
upsample_method='resize_conv',
# Visualization
summary_steps=500, # Output image summary every N steps
###################################
# Task Classifier Hyperparameters #
###################################
# Which task-specific prediction tower to use. Possible choices are:
# none: No task tower.
# doubling_pose_estimator: classifier + quaternion regressor.
# [conv + pool]* + FC
# Classifiers used in DSN paper:
# gtsrb: Classifier used for GTSRB
# svhn: Classifier used for SVHN
# mnist: Classifier used for MNIST
# pose_mini: Classifier + regressor used for pose_mini
task_tower='doubling_pose_estimator',
weight_decay_task_classifier=1e-5,
source_task_loss_weight=1.0,
transferred_task_loss_weight=1.0,
# Number of private layers in doubling_pose_estimator task tower
num_private_layers=2,
# The weight for the log quaternion loss we use for source and transferred
# samples of the cropped_linemod dataset.
# In the DSN work, 1/8 of the classifier weight worked well for our log
# quaternion loss
source_pose_weight=0.125 * 2.0,
transferred_pose_weight=0.125 * 1.0,
# If set to True, the style transfer network also attempts to change its
# weights to maximize the performance of the task tower. If set to False,
# then the style transfer network only attempts to change its weights to
# make the transferred images more likely according to the domain
# classifier.
task_tower_in_g_step=True,
task_loss_in_g_weight=1.0, # Weight of task loss in G
#########################################
# 'simple` generator arch model hparams #
#########################################
simple_num_conv_layers=1,
simple_conv_filters=8,
#########################
# Resnet Hyperparameters#
#########################
resnet_blocks=6, # Number of resnet blocks
resnet_filters=64, # Number of filters per conv in resnet blocks
# If true, add original input back to result of convolutions inside the
# resnet arch. If false, it turns into a simple stack of conv/relu/BN
# layers.
resnet_residuals=True,
#######################################
# The residual / interpretable model. #
#######################################
res_int_blocks=2, # The number of residual blocks.
res_int_convs=2, # The number of conv calls inside each block.
res_int_filters=64, # The number of filters used by each convolution.
####################
# Latent variables #
####################
# if true, then generate random noise and project to input for generator
noise_channel=True,
# The number of dimensions in the input noise vector.
noise_dims=10,
# If true, then one hot encode source image class and project as an
# additional channel for the input to generator. This gives the generator
# access to the class, which may help generation performance.
condition_on_source_class=False,
########################
# Loss Hyperparameters #
########################
domain_loss_weight=1.0,
style_transfer_loss_weight=1.0,
########################################################################
# Encourages the transferred images to be similar to the source images #
# using a configurable metric. #
########################################################################
# The weight of the loss function encouraging the source and transferred
# images to be similar. If set to 0, then the loss function is not used.
transferred_similarity_loss_weight=0.0,
# The type of loss used to encourage transferred and source image
# similarity. Valid values include:
# mpse: Mean Pairwise Squared Error
# mse: Mean Squared Error
# hinged_mse: Computes the mean squared error using squared differences
# greater than hparams.transferred_similarity_max_diff
# hinged_mae: Computes the mean absolute error using absolute
# differences greater than hparams.transferred_similarity_max_diff.
transferred_similarity_loss='mpse',
# The maximum allowable difference between the source and target images.
# This value is used, in effect, to produce a hinge loss. Note that the
# range of values should be between 0 and 1.
transferred_similarity_max_diff=0.4,
################################
# Optimization Hyperparameters #
################################
learning_rate=0.001,
batch_size=32,
lr_decay_steps=20000,
lr_decay_rate=0.95,
# Recomendation from the DCGAN paper:
adam_beta1=0.5,
clip_gradient_norm=5.0,
# The number of times we run the discriminator train_op in a row.
discriminator_steps=1,
# The number of times we run the generator train_op in a row.
generator_steps=1)
if hparam_string:
tf.logging.info('Parsing command line hparams: %s', hparam_string)
hparams.parse(hparam_string)
tf.logging.info('Final parsed hparams: %s', hparams.values())
return hparams
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evaluates the PIXELDA model.
-- Compiles the model for CPU.
$ bazel build -c opt third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Compile the model for GPU.
$ bazel build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Runs the training.
$ ./bazel-bin/third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation/pixelda_eval \
--source_dataset=mnist \
--target_dataset=mnist_m \
--dataset_dir=/tmp/datasets/ \
--alsologtostderr
-- Visualize the results.
$ bash learning/brain/tensorboard/tensorboard.sh \
--port 2222 --logdir=/tmp/pixelda/
"""
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/pixelda/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/pixelda/',
'Directory where the results are saved to.')
flags.DEFINE_integer('eval_interval_secs', 60,
'The frequency, in seconds, with which evaluation is run.')
flags.DEFINE_string('target_split_name', 'test',
'The name of the train/test split.')
flags.DEFINE_string('source_split_name', 'train', 'Split for source dataset.'
' Defaults to train.')
flags.DEFINE_string('source_dataset', 'mnist',
'The name of the source dataset.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string(
'dataset_dir',
'', # None,
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def run_eval(run_dir, checkpoint_dir, hparams):
"""Runs the eval loop.
Args:
run_dir: The directory where eval specific logs are placed
checkpoint_dir: The directory where the checkpoints are stored
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for checkpoint_path in slim.evaluation.checkpoints_iterator(
checkpoint_dir, FLAGS.eval_interval_secs):
with tf.Graph().as_default():
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name=FLAGS.target_split_name,
dataset_dir=FLAGS.dataset_dir)
target_images, target_labels = dataset_factory.provide_batch(
FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
target_labels['class'] = tf.argmax(target_labels['classes'], 1)
del target_labels['classes']
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name=FLAGS.source_split_name,
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Input and output datasets must have same number of classes')
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=False,
num_classes=num_target_classes)
#######################
# Metrics & Summaries #
#######################
names_to_values, names_to_updates = create_metrics(end_points,
source_labels,
target_labels, hparams)
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summarize_images(target_images, 'Target')
for name, value in names_to_values.iteritems():
tf.summary.scalar(name, value)
# Use the entire split by default
num_examples = target_dataset.num_samples
num_batches = math.ceil(num_examples / float(hparams.batch_size))
global_step = slim.get_or_create_global_step()
result = slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=checkpoint_path,
logdir=run_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
final_op=names_to_values)
def to_degrees(log_quaternion_loss):
"""Converts a log quaternion distance to an angle.
Args:
log_quaternion_loss: The log quaternion distance between two
unit quaternions (or a batch of pairs of quaternions).
Returns:
The angle in degrees of the implied angle-axis representation.
"""
return tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
def create_metrics(end_points, source_labels, target_labels, hparams):
"""Create metrics for the model.
Args:
end_points: A dictionary of end point name to tensor
source_labels: Labels for source images. batch_size x 1
target_labels: Labels for target images. batch_size x 1
hparams: The hyperparameters struct.
Returns:
Tuple of (names_to_values, names_to_updates), dictionaries that map a metric
name to its value and update op, respectively
"""
###########################################
# Evaluate the Domain Prediction Accuracy #
###########################################
batch_size = hparams.batch_size
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
('eval/Domain_Accuracy-Transferred'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points[
'transferred_domain_logits']))),
tf.zeros(batch_size, dtype=tf.int32)),
('eval/Domain_Accuracy-Target'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points['target_domain_logits']))),
tf.ones(batch_size, dtype=tf.int32))
})
################################
# Evaluate the task classifier #
################################
if 'source_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['source_task_logits'], 1),
source_labels['class'])
if 'transferred_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['transferred_task_logits'], 1),
source_labels['class'])
if 'target_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['target_task_logits'], 1),
target_labels['class'])
##########################################################################
# Pose data-specific losses.
##########################################################################
if 'quaternion' in source_labels.keys():
params = {}
params['use_logging'] = False
params['batch_size'] = batch_size
angle_loss_source = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'source_quaternion'], source_labels['quaternion'], params))
angle_loss_transferred = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'transferred_quaternion'], source_labels['quaternion'], params))
angle_loss_target = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'target_quaternion'], target_labels['quaternion'], params))
metric_name = 'eval/Angle_Loss-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_source)
metric_name = 'eval/Angle_Loss-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_transferred)
metric_name = 'eval/Angle_Loss-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_target)
return names_to_values, names_to_updates
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_eval(
run_dir=FLAGS.eval_dir,
checkpoint_dir=FLAGS.checkpoint_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the various loss functions in use by the PIXELDA model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
def add_domain_classifier_losses(end_points, hparams):
"""Adds losses related to the domain-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
hparams: The hyperparameters struct.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
if hparams.domain_loss_weight == 0:
tf.logging.info(
'Domain classifier loss weight is 0, so not creating losses.')
return 0
# The domain prediction loss is minimized with respect to the domain
# classifier features only. Its aim is to predict the domain of the images.
# Note: 1 = 'real image' label, 0 = 'fake image' label
transferred_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']),
logits=end_points['transferred_domain_logits'])
tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss)
target_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.ones_like(end_points['target_domain_logits']),
logits=end_points['target_domain_logits'])
tf.summary.scalar('Domain_loss_target', target_domain_loss)
# Compute the total domain loss:
total_domain_loss = transferred_domain_loss + target_domain_loss
total_domain_loss *= hparams.domain_loss_weight
tf.summary.scalar('Domain_loss_total', total_domain_loss)
return total_domain_loss
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(internal_dot_products, [
internal_dot_products,
tf.shape(internal_dot_products)
], 'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
def _quaternion_loss(labels, predictions, weight, batch_size, domain,
add_summaries):
"""Creates a Quaternion Loss.
Args:
labels: The true quaternions.
predictions: The predicted quaternions.
weight: A scalar weight.
batch_size: The size of the batches.
domain: The name of the domain from which the labels were taken.
add_summaries: Whether or not to add summaries for the losses.
Returns:
A `Tensor` representing the loss.
"""
assert domain in ['Source', 'Transferred']
params = {'use_logging': False, 'batch_size': batch_size}
loss = weight * log_quaternion_loss(labels, predictions, params)
if add_summaries:
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.histogram(
'Log_Quaternion_Loss_%s' % domain, loss, collections='losses')
tf.summary.scalar(
'Task_Quaternion_Loss_%s' % domain, loss, collections='losses')
return loss
def _add_task_specific_losses(end_points, source_labels, num_classes, hparams,
add_summaries=False):
"""Adds losses related to the task-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
add_summaries: Whether or not to add the summaries.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
# TODO(ddohan): Make sure the l2 regularization is added to the loss
one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes)
total_loss = 0
if 'source_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['source_task_logits'],
weights=hparams.source_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Source', loss)
total_loss += loss
if 'transferred_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['transferred_task_logits'],
weights=hparams.transferred_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Transferred', loss)
total_loss += loss
#########################
# Pose specific losses. #
#########################
if 'quaternion' in source_labels:
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['source_quaternion'],
hparams.source_pose_weight,
hparams.batch_size,
'Source',
add_summaries)
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['transferred_quaternion'],
hparams.transferred_pose_weight,
hparams.batch_size,
'Transferred',
add_summaries)
if add_summaries:
tf.summary.scalar('Task_Loss_Total', total_loss)
return total_loss
def _transferred_similarity_loss(reconstructions,
source_images,
weight=1.0,
method='mse',
max_diff=0.4,
name='similarity'):
"""Computes a loss encouraging similarity between source and transferred.
Args:
reconstructions: A `Tensor` of shape [batch_size, height, width, channels]
source_images: A `Tensor` of shape [batch_size, height, width, channels].
weight: Multiple similarity loss by this weight before returning
method: One of:
mpse = Mean Pairwise Squared Error
mse = Mean Squared Error
hinged_mse = Computes the mean squared error using squared differences
greater than hparams.transferred_similarity_max_diff
hinged_mae = Computes the mean absolute error using absolute
differences greater than hparams.transferred_similarity_max_diff.
max_diff: Maximum unpenalized difference for hinged losses
name: Identifying name to use for creating summaries
Returns:
A `Tensor` representing the transferred similarity loss.
Raises:
ValueError: if `method` is not recognized.
"""
if weight == 0:
return 0
source_channels = source_images.shape.as_list()[-1]
reconstruction_channels = reconstructions.shape.as_list()[-1]
# Convert grayscale source to RGB if target is RGB
if source_channels == 1 and reconstruction_channels != 1:
source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels])
if reconstruction_channels == 1 and source_channels != 1:
reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels])
if method == 'mpse':
reconstruction_similarity_loss_fn = (
tf.contrib.losses.mean_pairwise_squared_error)
elif method == 'masked_mpse':
def masked_mpse(predictions, labels, weight):
"""Masked mpse assuming we have a depth to create a mask from."""
assert labels.shape.as_list()[-1] == 4
mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99))
mask = tf.tile(mask, [1, 1, 1, 4])
predictions *= mask
labels *= mask
tf.image_summary('masked_pred', predictions)
tf.image_summary('masked_label', labels)
return tf.contrib.losses.mean_pairwise_squared_error(
predictions, labels, weight)
reconstruction_similarity_loss_fn = masked_mpse
elif method == 'mse':
reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error
elif method == 'hinged_mse':
def hinged_mse(predictions, labels, weight):
diffs = tf.square(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mse
elif method == 'hinged_mae':
def hinged_mae(predictions, labels, weight):
diffs = tf.abs(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mae
else:
raise ValueError('Unknown reconstruction loss %s' % method)
reconstruction_similarity_loss = reconstruction_similarity_loss_fn(
reconstructions, source_images, weight)
name = '%s_Similarity_(%s)' % (name, method)
tf.summary.scalar(name, reconstruction_similarity_loss)
return reconstruction_similarity_loss
def g_step_loss(source_images, source_labels, end_points, hparams, num_classes):
"""Configures the loss function which runs during the g-step.
Args:
source_images: A `Tensor` of shape [batch_size, height, width, channels].
source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys
are 'class' and 'quaternion'.
end_points: A map of the network end points.
hparams: The hyperparameters struct.
num_classes: Number of classes for classifier loss
Returns:
A `Tensor` representing a loss function.
Raises:
ValueError: if hparams.transferred_similarity_loss_weight is non-zero but
hparams.transferred_similarity_loss is invalid.
"""
generator_loss = 0
################################################################
# Adds a loss which encourages the discriminator probabilities #
# to be high (near one).
################################################################
# As per the GAN paper, maximize the log probs, instead of minimizing
# log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is
# the same thing.
style_transfer_loss = tf.losses.sigmoid_cross_entropy(
logits=end_points['transferred_domain_logits'],
multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']),
weights=hparams.style_transfer_loss_weight)
tf.summary.scalar('Style_transfer_loss', style_transfer_loss)
generator_loss += style_transfer_loss
# Optimizes the style transfer network to produce transferred images similar
# to the source images.
generator_loss += _transferred_similarity_loss(
end_points['transferred_images'],
source_images,
weight=hparams.transferred_similarity_loss_weight,
method=hparams.transferred_similarity_loss,
name='transferred_similarity')
# Optimizes the style transfer network to maximize classification accuracy.
if source_labels is not None and hparams.task_tower_in_g_step:
generator_loss += _add_task_specific_losses(
end_points, source_labels, num_classes,
hparams) * hparams.task_loss_in_g_weight
return generator_loss
def d_step_loss(end_points, source_labels, num_classes, hparams):
"""Configures the losses during the D-Step.
Note that during the D-step, the model optimizes both the domain (binary)
classifier and the task classifier.
Args:
end_points: A map of the network end points.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
Returns:
A `Tensor` representing the value of the D-step loss.
"""
domain_classifier_loss = add_domain_classifier_losses(end_points, hparams)
task_classifier_loss = 0
if source_labels is not None:
task_classifier_loss = _add_task_specific_losses(
end_points, source_labels, num_classes, hparams, add_summaries=True)
return domain_classifier_loss + task_classifier_loss
This diff is collapsed.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains functions for preprocessing the inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
def preprocess_classification(image, labels, is_training=False):
"""Preprocesses the image and labels for classification purposes.
Preprocessing includes shifting the images to be 0-centered between -1 and 1.
This is not only a popular method of preprocessing (inception) but is also
the mechanism used by DSNs.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
is_training: Whether or not we're training the model.
Returns:
The preprocessed image and labels.
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
return image, labels
def preprocess_style_transfer(image,
labels,
augment=False,
size=None,
is_training=False):
"""Preprocesses the image and labels for style transfer purposes.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
augment: Whether to apply data augmentation to inputs
size: The height and width to which images should be resized. If left as
`None`, then no resizing is performed
is_training: Whether or not we're training the model
Returns:
The preprocessed image and labels. Scaled to [-1, 1]
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
if augment and is_training:
image = image_augmentation(image)
if size:
image = resize_image(image, size)
image -= 0.5
image *= 2
return image, labels
def image_augmentation(image):
"""Performs data augmentation by randomly permuting the inputs.
Args:
image: A float `Tensor` of size [height, width, channels] with values
in range[0,1].
Returns:
The mutated batch of images
"""
# Apply photometric data augmentation (contrast etc.)
num_channels = image.shape_as_list()[-1]
if num_channels == 4:
# Only augment image part
image, depth = image[:, :, 0:3], image[:, :, 3:4]
elif num_channels == 1:
image = tf.image.grayscale_to_rgb(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.clip_by_value(image, 0, 1.0)
if num_channels == 4:
image = tf.concat(2, [image, depth])
elif num_channels == 1:
image = tf.image.rgb_to_grayscale(image)
return image
def resize_image(image, size=None):
"""Resize image to target size.
Args:
image: A `Tensor` of size [height, width, 3].
size: (height, width) to resize image to.
Returns:
resized image
"""
if size is None:
raise ValueError('Must specify size')
if image.shape_as_list()[:2] == size:
# Don't resize if not necessary
return image
image = tf.expand_dims(image, 0)
image = tf.image.resize_images(image, size)
image = tf.squeeze(image, 0)
return image
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for domain_adaptation.pixel_domain_adaptation.pixelda_preprocess."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
class PixelDAPreprocessTest(tf.test.TestCase):
def assert_preprocess_classification_is_centered(self, dtype, is_training):
tf.set_random_seed(0)
if dtype == tf.uint8:
image = tf.random_uniform((100, 200, 3), maxval=255, dtype=tf.int64)
image = tf.cast(image, tf.uint8)
else:
image = tf.random_uniform((100, 200, 3), maxval=1.0, dtype=dtype)
labels = {}
image, labels = pixelda_preprocess.preprocess_classification(
image, labels, is_training=is_training)
with self.test_session() as sess:
np_image = sess.run(image)
self.assertTrue(np_image.min() <= -0.95)
self.assertTrue(np_image.min() >= -1.0)
self.assertTrue(np_image.max() >= 0.95)
self.assertTrue(np_image.max() <= 1.0)
def testPreprocessClassificationZeroCentersUint8DuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=True)
def testPreprocessClassificationZeroCentersUint8DuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=False)
def testPreprocessClassificationZeroCentersFloatDuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=True)
def testPreprocessClassificationZeroCentersFloatDuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task towers for PixelDA model."""
import tensorflow as tf
slim = tf.contrib.slim
def add_task_specific_model(images,
hparams,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope=None):
"""Create a classifier for the given images.
The classifier is composed of a few 'private' layers followed by a few
'shared' layers. This lets us account for different image 'style', while
sharing the last few layers as 'content' layers.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
hparams: model hparams
num_classes: The number of output classes.
is_training: whether model is training
reuse_private: Whether or not to reuse the private weights, which are the
first few layers in the classifier
private_scope: The name of the variable_scope for the private (unshared)
components of the classifier.
reuse_shared: Whether or not to reuse the shared weights, which are the last
few layers in the classifier
shared_scope: The name of the variable_scope for the shared components of
the classifier.
Returns:
The logits, a `Tensor` of shape [batch_size, num_classes].
Raises:
ValueError: If hparams.task_classifier is an unknown value
"""
model = hparams.task_tower
# Make sure the classifier name shows up in graph
shared_scope = shared_scope or (model + '_shared')
kwargs = {
'num_classes': num_classes,
'is_training': is_training,
'reuse_private': reuse_private,
'reuse_shared': reuse_shared,
}
if private_scope:
kwargs['private_scope'] = private_scope
if shared_scope:
kwargs['shared_scope'] = shared_scope
quaternion_pred = None
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay_task_classifier)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if model == 'doubling_pose_estimator':
logits, quaternion_pred = doubling_cnn_class_and_quaternion(
images, num_private_layers=hparams.num_private_layers, **kwargs)
elif model == 'mnist':
logits, _ = mnist_classifier(images, **kwargs)
elif model == 'svhn':
logits, _ = svhn_classifier(images, **kwargs)
elif model == 'gtsrb':
logits, _ = gtsrb_classifier(images, **kwargs)
elif model == 'pose_mini':
logits, quaternion_pred = pose_mini_tower(images, **kwargs)
else:
raise ValueError('Unknown task classifier %s' % model)
return logits, quaternion_pred
#####################################
# Classifiers used in the DSN paper #
#####################################
def mnist_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope='mnist',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional MNIST model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits, endpoints = conv_mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 48, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool2']), 100, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 100, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def svhn_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional SVHN model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [3, 3], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 64, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [3, 3], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 128, [5, 5], scope='conv3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['conv3']), 3072, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 2048, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def gtsrb_classifier(images,
is_training=False,
num_classes=43,
reuse_private=False,
private_scope='gtsrb',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional GTSRB model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
reuse_private: Whether or not to reuse the private components of the model.
private_scope: The name of the private scope.
reuse_shared: Whether or not to reuse the shared components of the model.
shared_scope: The name of the shared scope.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 144, [3, 3], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 256, [5, 5], scope='conv3')
net['pool3'] = slim.max_pool2d(net['conv3'], [2, 2], 2, scope='pool3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool3']), 512, scope='fc3')
logits = slim.fully_connected(
net['fc3'], num_classes, activation_fn=None, scope='fc4')
return logits, net
#########################
# pose_mini task towers #
#########################
def pose_mini_tower(images,
num_classes=11,
is_training=False,
reuse_private=False,
private_scope='pose_mini',
reuse_shared=False,
shared_scope='task_model'):
"""Task tower for the pose_mini dataset."""
with tf.variable_scope(private_scope, reuse=reuse_private):
net = slim.conv2d(images, 32, [5, 5], scope='conv1')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2')
net = slim.flatten(net)
net = slim.fully_connected(net, 128, scope='fc3')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
with tf.variable_scope('quaternion_prediction'):
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc4')
return logits, quaternion_pred
def doubling_cnn_class_and_quaternion(images,
num_private_layers=1,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope='doubling_cnn',
reuse_shared=False,
shared_scope='task_model'):
"""Alternate conv, pool while doubling filter count."""
net = images
depth = 32
layer_id = 1
with tf.variable_scope(private_scope, reuse=reuse_private):
while num_private_layers > 0 and net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
num_private_layers -= 1
with tf.variable_scope(shared_scope, reuse=reuse_shared):
while net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
net = slim.flatten(net)
net = slim.fully_connected(net, 100, scope='fc1')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc_logits')
return logits, quaternion_pred
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the PixelDA model."""
from functools import partial
import os
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_string('train_log_dir', '/tmp/pixelda/',
'Directory where to write event logs.')
flags.DEFINE_integer(
'save_summaries_steps', 500,
'The frequency with which summaries are saved, in seconds.')
flags.DEFINE_integer('save_interval_secs', 300,
'The frequency with which the model is saved, in seconds.')
flags.DEFINE_boolean('summarize_gradients', False,
'Whether to summarize model gradients')
flags.DEFINE_integer(
'print_loss_steps', 100,
'The frequency with which the losses are printed, in steps.')
flags.DEFINE_string('source_dataset', 'mnist', 'The name of the source dataset.'
' If hparams="arch=dcgan", this flag is ignored.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string('source_split_name', 'train',
'Name of the train split for the source.')
flags.DEFINE_string('target_split_name', 'train',
'Name of the train split for the target.')
flags.DEFINE_string('dataset_dir', '',
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def _get_vars_and_update_ops(hparams, scope):
"""Returns the variables and update ops for a particular variable scope.
Args:
hparams: The hyperparameters struct.
scope: The variable scope.
Returns:
A tuple consisting of trainable variables and update ops.
"""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = filter(is_trainable, slim.get_model_variables(scope))
global_step = slim.get_or_create_global_step()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)
tf.logging.info('All variables for scope: %s',
slim.get_model_variables(scope))
tf.logging.info('Trainable variables for scope: %s', var_list)
return var_list, update_ops
def _train(discriminator_train_op,
generator_train_op,
logdir,
master='',
is_chief=True,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=100,
hparams=None):
"""Runs the training loop.
Args:
discriminator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the discriminator.
generator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the generator.
logdir: The directory where the graph and checkpoints are saved.
master: The URL of the master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
scaffold: An tf.train.Scaffold instance.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
training loop.
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
inside the training loop for the chief trainer only.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If `save_checkpoint_secs` is set to
`None`, then the default checkpoint saver isn't used.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If
`save_summaries_steps` is set to `None`, then the default summary saver
isn't used.
hparams: The hparams struct.
Returns:
the value of the loss function after training.
Raises:
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
`save_summaries_steps` are `None.
"""
global_step = slim.get_or_create_global_step()
scaffold = scaffold or tf.train.Scaffold()
hooks = hooks or []
if is_chief:
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold, checkpoint_dir=logdir, master=master)
if chief_only_hooks:
hooks.extend(chief_only_hooks)
hooks.append(tf.train.StepCounterHook(output_dir=logdir))
if save_summaries_steps:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_summaries_steps is None')
hooks.append(
tf.train.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
output_dir=logdir))
if save_checkpoint_secs:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_checkpoint_secs is None')
hooks.append(
tf.train.CheckpointSaverHook(
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
else:
session_creator = tf.train.WorkerSessionCreator(
scaffold=scaffold, master=master)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=hooks) as session:
loss = None
while not session.should_stop():
# Run the domain classifier op X times.
for _ in range(hparams.discriminator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run(
[discriminator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Discriminator Loss = %.2f', np_global_step,
loss)
# Run the generator op X times.
for _ in range(hparams.generator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run([generator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Generator Loss = %.2f', np_global_step,
loss)
return loss
def run_training(run_dir, checkpoint_dir, hparams):
"""Runs the training loop.
Args:
run_dir: The directory where training specific logs are placed
checkpoint_dir: The directory where the checkpoints and log files are
stored.
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for path in [run_dir, checkpoint_dir]:
if not tf.gfile.Exists(path):
tf.gfile.MakeDirs(path)
# Serialize hparams to log dir
hparams_filename = os.path.join(checkpoint_dir, 'hparams.json')
with tf.gfile.FastGFile(hparams_filename, 'w') as f:
f.write(hparams.to_json())
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
target_images, _ = dataset_factory.provide_batch(
FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
# Data provider provides 1 hot labels, but we expect categorical.
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Source and Target datasets must have same number of classes. '
'Are %d and %d' % (num_source_classes, num_target_classes))
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=True,
num_classes=num_target_classes)
#################################
# Get the variables to optimize #
#################################
generator_vars, generator_update_ops = _get_vars_and_update_ops(
hparams, 'generator')
discriminator_vars, discriminator_update_ops = _get_vars_and_update_ops(
hparams, 'discriminator')
########################
# Configure the losses #
########################
generator_loss = pixelda_losses.g_step_loss(
source_images,
source_labels,
end_points,
hparams,
num_classes=num_target_classes)
discriminator_loss = pixelda_losses.d_step_loss(
end_points, source_labels, num_target_classes, hparams)
###########################
# Create the training ops #
###########################
learning_rate = hparams.learning_rate
if hparams.lr_decay_steps:
learning_rate = tf.train.exponential_decay(
learning_rate,
slim.get_or_create_global_step(),
decay_steps=hparams.lr_decay_steps,
decay_rate=hparams.lr_decay_rate,
staircase=True)
tf.summary.scalar('Learning_rate', learning_rate)
if hparams.discriminator_steps == 0:
discriminator_train_op = tf.no_op()
else:
discriminator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
discriminator_train_op = slim.learning.create_train_op(
discriminator_loss,
discriminator_optimizer,
update_ops=discriminator_update_ops,
variables_to_train=discriminator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
if hparams.generator_steps == 0:
generator_train_op = tf.no_op()
else:
generator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
generator_train_op = slim.learning.create_train_op(
generator_loss,
generator_optimizer,
update_ops=generator_update_ops,
variables_to_train=generator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
#############
# Summaries #
#############
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summaries_color_distributions(end_points['transferred_images'],
'Transferred')
pixelda_utils.summaries_color_distributions(target_images, 'Target')
if source_images is not None:
pixelda_utils.summarize_transferred(source_images,
end_points['transferred_images'])
pixelda_utils.summaries_color_distributions(source_images, 'Source')
pixelda_utils.summaries_color_distributions(
tf.abs(source_images - end_points['transferred_images']),
'Abs(Source_minus_Transferred)')
number_of_steps = None
if hparams.num_training_examples:
# Want to control by amount of data seen, not # steps
number_of_steps = hparams.num_training_examples / hparams.batch_size
hooks = [tf.train.StepCounterHook(),]
chief_only_hooks = [
tf.train.CheckpointSaverHook(
saver=tf.train.Saver(),
checkpoint_dir=run_dir,
save_secs=FLAGS.save_interval_secs)
]
if number_of_steps:
hooks.append(tf.train.StopAtStepHook(last_step=number_of_steps))
_train(
discriminator_train_op,
generator_train_op,
logdir=run_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=None,
save_summaries_steps=FLAGS.save_summaries_steps,
hparams=hparams)
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_training(
run_dir=FLAGS.train_log_dir,
checkpoint_dir=FLAGS.train_log_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for PixelDA model."""
import math
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
def remove_depth(images):
"""Takes a batch of images and remove depth channel if present."""
if images.shape.as_list()[-1] == 4:
return images[:, :, :, 0:3]
return images
def image_grid(images, max_grid_size=4):
"""Given images and N, return first N^2 images as an NxN image grid.
Args:
images: a `Tensor` of size [batch_size, height, width, channels]
max_grid_size: Maximum image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
images = remove_depth(images)
batch_size = images.shape.as_list()[0]
grid_size = min(int(math.sqrt(batch_size)), max_grid_size)
assert images.shape.as_list()[0] >= grid_size * grid_size
# If we have a depth channel
if images.shape.as_list()[-1] == 4:
images = images[:grid_size * grid_size, :, :, 0:3]
depth = tf.image.grayscale_to_rgb(images[:grid_size * grid_size, :, :, 3:4])
images = tf.reshape(images, [-1, images.shape.as_list()[2], 3])
split = tf.split(0, grid_size, images)
depth = tf.reshape(depth, [-1, images.shape.as_list()[2], 3])
depth_split = tf.split(0, grid_size, depth)
grid = tf.concat(split + depth_split, 1)
return tf.expand_dims(grid, 0)
else:
images = images[:grid_size * grid_size, :, :, :]
images = tf.reshape(
images, [-1, images.shape.as_list()[2],
images.shape.as_list()[3]])
split = tf.split(images, grid_size, 0)
grid = tf.concat(split, 1)
return tf.expand_dims(grid, 0)
def source_and_output_image_grid(output_images,
source_images=None,
max_grid_size=4):
"""Create NxN image grid for output, concatenate source grid if given.
Makes grid out of output_images and, if provided, source_images, and
concatenates them.
Args:
output_images: [batch_size, h, w, c] tensor of images
source_images: optional[batch_size, h, w, c] tensor of images
max_grid_size: Image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
output_grid = image_grid(output_images, max_grid_size=max_grid_size)
if source_images is not None:
source_grid = image_grid(source_images, max_grid_size=max_grid_size)
# Make sure they have the same # of channels before concat
# Assumes either 1 or 3 channels
if output_grid.shape.as_list()[-1] != source_grid.shape.as_list()[-1]:
if output_grid.shape.as_list()[-1] == 1:
output_grid = tf.tile(output_grid, [1, 1, 1, 3])
if source_grid.shape.as_list()[-1] == 1:
source_grid = tf.tile(source_grid, [1, 1, 1, 3])
output_grid = tf.concat([output_grid, source_grid], 1)
return output_grid
def summarize_model(end_points):
"""Summarizes the given model via its end_points.
Args:
end_points: A dictionary of end_point names to `Tensor`.
"""
tf.summary.histogram('domain_logits_transferred',
tf.sigmoid(end_points['transferred_domain_logits']))
tf.summary.histogram('domain_logits_target',
tf.sigmoid(end_points['target_domain_logits']))
def summarize_transferred_grid(transferred_images,
source_images=None,
name='Transferred'):
"""Produces a visual grid summarization of the image transferrence.
Args:
transferred_images: A `Tensor` of size [batch_size, height, width, c].
source_images: A `Tensor` of size [batch_size, height, width, c].
name: Name to use in summary name
"""
if source_images is not None:
grid = source_and_output_image_grid(transferred_images, source_images)
else:
grid = image_grid(transferred_images)
tf.summary.image('%s_Images_Grid' % name, grid, max_outputs=1)
def summarize_transferred(source_images,
transferred_images,
max_images=20,
name='Transferred'):
"""Produces a visual summary of the image transferrence.
This summary displays the source image, transferred image, and a grayscale
difference image which highlights the differences between input and output.
Args:
source_images: A `Tensor` of size [batch_size, height, width, channels].
transferred_images: A `Tensor` of size [batch_size, height, width, channels]
max_images: The number of images to show.
name: Name to use in summary name
Raises:
ValueError: If number of channels in source and target are incompatible
"""
source_channels = source_images.shape.as_list()[-1]
transferred_channels = transferred_images.shape.as_list()[-1]
if source_channels < transferred_channels:
if source_channels != 1:
raise ValueError(
'Source must be 1 channel or same # of channels as target')
source_images = tf.tile(source_images, [1, 1, 1, transferred_channels])
if transferred_channels < source_channels:
if transferred_channels != 1:
raise ValueError(
'Target must be 1 channel or same # of channels as source')
transferred_images = tf.tile(transferred_images, [1, 1, 1, source_channels])
diffs = tf.abs(source_images - transferred_images)
diffs = tf.reduce_max(diffs, reduction_indices=[3], keep_dims=True)
diffs = tf.tile(diffs, [1, 1, 1, max(source_channels, transferred_channels)])
transition_images = tf.concat([
source_images,
transferred_images,
diffs,
], 2)
tf.summary.image(
'%s_difference' % name, transition_images, max_outputs=max_images)
def summaries_color_distributions(images, name):
"""Produces a histogram of the color distributions of the images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
tf.summary.histogram('color_values/%s' % name, images)
def summarize_images(images, name):
"""Produces a visual summary of the given images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
grid = image_grid(images)
tf.summary.image('%s_Images' % name, grid, max_outputs=1)
...@@ -322,7 +322,7 @@ bazel-bin/im2txt/run_inference \ ...@@ -322,7 +322,7 @@ bazel-bin/im2txt/run_inference \
Example output: Example output:
```shell ```
Captions for image COCO_val2014_000000224477.jpg: Captions for image COCO_val2014_000000224477.jpg:
0) a man riding a wave on top of a surfboard . (p=0.040413) 0) a man riding a wave on top of a surfboard . (p=0.040413)
1) a person riding a surf board on a wave (p=0.017452) 1) a person riding a surf board on a wave (p=0.017452)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment