Commit aed6922f authored by David Dohan's avatar David Dohan
Browse files

Open sourcing PixelDA code

parent 2a5f2a95
# Domain Separation Networks
## Introduction
This is the code used for two domain adaptation papers.
The `domain_separation` directory contains code for the "Domain Separation
Networks" paper by Bousmalis K., Trigeorgis G., et al. which was presented at
NIPS 2016. The paper can be found here: https://arxiv.org/abs/1608.06019.
## Introduction
This code is the code used for the "Domain Separation Networks" paper
by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The
paper can be found here: https://arxiv.org/abs/1608.06019.
The `pixel_domain_adaptation` directory contains the code used for the
"Unsupervised Pixel-Level Domain Adaptation with Generative Adversarial
Networks" paper by Bousmalis K., et al. (presented at CVPR 2017). The paper can
be found here: https://arxiv.org/abs/1612.05424. PixelDA aims to perform domain
adaptation by transfering the visual style of the target domain (which has few
or no labels) to a source domain (which has many labels). This is accomplished
using a Generative Adversarial Network (GAN).
## Contact
This code was open-sourced by [Konstantinos Bousmalis](https://github.com/bousmalis) (konstantinos@google.com).
The domain separation code was open-sourced
by [Konstantinos Bousmalis](https://github.com/bousmalis)
(konstantinos@google.com), while the pixel level domain adaptation code was
open-sourced by [David Dohan](https://github.com/dmrd) (ddohan@google.com).
## Installation
You will need to have the following installed on your machine before trying out the DSN code.
......@@ -16,26 +26,70 @@ You will need to have the following installed on your machine before trying out
* Bazel: https://bazel.build/
## Important Note
Although we are making the code available, you are only able to use the MNIST
provider for now. We will soon provide a script to download and convert MNIST-M
as well. Check back here in a few weeks or wait for a relevant announcement from
[@bousmalis](https://twitter.com/bousmalis).
We are working to open source the pose estimation dataset. For now, the MNIST to
MNIST-M dataset is available. Check back here in a few weeks or wait for a
relevant announcement from [@bousmalis](https://twitter.com/bousmalis).
## Running the code for adapting MNIST to MNIST-M
In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with
domain separation (DSNs) you will need to set the directory you used to download
MNIST and MNIST-M:
## Initial setup
In order to run the MNIST to MNIST-M experiments, you will need to set the
data directory:
```
$ export DSN_DATA_DIR=/your/dir
```
Add models and models/slim to your `$PYTHONPATH`:
Add models and models/slim to your `$PYTHONPATH` (assumes $PWD is /models):
```
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
```
## Getting the datasets
You can fetch the MNIST data by running
```
$ bazel run slim:download_and_convert_data -- --dataset_dir $DSN_DATA_DIR --dataset_name=mnist
```
The MNIST-M dataset is available online [here](http://bit.ly/2nrlUAJ). Once it is downloaded and extracted into your data directory, create TFRecord files by running:
```
$ bazel run domain_adaptation/datasets:download_and_convert_mnist_m -- --dataset_dir $DSN_DATA_DIR
```
# Running PixelDA from MNIST to MNIST-M
You can run PixelDA as follows (using Tensorboard to examine the results):
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_train -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m
```
And evaluation as:
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_eval -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m --target_split_name test
```
The MNIST-M results in the paper were run with the following hparams flag:
```
--hparams arch=resnet,domain_loss_weight=0.135603587834,num_training_examples=16000000,style_transfer_loss_weight=0.0113173311334,task_loss_in_g_weight=0.0100959947002,task_tower=mnist,task_tower_in_g_step=true
```
### A note on terminology/language of the code:
The components of the network can be grouped into two parts
which correspond to elements which are jointly optimized: The generator
component and the discriminator component.
The generator component takes either an image or noise vector and produces an
output image.
The discriminator component takes the generated images and the target images
and attempts to discriminate between them.
## Running DSN code for adapting MNIST to MNIST-M
Then you need to build the binaries with Bazel:
```
......
......@@ -26,10 +26,20 @@ py_library(
],
)
py_binary(
name = "download_and_convert_mnist_m",
srcs = ["download_and_convert_mnist_m.py"],
deps = [
"//slim:dataset_utils",
],
)
py_binary(
name = "mnist_m",
srcs = ["mnist_m.py"],
deps = [
"//slim:dataset_utils",
],
)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A factory-pattern class which returns image/label pairs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from slim.datasets import mnist
......
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
This module downloads the MNIST-M data, uncompresses it, reads the files
that make up the MNIST-M data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.
The script should take about a minute to run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
# Dependency imports
import numpy as np
from six.moves import urllib
import tensorflow as tf
from slim.datasets import dataset_utils
tf.app.flags.DEFINE_string(
'dataset_dir', None,
'The directory where the output TFRecords and temporary files are saved.')
FLAGS = tf.app.flags.FLAGS
_IMAGE_SIZE = 32
_NUM_CHANNELS = 3
# The number of images in the training set.
_NUM_TRAIN_SAMPLES = 59001
# The number of images to be kept from the training set for the validation set.
_NUM_VALIDATION = 1000
# The number of images in the test set.
_NUM_TEST_SAMPLES = 9001
# Seed for repeatability.
_RANDOM_SEED = 0
# The names of the classes.
_CLASS_NAMES = [
'zero',
'one',
'two',
'three',
'four',
'five',
'size',
'seven',
'eight',
'nine',
]
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB PNG data.
self._decode_png_data = tf.placeholder(dtype=tf.string)
self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_png(sess, image_data)
return image.shape[0], image.shape[1]
def decode_png(self, sess, image_data):
image = sess.run(
self._decode_png, feed_dict={self._decode_png_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'valid'.
filenames: A list of absolute paths to png images.
filename_to_class_id: A dictionary from filenames (strings) to class ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
print('Converting the {} split.'.format(split_name))
# Train and validation splits are both in the train directory.
if split_name in ['train', 'valid']:
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train')
elif split_name == 'test':
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test')
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
output_filename = _get_output_filename(dataset_dir, split_name)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for filename in filenames:
# Read the filename:
image_data = tf.gfile.FastGFile(
os.path.join(png_directory, filename), 'r').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_id = filename_to_class_id[filename]
example = dataset_utils.image_to_tfexample(image_data, 'png', height,
width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def _extract_labels(label_filename):
"""Extract the labels into a dict of filenames to int labels.
Args:
labels_filename: The filename of the MNIST-M labels.
Returns:
A dictionary of filenames to int labels.
"""
print('Extracting labels from: ', label_filename)
label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
label_lines = [line.rstrip('\n').split() for line in label_file]
labels = {}
for line in label_lines:
assert len(line) == 2
labels[line[0]] = int(line[1])
return labels
def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)
def _get_filenames(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set PNG encoded MNIST-M images.
Returns:
A list of image file paths, relative to `dataset_dir`.
"""
photo_filenames = []
for filename in os.listdir(dataset_dir):
photo_filenames.append(filename)
return photo_filenames
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
train_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
# TODO(konstantinos): Add download and cleanup functionality
train_validation_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train'))
test_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test'))
# Divide into train and validation:
random.seed(_RANDOM_SEED)
random.shuffle(train_validation_filenames)
train_filenames = train_validation_filenames[_NUM_VALIDATION:]
validation_filenames = train_validation_filenames[:_NUM_VALIDATION]
train_validation_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train_labels.txt'))
test_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test_labels.txt'))
# Convert the train, validation, and test sets.
_convert_dataset('train', train_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('valid', validation_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('test', test_filenames, test_filenames_to_class_ids,
dataset_dir)
# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
print('\nFinished converting the MNIST-M dataset!')
def main(_):
assert FLAGS.dataset_dir
run(FLAGS.dataset_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the MNIST-M dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow_models/domain_adaptation_/datasets/download_and_convert_mnist_m_dataset.py
"""
from __future__ import absolute_import
......@@ -20,6 +23,7 @@ from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import tensorflow as tf
from slim.datasets import dataset_utils
......
# Description:
# Contains code for domain-adaptation style transfer.
package(
default_visibility = [
":internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//domain_adaptation/...",
],
)
py_library(
name = "pixelda_preprocess",
srcs = ["pixelda_preprocess.py"],
deps = [
],
)
py_test(
name = "pixelda_preprocess_test",
srcs = ["pixelda_preprocess_test.py"],
deps = [
":pixelda_preprocess",
],
)
py_library(
name = "pixelda_model",
srcs = [
"pixelda_model.py",
"pixelda_task_towers.py",
"hparams.py",
],
deps = [
],
)
py_library(
name = "pixelda_utils",
srcs = ["pixelda_utils.py"],
deps = [
],
)
py_library(
name = "pixelda_losses",
srcs = ["pixelda_losses.py"],
deps = [
],
)
py_binary(
name = "pixelda_train",
srcs = ["pixelda_train.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
py_binary(
name = "pixelda_eval",
srcs = ["pixelda_eval.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
licenses(["notice"]) # Apache 2.0
py_binary(
name = "baseline_train",
srcs = ["baseline_train.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
py_binary(
name = "baseline_eval",
srcs = ["baseline_eval.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
The best baselines are obtainable via the following configuration:
## MNIST => MNIST_M
Accuracy:
MNIST-Train: 99.9
MNIST_M-Train: 63.9
MNIST_M-Valid: 63.9
MNIST_M-Test: 63.6
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 105,000
## MNIST => USPS
Accuracy:
MNIST-Train: 100.0
USPS-Train: 82.8
USPS-Valid: 82.8
USPS-Test: 78.9
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 22,000
## MNIST_M => MNIST
Accuracy:
MNIST_M-Train: 100
MNIST-Train: 98.5
MNIST-Valid: 98.5
MNIST-Test: 98.1
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 604,400
## MNIST_M => MNIST_M
Accuracy:
MNIST_M-Train: 100.0
MNIST_M-Valid: 96.6
MNIST_M-Test: 96.4
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 139,400
## USPS => USPS
Accuracy:
USPS-Train: 100.0
USPS-Valid: 100.0
USPS-Test: 96.5
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 67,000
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evals the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_string(
'checkpoint_dir', None, 'The location of the checkpoint files.')
flags.DEFINE_string(
'eval_dir', None, 'The directory where evaluation logs are written.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = 0.0
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
if not tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.MakeDirs(FLAGS.eval_dir)
with tf.Graph().as_default():
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.split_name,
FLAGS.dataset_dir)
num_classes = dataset.num_classes
num_samples = dataset.num_samples
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=False)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
#####################
# Define the losses #
#####################
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
predictions = tf.reshape(tf.argmax(logits, 1), shape=[-1])
class_labels = tf.argmax(labels['classes'], 1)
metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
'Mean_Loss':
tf.contrib.metrics.streaming_mean(total_loss),
'Accuracy':
tf.contrib.metrics.streaming_accuracy(predictions,
tf.reshape(
class_labels,
shape=[-1])),
'Recall_at_5':
tf.contrib.metrics.streaming_recall_at_k(logits, class_labels, 5),
})
tf.summary.histogram('outputs/Predictions', predictions)
tf.summary.histogram('outputs/Ground_Truth', class_labels)
for name, value in metrics_to_values.iteritems():
tf.summary.scalar(name, value)
num_batches = int(math.ceil(num_samples / float(FLAGS.batch_size)))
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=metrics_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('task', 0, 'The task ID.')
flags.DEFINE_integer('num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then '
'the parameters are handled locally by the worker.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_float('learning_rate', 0.001, 'The initial learning rate.')
flags.DEFINE_integer(
'learning_rate_decay_steps', 20000,
'The frequency, in steps, at which the learning rate is decayed.')
flags.DEFINE_float('learning_rate_decay_factor',
0.95,
'The factor with which the learning rate is decayed.')
flags.DEFINE_float('adam_beta1', 0.5, 'The beta1 value for the AdamOptimizer')
flags.DEFINE_float('weight_decay', 1e-5,
'The L2 coefficient on the model weights.')
flags.DEFINE_string(
'logdir', None, 'The location of the logs and checkpoints.')
flags.DEFINE_integer('save_interval_secs', 600,
'How often, in seconds, we save the model to disk.')
flags.DEFINE_integer('save_summaries_secs', 600,
'How often, in seconds, we compute the summaries.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_float(
'moving_average_decay', 0.9999,
'The amount of decay to use for moving averages.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = FLAGS.weight_decay
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
with tf.Graph().as_default():
with tf.device(
tf.train.replica_device_setter(FLAGS.num_ps_tasks, merge_devices=True)):
dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
FLAGS.split_name, FLAGS.dataset_dir)
num_classes = dataset.num_classes
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=True)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# preprocess_fn=preprocess_fn)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
# Define the losses
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
tf.summary.scalar('losses/Total_Loss', total_loss)
# Setup the moving averages
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, slim.get_or_create_global_step())
tf.add_to_collection(
tf.GraphKeys.UPDATE_OPS,
variable_averages.apply(moving_average_variables))
# Specify the optimization scheme:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
slim.get_or_create_global_step(),
FLAGS.learning_rate_decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.adam_beta1)
train_op = slim.learning.create_train_op(total_loss, optimizer)
slim.learning.train(
train_op,
FLAGS.logdir,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define model HParams."""
import tensorflow as tf
def create_hparams(hparam_string=None):
"""Create model hyperparameters. Parse nondefault from given string."""
hparams = tf.contrib.training.HParams(
# The name of the architecture to use.
arch='resnet',
lrelu_leakiness=0.2,
batch_norm_decay=0.9,
weight_decay=1e-5,
normal_init_std=0.02,
generator_kernel_size=3,
discriminator_kernel_size=3,
# Stop training after this many examples are processed
# If none, train indefinitely
num_training_examples=0,
# Apply data augmentation to datasets
# Applies only in training job
augment_source_images=False,
augment_target_images=False,
# Discriminator
# Number of filters in first layer of discriminator
num_discriminator_filters=64,
discriminator_conv_block_size=1, # How many convs to have at each size
discriminator_filter_factor=2.0, # Multiply # filters by this each layer
# Add gaussian noise with this stddev to every hidden layer of D
discriminator_noise_stddev=0.2, # lmetz: Start seeing results at >= 0.1
# If true, add this gaussian noise to input images to D as well
discriminator_image_noise=False,
discriminator_first_stride=1, # Stride in first conv of discriminator
discriminator_do_pooling=False, # If true, replace stride 2 with avg pool
discriminator_dropout_keep_prob=0.9, # keep probability for dropout
# DCGAN Generator
# Number of filters in generator decoder last layer (repeatedly halved
# from 1st layer)
num_decoder_filters=64,
# Number of filters in generator encoder 1st layer (repeatedly doubled
# after 1st layer)
num_encoder_filters=64,
# This is the shape to which the noise vector is projected (if we're
# transferring from noise).
# Write this way instead of [4, 4, 64] for hparam search flexibility
projection_shape_size=4,
projection_shape_channels=64,
# Indicates the method by which we enlarge the spatial representation
# of an image. Possible values include:
# - resize_conv: Performs a nearest neighbor resize followed by a conv.
# - conv2d_transpose: Performs a conv2d_transpose.
upsample_method='resize_conv',
# Visualization
summary_steps=500, # Output image summary every N steps
###################################
# Task Classifier Hyperparameters #
###################################
# Which task-specific prediction tower to use. Possible choices are:
# none: No task tower.
# doubling_pose_estimator: classifier + quaternion regressor.
# [conv + pool]* + FC
# Classifiers used in DSN paper:
# gtsrb: Classifier used for GTSRB
# svhn: Classifier used for SVHN
# mnist: Classifier used for MNIST
# pose_mini: Classifier + regressor used for pose_mini
task_tower='doubling_pose_estimator',
weight_decay_task_classifier=1e-5,
source_task_loss_weight=1.0,
transferred_task_loss_weight=1.0,
# Number of private layers in doubling_pose_estimator task tower
num_private_layers=2,
# The weight for the log quaternion loss we use for source and transferred
# samples of the cropped_linemod dataset.
# In the DSN work, 1/8 of the classifier weight worked well for our log
# quaternion loss
source_pose_weight=0.125 * 2.0,
transferred_pose_weight=0.125 * 1.0,
# If set to True, the style transfer network also attempts to change its
# weights to maximize the performance of the task tower. If set to False,
# then the style transfer network only attempts to change its weights to
# make the transferred images more likely according to the domain
# classifier.
task_tower_in_g_step=True,
task_loss_in_g_weight=1.0, # Weight of task loss in G
#########################################
# 'simple` generator arch model hparams #
#########################################
simple_num_conv_layers=1,
simple_conv_filters=8,
#########################
# Resnet Hyperparameters#
#########################
resnet_blocks=6, # Number of resnet blocks
resnet_filters=64, # Number of filters per conv in resnet blocks
# If true, add original input back to result of convolutions inside the
# resnet arch. If false, it turns into a simple stack of conv/relu/BN
# layers.
resnet_residuals=True,
#######################################
# The residual / interpretable model. #
#######################################
res_int_blocks=2, # The number of residual blocks.
res_int_convs=2, # The number of conv calls inside each block.
res_int_filters=64, # The number of filters used by each convolution.
####################
# Latent variables #
####################
# if true, then generate random noise and project to input for generator
noise_channel=True,
# The number of dimensions in the input noise vector.
noise_dims=10,
# If true, then one hot encode source image class and project as an
# additional channel for the input to generator. This gives the generator
# access to the class, which may help generation performance.
condition_on_source_class=False,
########################
# Loss Hyperparameters #
########################
domain_loss_weight=1.0,
style_transfer_loss_weight=1.0,
########################################################################
# Encourages the transferred images to be similar to the source images #
# using a configurable metric. #
########################################################################
# The weight of the loss function encouraging the source and transferred
# images to be similar. If set to 0, then the loss function is not used.
transferred_similarity_loss_weight=0.0,
# The type of loss used to encourage transferred and source image
# similarity. Valid values include:
# mpse: Mean Pairwise Squared Error
# mse: Mean Squared Error
# hinged_mse: Computes the mean squared error using squared differences
# greater than hparams.transferred_similarity_max_diff
# hinged_mae: Computes the mean absolute error using absolute
# differences greater than hparams.transferred_similarity_max_diff.
transferred_similarity_loss='mpse',
# The maximum allowable difference between the source and target images.
# This value is used, in effect, to produce a hinge loss. Note that the
# range of values should be between 0 and 1.
transferred_similarity_max_diff=0.4,
################################
# Optimization Hyperparameters #
################################
learning_rate=0.001,
batch_size=32,
lr_decay_steps=20000,
lr_decay_rate=0.95,
# Recomendation from the DCGAN paper:
adam_beta1=0.5,
clip_gradient_norm=5.0,
# The number of times we run the discriminator train_op in a row.
discriminator_steps=1,
# The number of times we run the generator train_op in a row.
generator_steps=1)
if hparam_string:
tf.logging.info('Parsing command line hparams: %s', hparam_string)
hparams.parse(hparam_string)
tf.logging.info('Final parsed hparams: %s', hparams.values())
return hparams
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evaluates the PIXELDA model.
-- Compiles the model for CPU.
$ bazel build -c opt third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Compile the model for GPU.
$ bazel build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Runs the training.
$ ./bazel-bin/third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation/pixelda_eval \
--source_dataset=mnist \
--target_dataset=mnist_m \
--dataset_dir=/tmp/datasets/ \
--alsologtostderr
-- Visualize the results.
$ bash learning/brain/tensorboard/tensorboard.sh \
--port 2222 --logdir=/tmp/pixelda/
"""
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/pixelda/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/pixelda/',
'Directory where the results are saved to.')
flags.DEFINE_integer('eval_interval_secs', 60,
'The frequency, in seconds, with which evaluation is run.')
flags.DEFINE_string('target_split_name', 'test',
'The name of the train/test split.')
flags.DEFINE_string('source_split_name', 'train', 'Split for source dataset.'
' Defaults to train.')
flags.DEFINE_string('source_dataset', 'mnist',
'The name of the source dataset.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string(
'dataset_dir',
'', # None,
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def run_eval(run_dir, checkpoint_dir, hparams):
"""Runs the eval loop.
Args:
run_dir: The directory where eval specific logs are placed
checkpoint_dir: The directory where the checkpoints are stored
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for checkpoint_path in slim.evaluation.checkpoints_iterator(
checkpoint_dir, FLAGS.eval_interval_secs):
with tf.Graph().as_default():
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name=FLAGS.target_split_name,
dataset_dir=FLAGS.dataset_dir)
target_images, target_labels = dataset_factory.provide_batch(
FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
target_labels['class'] = tf.argmax(target_labels['classes'], 1)
del target_labels['classes']
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name=FLAGS.source_split_name,
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Input and output datasets must have same number of classes')
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=False,
num_classes=num_target_classes)
#######################
# Metrics & Summaries #
#######################
names_to_values, names_to_updates = create_metrics(end_points,
source_labels,
target_labels, hparams)
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summarize_images(target_images, 'Target')
for name, value in names_to_values.iteritems():
tf.summary.scalar(name, value)
# Use the entire split by default
num_examples = target_dataset.num_samples
num_batches = math.ceil(num_examples / float(hparams.batch_size))
global_step = slim.get_or_create_global_step()
result = slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=checkpoint_path,
logdir=run_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
final_op=names_to_values)
def to_degrees(log_quaternion_loss):
"""Converts a log quaternion distance to an angle.
Args:
log_quaternion_loss: The log quaternion distance between two
unit quaternions (or a batch of pairs of quaternions).
Returns:
The angle in degrees of the implied angle-axis representation.
"""
return tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
def create_metrics(end_points, source_labels, target_labels, hparams):
"""Create metrics for the model.
Args:
end_points: A dictionary of end point name to tensor
source_labels: Labels for source images. batch_size x 1
target_labels: Labels for target images. batch_size x 1
hparams: The hyperparameters struct.
Returns:
Tuple of (names_to_values, names_to_updates), dictionaries that map a metric
name to its value and update op, respectively
"""
###########################################
# Evaluate the Domain Prediction Accuracy #
###########################################
batch_size = hparams.batch_size
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
('eval/Domain_Accuracy-Transferred'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points[
'transferred_domain_logits']))),
tf.zeros(batch_size, dtype=tf.int32)),
('eval/Domain_Accuracy-Target'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points['target_domain_logits']))),
tf.ones(batch_size, dtype=tf.int32))
})
################################
# Evaluate the task classifier #
################################
if 'source_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['source_task_logits'], 1),
source_labels['class'])
if 'transferred_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['transferred_task_logits'], 1),
source_labels['class'])
if 'target_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['target_task_logits'], 1),
target_labels['class'])
##########################################################################
# Pose data-specific losses.
##########################################################################
if 'quaternion' in source_labels.keys():
params = {}
params['use_logging'] = False
params['batch_size'] = batch_size
angle_loss_source = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'source_quaternion'], source_labels['quaternion'], params))
angle_loss_transferred = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'transferred_quaternion'], source_labels['quaternion'], params))
angle_loss_target = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'target_quaternion'], target_labels['quaternion'], params))
metric_name = 'eval/Angle_Loss-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_source)
metric_name = 'eval/Angle_Loss-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_transferred)
metric_name = 'eval/Angle_Loss-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_target)
return names_to_values, names_to_updates
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_eval(
run_dir=FLAGS.eval_dir,
checkpoint_dir=FLAGS.checkpoint_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the various loss functions in use by the PIXELDA model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
def add_domain_classifier_losses(end_points, hparams):
"""Adds losses related to the domain-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
hparams: The hyperparameters struct.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
if hparams.domain_loss_weight == 0:
tf.logging.info(
'Domain classifier loss weight is 0, so not creating losses.')
return 0
# The domain prediction loss is minimized with respect to the domain
# classifier features only. Its aim is to predict the domain of the images.
# Note: 1 = 'real image' label, 0 = 'fake image' label
transferred_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']),
logits=end_points['transferred_domain_logits'])
tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss)
target_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.ones_like(end_points['target_domain_logits']),
logits=end_points['target_domain_logits'])
tf.summary.scalar('Domain_loss_target', target_domain_loss)
# Compute the total domain loss:
total_domain_loss = transferred_domain_loss + target_domain_loss
total_domain_loss *= hparams.domain_loss_weight
tf.summary.scalar('Domain_loss_total', total_domain_loss)
return total_domain_loss
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(internal_dot_products, [
internal_dot_products,
tf.shape(internal_dot_products)
], 'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
def _quaternion_loss(labels, predictions, weight, batch_size, domain,
add_summaries):
"""Creates a Quaternion Loss.
Args:
labels: The true quaternions.
predictions: The predicted quaternions.
weight: A scalar weight.
batch_size: The size of the batches.
domain: The name of the domain from which the labels were taken.
add_summaries: Whether or not to add summaries for the losses.
Returns:
A `Tensor` representing the loss.
"""
assert domain in ['Source', 'Transferred']
params = {'use_logging': False, 'batch_size': batch_size}
loss = weight * log_quaternion_loss(labels, predictions, params)
if add_summaries:
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.histogram(
'Log_Quaternion_Loss_%s' % domain, loss, collections='losses')
tf.summary.scalar(
'Task_Quaternion_Loss_%s' % domain, loss, collections='losses')
return loss
def _add_task_specific_losses(end_points, source_labels, num_classes, hparams,
add_summaries=False):
"""Adds losses related to the task-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
add_summaries: Whether or not to add the summaries.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
# TODO(ddohan): Make sure the l2 regularization is added to the loss
one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes)
total_loss = 0
if 'source_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['source_task_logits'],
weights=hparams.source_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Source', loss)
total_loss += loss
if 'transferred_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['transferred_task_logits'],
weights=hparams.transferred_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Transferred', loss)
total_loss += loss
#########################
# Pose specific losses. #
#########################
if 'quaternion' in source_labels:
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['source_quaternion'],
hparams.source_pose_weight,
hparams.batch_size,
'Source',
add_summaries)
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['transferred_quaternion'],
hparams.transferred_pose_weight,
hparams.batch_size,
'Transferred',
add_summaries)
if add_summaries:
tf.summary.scalar('Task_Loss_Total', total_loss)
return total_loss
def _transferred_similarity_loss(reconstructions,
source_images,
weight=1.0,
method='mse',
max_diff=0.4,
name='similarity'):
"""Computes a loss encouraging similarity between source and transferred.
Args:
reconstructions: A `Tensor` of shape [batch_size, height, width, channels]
source_images: A `Tensor` of shape [batch_size, height, width, channels].
weight: Multiple similarity loss by this weight before returning
method: One of:
mpse = Mean Pairwise Squared Error
mse = Mean Squared Error
hinged_mse = Computes the mean squared error using squared differences
greater than hparams.transferred_similarity_max_diff
hinged_mae = Computes the mean absolute error using absolute
differences greater than hparams.transferred_similarity_max_diff.
max_diff: Maximum unpenalized difference for hinged losses
name: Identifying name to use for creating summaries
Returns:
A `Tensor` representing the transferred similarity loss.
Raises:
ValueError: if `method` is not recognized.
"""
if weight == 0:
return 0
source_channels = source_images.shape.as_list()[-1]
reconstruction_channels = reconstructions.shape.as_list()[-1]
# Convert grayscale source to RGB if target is RGB
if source_channels == 1 and reconstruction_channels != 1:
source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels])
if reconstruction_channels == 1 and source_channels != 1:
reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels])
if method == 'mpse':
reconstruction_similarity_loss_fn = (
tf.contrib.losses.mean_pairwise_squared_error)
elif method == 'masked_mpse':
def masked_mpse(predictions, labels, weight):
"""Masked mpse assuming we have a depth to create a mask from."""
assert labels.shape.as_list()[-1] == 4
mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99))
mask = tf.tile(mask, [1, 1, 1, 4])
predictions *= mask
labels *= mask
tf.image_summary('masked_pred', predictions)
tf.image_summary('masked_label', labels)
return tf.contrib.losses.mean_pairwise_squared_error(
predictions, labels, weight)
reconstruction_similarity_loss_fn = masked_mpse
elif method == 'mse':
reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error
elif method == 'hinged_mse':
def hinged_mse(predictions, labels, weight):
diffs = tf.square(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mse
elif method == 'hinged_mae':
def hinged_mae(predictions, labels, weight):
diffs = tf.abs(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mae
else:
raise ValueError('Unknown reconstruction loss %s' % method)
reconstruction_similarity_loss = reconstruction_similarity_loss_fn(
reconstructions, source_images, weight)
name = '%s_Similarity_(%s)' % (name, method)
tf.summary.scalar(name, reconstruction_similarity_loss)
return reconstruction_similarity_loss
def g_step_loss(source_images, source_labels, end_points, hparams, num_classes):
"""Configures the loss function which runs during the g-step.
Args:
source_images: A `Tensor` of shape [batch_size, height, width, channels].
source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys
are 'class' and 'quaternion'.
end_points: A map of the network end points.
hparams: The hyperparameters struct.
num_classes: Number of classes for classifier loss
Returns:
A `Tensor` representing a loss function.
Raises:
ValueError: if hparams.transferred_similarity_loss_weight is non-zero but
hparams.transferred_similarity_loss is invalid.
"""
generator_loss = 0
################################################################
# Adds a loss which encourages the discriminator probabilities #
# to be high (near one).
################################################################
# As per the GAN paper, maximize the log probs, instead of minimizing
# log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is
# the same thing.
style_transfer_loss = tf.losses.sigmoid_cross_entropy(
logits=end_points['transferred_domain_logits'],
multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']),
weights=hparams.style_transfer_loss_weight)
tf.summary.scalar('Style_transfer_loss', style_transfer_loss)
generator_loss += style_transfer_loss
# Optimizes the style transfer network to produce transferred images similar
# to the source images.
generator_loss += _transferred_similarity_loss(
end_points['transferred_images'],
source_images,
weight=hparams.transferred_similarity_loss_weight,
method=hparams.transferred_similarity_loss,
name='transferred_similarity')
# Optimizes the style transfer network to maximize classification accuracy.
if source_labels is not None and hparams.task_tower_in_g_step:
generator_loss += _add_task_specific_losses(
end_points, source_labels, num_classes,
hparams) * hparams.task_loss_in_g_weight
return generator_loss
def d_step_loss(end_points, source_labels, num_classes, hparams):
"""Configures the losses during the D-Step.
Note that during the D-step, the model optimizes both the domain (binary)
classifier and the task classifier.
Args:
end_points: A map of the network end points.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
Returns:
A `Tensor` representing the value of the D-step loss.
"""
domain_classifier_loss = add_domain_classifier_losses(end_points, hparams)
task_classifier_loss = 0
if source_labels is not None:
task_classifier_loss = _add_task_specific_losses(
end_points, source_labels, num_classes, hparams, add_summaries=True)
return domain_classifier_loss + task_classifier_loss
This diff is collapsed.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains functions for preprocessing the inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
def preprocess_classification(image, labels, is_training=False):
"""Preprocesses the image and labels for classification purposes.
Preprocessing includes shifting the images to be 0-centered between -1 and 1.
This is not only a popular method of preprocessing (inception) but is also
the mechanism used by DSNs.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
is_training: Whether or not we're training the model.
Returns:
The preprocessed image and labels.
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
return image, labels
def preprocess_style_transfer(image,
labels,
augment=False,
size=None,
is_training=False):
"""Preprocesses the image and labels for style transfer purposes.
Args:
image: A `Tensor` of size [height, width, 3].
labels: A dictionary of labels.
augment: Whether to apply data augmentation to inputs
size: The height and width to which images should be resized. If left as
`None`, then no resizing is performed
is_training: Whether or not we're training the model
Returns:
The preprocessed image and labels. Scaled to [-1, 1]
"""
# If the image is uint8, this will scale it to 0-1.
image = tf.image.convert_image_dtype(image, tf.float32)
if augment and is_training:
image = image_augmentation(image)
if size:
image = resize_image(image, size)
image -= 0.5
image *= 2
return image, labels
def image_augmentation(image):
"""Performs data augmentation by randomly permuting the inputs.
Args:
image: A float `Tensor` of size [height, width, channels] with values
in range[0,1].
Returns:
The mutated batch of images
"""
# Apply photometric data augmentation (contrast etc.)
num_channels = image.shape_as_list()[-1]
if num_channels == 4:
# Only augment image part
image, depth = image[:, :, 0:3], image[:, :, 3:4]
elif num_channels == 1:
image = tf.image.grayscale_to_rgb(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.clip_by_value(image, 0, 1.0)
if num_channels == 4:
image = tf.concat(2, [image, depth])
elif num_channels == 1:
image = tf.image.rgb_to_grayscale(image)
return image
def resize_image(image, size=None):
"""Resize image to target size.
Args:
image: A `Tensor` of size [height, width, 3].
size: (height, width) to resize image to.
Returns:
resized image
"""
if size is None:
raise ValueError('Must specify size')
if image.shape_as_list()[:2] == size:
# Don't resize if not necessary
return image
image = tf.expand_dims(image, 0)
image = tf.image.resize_images(image, size)
image = tf.squeeze(image, 0)
return image
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for domain_adaptation.pixel_domain_adaptation.pixelda_preprocess."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
class PixelDAPreprocessTest(tf.test.TestCase):
def assert_preprocess_classification_is_centered(self, dtype, is_training):
tf.set_random_seed(0)
if dtype == tf.uint8:
image = tf.random_uniform((100, 200, 3), maxval=255, dtype=tf.int64)
image = tf.cast(image, tf.uint8)
else:
image = tf.random_uniform((100, 200, 3), maxval=1.0, dtype=dtype)
labels = {}
image, labels = pixelda_preprocess.preprocess_classification(
image, labels, is_training=is_training)
with self.test_session() as sess:
np_image = sess.run(image)
self.assertTrue(np_image.min() <= -0.95)
self.assertTrue(np_image.min() >= -1.0)
self.assertTrue(np_image.max() >= 0.95)
self.assertTrue(np_image.max() <= 1.0)
def testPreprocessClassificationZeroCentersUint8DuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=True)
def testPreprocessClassificationZeroCentersUint8DuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.uint8, is_training=False)
def testPreprocessClassificationZeroCentersFloatDuringTrain(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=True)
def testPreprocessClassificationZeroCentersFloatDuringTest(self):
self.assert_preprocess_classification_is_centered(
tf.float32, is_training=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task towers for PixelDA model."""
import tensorflow as tf
slim = tf.contrib.slim
def add_task_specific_model(images,
hparams,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope=None):
"""Create a classifier for the given images.
The classifier is composed of a few 'private' layers followed by a few
'shared' layers. This lets us account for different image 'style', while
sharing the last few layers as 'content' layers.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
hparams: model hparams
num_classes: The number of output classes.
is_training: whether model is training
reuse_private: Whether or not to reuse the private weights, which are the
first few layers in the classifier
private_scope: The name of the variable_scope for the private (unshared)
components of the classifier.
reuse_shared: Whether or not to reuse the shared weights, which are the last
few layers in the classifier
shared_scope: The name of the variable_scope for the shared components of
the classifier.
Returns:
The logits, a `Tensor` of shape [batch_size, num_classes].
Raises:
ValueError: If hparams.task_classifier is an unknown value
"""
model = hparams.task_tower
# Make sure the classifier name shows up in graph
shared_scope = shared_scope or (model + '_shared')
kwargs = {
'num_classes': num_classes,
'is_training': is_training,
'reuse_private': reuse_private,
'reuse_shared': reuse_shared,
}
if private_scope:
kwargs['private_scope'] = private_scope
if shared_scope:
kwargs['shared_scope'] = shared_scope
quaternion_pred = None
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=tf.contrib.layers.l2_regularizer(
hparams.weight_decay_task_classifier)):
with slim.arg_scope([slim.conv2d], padding='SAME'):
if model == 'doubling_pose_estimator':
logits, quaternion_pred = doubling_cnn_class_and_quaternion(
images, num_private_layers=hparams.num_private_layers, **kwargs)
elif model == 'mnist':
logits, _ = mnist_classifier(images, **kwargs)
elif model == 'svhn':
logits, _ = svhn_classifier(images, **kwargs)
elif model == 'gtsrb':
logits, _ = gtsrb_classifier(images, **kwargs)
elif model == 'pose_mini':
logits, quaternion_pred = pose_mini_tower(images, **kwargs)
else:
raise ValueError('Unknown task classifier %s' % model)
return logits, quaternion_pred
#####################################
# Classifiers used in the DSN paper #
#####################################
def mnist_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope='mnist',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional MNIST model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits, endpoints = conv_mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 48, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool2']), 100, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 100, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def svhn_classifier(images,
is_training=False,
num_classes=10,
reuse_private=False,
private_scope=None,
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional SVHN model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [3, 3], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 64, [5, 5], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [3, 3], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 128, [5, 5], scope='conv3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['conv3']), 3072, scope='fc3')
net['fc4'] = slim.fully_connected(
slim.flatten(net['fc3']), 2048, scope='fc4')
logits = slim.fully_connected(
net['fc4'], num_classes, activation_fn=None, scope='fc5')
return logits, net
def gtsrb_classifier(images,
is_training=False,
num_classes=43,
reuse_private=False,
private_scope='gtsrb',
reuse_shared=False,
shared_scope='task_model'):
"""Creates the convolutional GTSRB model from the gradient reversal paper.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
predictions = tf.nn.softmax(logits)
Args:
images: the SVHN digits, a tensor of size [batch_size, 40, 40, 3].
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
num_classes: the number of output classes to use.
reuse_private: Whether or not to reuse the private components of the model.
private_scope: The name of the private scope.
reuse_shared: Whether or not to reuse the shared components of the model.
shared_scope: The name of the shared scope.
Returns:
the output logits, a tensor of size [batch_size, num_classes].
a dictionary with key/values the layer names and tensors.
"""
net = {}
with tf.variable_scope(private_scope, reuse=reuse_private):
net['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
net['pool1'] = slim.max_pool2d(net['conv1'], [2, 2], 2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net['conv2'] = slim.conv2d(net['pool1'], 144, [3, 3], scope='conv2')
net['pool2'] = slim.max_pool2d(net['conv2'], [2, 2], 2, scope='pool2')
net['conv3'] = slim.conv2d(net['pool2'], 256, [5, 5], scope='conv3')
net['pool3'] = slim.max_pool2d(net['conv3'], [2, 2], 2, scope='pool3')
net['fc3'] = slim.fully_connected(
slim.flatten(net['pool3']), 512, scope='fc3')
logits = slim.fully_connected(
net['fc3'], num_classes, activation_fn=None, scope='fc4')
return logits, net
#########################
# pose_mini task towers #
#########################
def pose_mini_tower(images,
num_classes=11,
is_training=False,
reuse_private=False,
private_scope='pose_mini',
reuse_shared=False,
shared_scope='task_model'):
"""Task tower for the pose_mini dataset."""
with tf.variable_scope(private_scope, reuse=reuse_private):
net = slim.conv2d(images, 32, [5, 5], scope='conv1')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1')
with tf.variable_scope(shared_scope, reuse=reuse_shared):
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2')
net = slim.flatten(net)
net = slim.fully_connected(net, 128, scope='fc3')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
with tf.variable_scope('quaternion_prediction'):
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc4')
return logits, quaternion_pred
def doubling_cnn_class_and_quaternion(images,
num_private_layers=1,
num_classes=10,
is_training=False,
reuse_private=False,
private_scope='doubling_cnn',
reuse_shared=False,
shared_scope='task_model'):
"""Alternate conv, pool while doubling filter count."""
net = images
depth = 32
layer_id = 1
with tf.variable_scope(private_scope, reuse=reuse_private):
while num_private_layers > 0 and net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
num_private_layers -= 1
with tf.variable_scope(shared_scope, reuse=reuse_shared):
while net.shape.as_list()[1] > 5:
net = slim.conv2d(net, depth, [3, 3], scope='conv%s' % layer_id)
net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool%s' % layer_id)
depth *= 2
layer_id += 1
net = slim.flatten(net)
net = slim.fully_connected(net, 100, scope='fc1')
net = slim.dropout(net, 0.5, is_training=is_training, scope='dropout')
quaternion_pred = slim.fully_connected(
net, 4, activation_fn=tf.tanh, scope='fc_q')
quaternion_pred = tf.nn.l2_normalize(quaternion_pred, 1)
logits = slim.fully_connected(
net, num_classes, activation_fn=None, scope='fc_logits')
return logits, quaternion_pred
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the PixelDA model."""
from functools import partial
import os
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_string('train_log_dir', '/tmp/pixelda/',
'Directory where to write event logs.')
flags.DEFINE_integer(
'save_summaries_steps', 500,
'The frequency with which summaries are saved, in seconds.')
flags.DEFINE_integer('save_interval_secs', 300,
'The frequency with which the model is saved, in seconds.')
flags.DEFINE_boolean('summarize_gradients', False,
'Whether to summarize model gradients')
flags.DEFINE_integer(
'print_loss_steps', 100,
'The frequency with which the losses are printed, in steps.')
flags.DEFINE_string('source_dataset', 'mnist', 'The name of the source dataset.'
' If hparams="arch=dcgan", this flag is ignored.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string('source_split_name', 'train',
'Name of the train split for the source.')
flags.DEFINE_string('target_split_name', 'train',
'Name of the train split for the target.')
flags.DEFINE_string('dataset_dir', '',
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def _get_vars_and_update_ops(hparams, scope):
"""Returns the variables and update ops for a particular variable scope.
Args:
hparams: The hyperparameters struct.
scope: The variable scope.
Returns:
A tuple consisting of trainable variables and update ops.
"""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = filter(is_trainable, slim.get_model_variables(scope))
global_step = slim.get_or_create_global_step()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)
tf.logging.info('All variables for scope: %s',
slim.get_model_variables(scope))
tf.logging.info('Trainable variables for scope: %s', var_list)
return var_list, update_ops
def _train(discriminator_train_op,
generator_train_op,
logdir,
master='',
is_chief=True,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=100,
hparams=None):
"""Runs the training loop.
Args:
discriminator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the discriminator.
generator_train_op: A `Tensor` that, when executed, will apply the
gradients and return the loss value for the generator.
logdir: The directory where the graph and checkpoints are saved.
master: The URL of the master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
scaffold: An tf.train.Scaffold instance.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
training loop.
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
inside the training loop for the chief trainer only.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If `save_checkpoint_secs` is set to
`None`, then the default checkpoint saver isn't used.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If
`save_summaries_steps` is set to `None`, then the default summary saver
isn't used.
hparams: The hparams struct.
Returns:
the value of the loss function after training.
Raises:
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
`save_summaries_steps` are `None.
"""
global_step = slim.get_or_create_global_step()
scaffold = scaffold or tf.train.Scaffold()
hooks = hooks or []
if is_chief:
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold, checkpoint_dir=logdir, master=master)
if chief_only_hooks:
hooks.extend(chief_only_hooks)
hooks.append(tf.train.StepCounterHook(output_dir=logdir))
if save_summaries_steps:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_summaries_steps is None')
hooks.append(
tf.train.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
output_dir=logdir))
if save_checkpoint_secs:
if logdir is None:
raise ValueError(
'logdir cannot be None when save_checkpoint_secs is None')
hooks.append(
tf.train.CheckpointSaverHook(
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
else:
session_creator = tf.train.WorkerSessionCreator(
scaffold=scaffold, master=master)
with tf.train.MonitoredSession(
session_creator=session_creator, hooks=hooks) as session:
loss = None
while not session.should_stop():
# Run the domain classifier op X times.
for _ in range(hparams.discriminator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run(
[discriminator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Discriminator Loss = %.2f', np_global_step,
loss)
# Run the generator op X times.
for _ in range(hparams.generator_steps):
if session.should_stop():
return loss
loss, np_global_step = session.run([generator_train_op, global_step])
if np_global_step % FLAGS.print_loss_steps == 0:
tf.logging.info('Step %d: Generator Loss = %.2f', np_global_step,
loss)
return loss
def run_training(run_dir, checkpoint_dir, hparams):
"""Runs the training loop.
Args:
run_dir: The directory where training specific logs are placed
checkpoint_dir: The directory where the checkpoints and log files are
stored.
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for path in [run_dir, checkpoint_dir]:
if not tf.gfile.Exists(path):
tf.gfile.MakeDirs(path)
# Serialize hparams to log dir
hparams_filename = os.path.join(checkpoint_dir, 'hparams.json')
with tf.gfile.FastGFile(hparams_filename, 'w') as f:
f.write(hparams.to_json())
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
target_images, _ = dataset_factory.provide_batch(
FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name='train',
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
hparams.batch_size, FLAGS.num_preprocessing_threads)
# Data provider provides 1 hot labels, but we expect categorical.
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Source and Target datasets must have same number of classes. '
'Are %d and %d' % (num_source_classes, num_target_classes))
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=True,
num_classes=num_target_classes)
#################################
# Get the variables to optimize #
#################################
generator_vars, generator_update_ops = _get_vars_and_update_ops(
hparams, 'generator')
discriminator_vars, discriminator_update_ops = _get_vars_and_update_ops(
hparams, 'discriminator')
########################
# Configure the losses #
########################
generator_loss = pixelda_losses.g_step_loss(
source_images,
source_labels,
end_points,
hparams,
num_classes=num_target_classes)
discriminator_loss = pixelda_losses.d_step_loss(
end_points, source_labels, num_target_classes, hparams)
###########################
# Create the training ops #
###########################
learning_rate = hparams.learning_rate
if hparams.lr_decay_steps:
learning_rate = tf.train.exponential_decay(
learning_rate,
slim.get_or_create_global_step(),
decay_steps=hparams.lr_decay_steps,
decay_rate=hparams.lr_decay_rate,
staircase=True)
tf.summary.scalar('Learning_rate', learning_rate)
if hparams.discriminator_steps == 0:
discriminator_train_op = tf.no_op()
else:
discriminator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
discriminator_train_op = slim.learning.create_train_op(
discriminator_loss,
discriminator_optimizer,
update_ops=discriminator_update_ops,
variables_to_train=discriminator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
if hparams.generator_steps == 0:
generator_train_op = tf.no_op()
else:
generator_optimizer = tf.train.AdamOptimizer(
learning_rate, beta1=hparams.adam_beta1)
generator_train_op = slim.learning.create_train_op(
generator_loss,
generator_optimizer,
update_ops=generator_update_ops,
variables_to_train=generator_vars,
clip_gradient_norm=hparams.clip_gradient_norm,
summarize_gradients=FLAGS.summarize_gradients)
#############
# Summaries #
#############
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summaries_color_distributions(end_points['transferred_images'],
'Transferred')
pixelda_utils.summaries_color_distributions(target_images, 'Target')
if source_images is not None:
pixelda_utils.summarize_transferred(source_images,
end_points['transferred_images'])
pixelda_utils.summaries_color_distributions(source_images, 'Source')
pixelda_utils.summaries_color_distributions(
tf.abs(source_images - end_points['transferred_images']),
'Abs(Source_minus_Transferred)')
number_of_steps = None
if hparams.num_training_examples:
# Want to control by amount of data seen, not # steps
number_of_steps = hparams.num_training_examples / hparams.batch_size
hooks = [tf.train.StepCounterHook(),]
chief_only_hooks = [
tf.train.CheckpointSaverHook(
saver=tf.train.Saver(),
checkpoint_dir=run_dir,
save_secs=FLAGS.save_interval_secs)
]
if number_of_steps:
hooks.append(tf.train.StopAtStepHook(last_step=number_of_steps))
_train(
discriminator_train_op,
generator_train_op,
logdir=run_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=None,
save_summaries_steps=FLAGS.save_summaries_steps,
hparams=hparams)
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_training(
run_dir=FLAGS.train_log_dir,
checkpoint_dir=FLAGS.train_log_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for PixelDA model."""
import math
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
def remove_depth(images):
"""Takes a batch of images and remove depth channel if present."""
if images.shape.as_list()[-1] == 4:
return images[:, :, :, 0:3]
return images
def image_grid(images, max_grid_size=4):
"""Given images and N, return first N^2 images as an NxN image grid.
Args:
images: a `Tensor` of size [batch_size, height, width, channels]
max_grid_size: Maximum image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
images = remove_depth(images)
batch_size = images.shape.as_list()[0]
grid_size = min(int(math.sqrt(batch_size)), max_grid_size)
assert images.shape.as_list()[0] >= grid_size * grid_size
# If we have a depth channel
if images.shape.as_list()[-1] == 4:
images = images[:grid_size * grid_size, :, :, 0:3]
depth = tf.image.grayscale_to_rgb(images[:grid_size * grid_size, :, :, 3:4])
images = tf.reshape(images, [-1, images.shape.as_list()[2], 3])
split = tf.split(0, grid_size, images)
depth = tf.reshape(depth, [-1, images.shape.as_list()[2], 3])
depth_split = tf.split(0, grid_size, depth)
grid = tf.concat(split + depth_split, 1)
return tf.expand_dims(grid, 0)
else:
images = images[:grid_size * grid_size, :, :, :]
images = tf.reshape(
images, [-1, images.shape.as_list()[2],
images.shape.as_list()[3]])
split = tf.split(images, grid_size, 0)
grid = tf.concat(split, 1)
return tf.expand_dims(grid, 0)
def source_and_output_image_grid(output_images,
source_images=None,
max_grid_size=4):
"""Create NxN image grid for output, concatenate source grid if given.
Makes grid out of output_images and, if provided, source_images, and
concatenates them.
Args:
output_images: [batch_size, h, w, c] tensor of images
source_images: optional[batch_size, h, w, c] tensor of images
max_grid_size: Image grid height/width
Returns:
Single image batch, of dim [1, h*n, w*n, c]
"""
output_grid = image_grid(output_images, max_grid_size=max_grid_size)
if source_images is not None:
source_grid = image_grid(source_images, max_grid_size=max_grid_size)
# Make sure they have the same # of channels before concat
# Assumes either 1 or 3 channels
if output_grid.shape.as_list()[-1] != source_grid.shape.as_list()[-1]:
if output_grid.shape.as_list()[-1] == 1:
output_grid = tf.tile(output_grid, [1, 1, 1, 3])
if source_grid.shape.as_list()[-1] == 1:
source_grid = tf.tile(source_grid, [1, 1, 1, 3])
output_grid = tf.concat([output_grid, source_grid], 1)
return output_grid
def summarize_model(end_points):
"""Summarizes the given model via its end_points.
Args:
end_points: A dictionary of end_point names to `Tensor`.
"""
tf.summary.histogram('domain_logits_transferred',
tf.sigmoid(end_points['transferred_domain_logits']))
tf.summary.histogram('domain_logits_target',
tf.sigmoid(end_points['target_domain_logits']))
def summarize_transferred_grid(transferred_images,
source_images=None,
name='Transferred'):
"""Produces a visual grid summarization of the image transferrence.
Args:
transferred_images: A `Tensor` of size [batch_size, height, width, c].
source_images: A `Tensor` of size [batch_size, height, width, c].
name: Name to use in summary name
"""
if source_images is not None:
grid = source_and_output_image_grid(transferred_images, source_images)
else:
grid = image_grid(transferred_images)
tf.summary.image('%s_Images_Grid' % name, grid, max_outputs=1)
def summarize_transferred(source_images,
transferred_images,
max_images=20,
name='Transferred'):
"""Produces a visual summary of the image transferrence.
This summary displays the source image, transferred image, and a grayscale
difference image which highlights the differences between input and output.
Args:
source_images: A `Tensor` of size [batch_size, height, width, channels].
transferred_images: A `Tensor` of size [batch_size, height, width, channels]
max_images: The number of images to show.
name: Name to use in summary name
Raises:
ValueError: If number of channels in source and target are incompatible
"""
source_channels = source_images.shape.as_list()[-1]
transferred_channels = transferred_images.shape.as_list()[-1]
if source_channels < transferred_channels:
if source_channels != 1:
raise ValueError(
'Source must be 1 channel or same # of channels as target')
source_images = tf.tile(source_images, [1, 1, 1, transferred_channels])
if transferred_channels < source_channels:
if transferred_channels != 1:
raise ValueError(
'Target must be 1 channel or same # of channels as source')
transferred_images = tf.tile(transferred_images, [1, 1, 1, source_channels])
diffs = tf.abs(source_images - transferred_images)
diffs = tf.reduce_max(diffs, reduction_indices=[3], keep_dims=True)
diffs = tf.tile(diffs, [1, 1, 1, max(source_channels, transferred_channels)])
transition_images = tf.concat([
source_images,
transferred_images,
diffs,
], 2)
tf.summary.image(
'%s_difference' % name, transition_images, max_outputs=max_images)
def summaries_color_distributions(images, name):
"""Produces a histogram of the color distributions of the images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
tf.summary.histogram('color_values/%s' % name, images)
def summarize_images(images, name):
"""Produces a visual summary of the given images.
Args:
images: A `Tensor` of size [batch_size, height, width, 3].
name: The name of the images being summarized.
"""
grid = image_grid(images)
tf.summary.image('%s_Images' % name, grid, max_outputs=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment