Commit 4dcd5116 authored by Jing Li's avatar Jing Li
Browse files

Removed deprecated research/gan. Please visit https://github.com/tensorflow/gan.

parent f673f7a8
# TFGAN Examples
[TFGAN](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan) is a lightweight library for training and evaluating Generative
Adversarial Networks (GANs). GANs have been in a wide range of tasks
including [image translation](https://arxiv.org/abs/1703.10593), [superresolution](https://arxiv.org/abs/1609.04802), and [data augmentation](https://arxiv.org/abs/1612.07828). This directory contains fully-working examples
that demonstrate the ease and flexibility of TFGAN. Each subdirectory contains a
different working example. The sub-sections below describe each of the problems,
and include some sample outputs. We've also included a [jupyter notebook](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb), which
provides a walkthrough of TFGAN.
## Contacts
Maintainers of TFGAN:
* Joel Shor,
github: [joel-shor](https://github.com/joel-shor)
## Table of contents
1. [MNIST](#mnist)
1. [MNIST with GANEstimator](#mnist_estimator)
1. [CIFAR10](#cifar10)
1. [Image compression](#compression)
## MNIST
<a id='mnist'></a>
We train a simple generator to produce [MNIST digits](http://yann.lecun.com/exdb/mnist/).
The unconditional case maps noise to MNIST digits. The conditional case maps
noise and digit class to MNIST digits. [InfoGAN](https://arxiv.org/abs/1606.03657) learns to produce
digits of a given class without labels, as well as controlling style. The
network architectures are defined [here](https://github.com/tensorflow/models/tree/master/research/gan/mnist/networks.py).
We use a classifier trained on MNIST digit classification for evaluation.
### Unconditional MNIST
<img src="g3doc/mnist_unconditional_gan.png" title="Unconditional GAN" width="330" />
### Conditional MNIST
<img src="g3doc/mnist_conditional_gan.png" title="Conditional GAN" width="330" />
### InfoGAN MNIST
<img src="g3doc/mnist_infogan.png" title="InfoGAN" width="330" />
## MNIST with GANEstimator
<a id='mnist_estimator'></a>
This setup is exactly the same as in the [unconditional MNIST example](#mnist), but
uses the `tf.Learn` `GANEstimator`.
<img src="g3doc/mnist_estimator_unconditional_gan.png" title="Unconditional GAN" width="330" />
## CIFAR10
<a id='cifar10'></a>
We train a [DCGAN generator](https://arxiv.org/abs/1511.06434) to produce [CIFAR10 images](https://www.cs.toronto.edu/~kriz/cifar.html).
The unconditional case maps noise to CIFAR10 images. The conditional case maps
noise and image class to CIFAR10 images. The
network architectures are defined [here](https://github.com/tensorflow/models/tree/master/research/gan/cifar/networks.py).
We use the [Inception Score](https://arxiv.org/abs/1606.03498) to evaluate the images.
### Unconditional CIFAR10
<img src="g3doc/cifar_unconditional_gan.png" title="Unconditional GAN" width="330" />
### Conditional CIFAR10
<img src="g3doc/cifar_conditional_gan.png" title="Conditional GAN" width="330" />
## Image compression
<a id='compression'></a>
In neural image compression, we attempt to reduce an image to a smaller representation
such that we can recreate the original image as closely as possible. See [`Full Resolution Image Compression with Recurrent Neural Networks`](https://arxiv.org/abs/1608.05148) for more details on using neural networks
for image compression.
In this example, we train an encoder to compress images to a compressed binary
representation and a decoder to map the binary representation back to the image.
We treat both systems together (encoder -> decoder) as the generator.
A typical image compression trained on L1 pixel loss will decode into
blurry images. We use an adversarial loss to force the outputs to be more
plausible.
This example also highlights the following infrastructure challenges:
* When you have custom code to keep track of your variables
Some other notes on the problem:
* Since the network is fully convolutional, we train on image patches.
* Bottleneck layer is floating point during training and binarized during
evaluation.
### Results
#### No adversarial loss
<img src="g3doc/compression_wf0.png" title="No adversarial loss" width="500" />
#### Adversarial loss
<img src="g3doc/compression_wf10000.png" title="With adversarial loss" width="500" />
### Architectures
#### Compression Network
The compression network is a DCGAN discriminator for the encoder and a DCGAN
generator for the decoder from [`Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks`](https://arxiv.org/abs/1511.06434).
The binarizer adds uniform noise during training then binarizes during eval, as in
[`End-to-end Optimized Image Compression`](https://arxiv.org/abs/1611.01704).
#### Discriminator
The discriminator looks at 70x70 patches, as in
[`Image-to-Image Translation with Conditional Adversarial Networks`](https://arxiv.org/abs/1611.07004).
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the CIFAR data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(batch_size, dataset_dir, dataset_name='cifar10',
split_name='train', one_hot=True):
"""Provides batches of CIFAR data.
Args:
batch_size: The number of images in each batch.
dataset_dir: The directory where the CIFAR10 data can be found. If `None`,
use default.
dataset_name: Name of the dataset.
split_name: Should be either 'train' or 'test'.
one_hot: Output one hot vector instead of int32 label.
Returns:
images: A `Tensor` of size [batch_size, 32, 32, 3]. Output pixel values are
in [-1, 1].
labels: Either (1) one_hot_labels if `one_hot` is `True`
A `Tensor` of size [batch_size, num_classes], where each row has a
single element set to one and the rest set to zeros.
Or (2) labels if `one_hot` is `False`
A `Tensor` of size [batch_size], holding the labels as integers.
num_samples: The number of total samples in the dataset.
num_classes: The number of classes in the dataset.
Raises:
ValueError: if the split_name is not either 'train' or 'test'.
"""
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=(split_name == 'train'))
[image, label] = provider.get(['image', 'label'])
# Preprocess the images.
image = (tf.to_float(image) - 128.0) / 128.0
# Creates a QueueRunner for the pre-fetching operation.
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=32,
capacity=5 * batch_size)
labels = tf.reshape(labels, [-1])
if one_hot:
labels = tf.one_hot(labels, dataset.num_classes)
return images, labels, dataset.num_samples, dataset.num_classes
def float_image_to_uint8(image):
"""Convert float image in [-1, 1) to [0, 255] uint8.
Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
Args:
image: An image tensor. Values should be in [-1, 1).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 128.0) + 128.0
return tf.cast(image, tf.uint8)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
def test_cifar10_train_set(self):
dataset_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cifar/testdata')
batch_size = 4
images, labels, num_samples, num_classes = data_provider.provide_data(
batch_size, dataset_dir)
self.assertEqual(50000, num_samples)
self.assertEqual(10, num_classes)
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out, labels_out = sess.run([images, labels])
self.assertEqual(images_out.shape, (batch_size, 32, 32, 3))
expected_label_shape = (batch_size, 10)
self.assertEqual(expected_label_shape, labels_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained CIFAR model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
import data_provider
import networks
import util
FLAGS = flags.FLAGS
tfgan = tf.contrib.gan
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/cifar10/',
'Directory where the results are saved to.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
flags.DEFINE_integer('num_images_generated', 100,
'Number of images to generate at once.')
flags.DEFINE_integer('num_inception_images', 10,
'The number of images to run through Inception at once.')
flags.DEFINE_boolean('eval_real_images', False,
'If `True`, run Inception network on real images.')
flags.DEFINE_boolean('conditional_eval', False,
'If `True`, set up a conditional GAN.')
flags.DEFINE_boolean('eval_frechet_inception_distance', True,
'If `True`, compute Frechet Inception distance using real '
'images and generated images.')
flags.DEFINE_integer('num_images_per_class', 10,
'When a conditional generator is used, this is the number '
'of images to display per class.')
flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'forever.')
flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.')
flags.DEFINE_integer(
'inter_op_parallelism_threads', 0,
'Number of threads to use for inter-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
flags.DEFINE_integer(
'intra_op_parallelism_threads', 0,
'Number of threads to use for intra-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
def main(_, run_eval_loop=True):
# Fetch and generate images to run through Inception.
with tf.name_scope('inputs'):
real_data, num_classes = _get_real_data(
FLAGS.num_images_generated, FLAGS.dataset_dir)
generated_data = _get_generated_data(
FLAGS.num_images_generated, FLAGS.conditional_eval, num_classes)
# Compute Frechet Inception Distance.
if FLAGS.eval_frechet_inception_distance:
fid = util.get_frechet_inception_distance(
real_data, generated_data, FLAGS.num_images_generated,
FLAGS.num_inception_images)
tf.summary.scalar('frechet_inception_distance', fid)
# Compute normal Inception scores.
if FLAGS.eval_real_images:
inc_score = util.get_inception_scores(
real_data, FLAGS.num_images_generated, FLAGS.num_inception_images)
else:
inc_score = util.get_inception_scores(
generated_data, FLAGS.num_images_generated, FLAGS.num_inception_images)
tf.summary.scalar('inception_score', inc_score)
# If conditional, display an image grid of difference classes.
if FLAGS.conditional_eval and not FLAGS.eval_real_images:
reshaped_imgs = util.get_image_grid(
generated_data, FLAGS.num_images_generated, num_classes,
FLAGS.num_images_per_class)
tf.summary.image('generated_data', reshaped_imgs, max_outputs=1)
# Create ops that write images to disk.
image_write_ops = None
if FLAGS.conditional_eval and FLAGS.write_to_disk:
reshaped_imgs = util.get_image_grid(
generated_data, FLAGS.num_images_generated, num_classes,
FLAGS.num_images_per_class)
uint8_images = data_provider.float_image_to_uint8(reshaped_imgs)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'conditional_cifar10.png'),
tf.image.encode_png(uint8_images[0]))
else:
if FLAGS.num_images_generated >= 100 and FLAGS.write_to_disk:
reshaped_imgs = tfgan.eval.image_reshaper(
generated_data[:100], num_cols=FLAGS.num_images_per_class)
uint8_images = data_provider.float_image_to_uint8(reshaped_imgs)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'unconditional_cifar10.png'),
tf.image.encode_png(uint8_images[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
sess_config = tf.ConfigProto(
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
config=sess_config,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
def _get_real_data(num_images_generated, dataset_dir):
"""Get real images."""
data, _, _, num_classes = data_provider.provide_data(
num_images_generated, dataset_dir)
return data, num_classes
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
"""Get generated images."""
noise = tf.random_normal([num_images_generated, 64])
# If conditional, generate class-specific images.
if conditional_eval:
conditioning = util.get_generator_conditioning(
num_images_generated, num_classes)
generator_inputs = (noise, conditioning)
generator_fn = networks.conditional_generator
else:
generator_inputs = noise
generator_fn = networks.generator
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('Generator'):
data = generator_fn(generator_inputs, is_training=False)
return data
if __name__ == '__main__':
app.run(main)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.cifar.eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
import eval # pylint:disable=redefined-builtin
FLAGS = flags.FLAGS
mock = tf.test.mock
class EvalTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('RealData', True, False),
('GeneratedData', False, False),
('GeneratedDataConditional', False, True))
def test_build_graph(self, eval_real_images, conditional_eval):
FLAGS.eval_real_images = eval_real_images
FLAGS.conditional_eval = conditional_eval
# Mock `frechet_inception_distance` and `inception_score`, which are
# expensive.
with mock.patch.object(
eval.util, 'get_frechet_inception_distance') as mock_fid:
with mock.patch.object(eval.util, 'get_inception_scores') as mock_iscore:
mock_fid.return_value = 1.0
mock_iscore.return_value = 1.0
eval.main(None, run_eval_loop=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the CIFAR dataset.
# 2. Trains an unconditional or conditional model on the CIFAR training set.
# 3. Evaluates the models and writes sample images to disk.
#
#
# With the default batch size and number of steps, train times are:
#
# Usage:
# cd models/research/gan/cifar
# ./launch_jobs.sh ${gan_type} ${git_repo}
set -e
# Type of GAN to run. Right now, options are `unconditional`, `conditional`, or
# `infogan`.
gan_type=$1
if ! [[ "$gan_type" =~ ^(unconditional|conditional) ]]; then
echo "'gan_type' must be one of: 'unconditional', 'conditional'."
exit
fi
# Location of the git repository.
git_repo=$2
if [[ "$git_repo" == "" ]]; then
echo "'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/cifar-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR=/tmp/cifar-model/eval
# Where the dataset is saved to.
DATASET_DIR=/tmp/cifar-data
export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/gan:$git_repo/research/slim
# A helper function for printing pretty output.
Banner () {
local text=$1
local green='\033[0;32m'
local nc='\033[0m' # No color.
echo -e "${green}${text}${nc}"
}
# Download the dataset.
python "${git_repo}/research/slim/download_and_convert_data.py" \
--dataset_name=cifar10 \
--dataset_dir=${DATASET_DIR}
# Run unconditional GAN.
if [[ "$gan_type" == "unconditional" ]]; then
UNCONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/unconditional"
UNCONDITIONAL_EVAL_DIR="${EVAL_DIR}/unconditional"
NUM_STEPS=10000
# Run training.
Banner "Starting training unconditional GAN for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/cifar/train.py" \
--train_log_dir=${UNCONDITIONAL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--gan_type="unconditional" \
--alsologtostderr
Banner "Finished training unconditional GAN ${NUM_STEPS} steps."
# Run evaluation.
Banner "Starting evaluation of unconditional GAN..."
python "${git_repo}/research/gan/cifar/eval.py" \
--checkpoint_dir=${UNCONDITIONAL_TRAIN_DIR} \
--eval_dir=${UNCONDITIONAL_EVAL_DIR} \
--dataset_dir=${DATASET_DIR} \
--eval_real_images=false \
--conditional_eval=false \
--max_number_of_evaluations=1
Banner "Finished unconditional evaluation. See ${UNCONDITIONAL_EVAL_DIR} for output images."
fi
# Run conditional GAN.
if [[ "$gan_type" == "conditional" ]]; then
CONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/conditional"
CONDITIONAL_EVAL_DIR="${EVAL_DIR}/conditional"
NUM_STEPS=10000
# Run training.
Banner "Starting training conditional GAN for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/cifar/train.py" \
--train_log_dir=${CONDITIONAL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--gan_type="conditional" \
--alsologtostderr
Banner "Finished training conditional GAN ${NUM_STEPS} steps."
# Run evaluation.
Banner "Starting evaluation of conditional GAN..."
python "${git_repo}/research/gan/cifar/eval.py" \
--checkpoint_dir=${CONDITIONAL_TRAIN_DIR} \
--eval_dir=${CONDITIONAL_EVAL_DIR} \
--dataset_dir=${DATASET_DIR} \
--eval_real_images=false \
--conditional_eval=true \
--max_number_of_evaluations=1
Banner "Finished conditional evaluation. See ${CONDITIONAL_EVAL_DIR} for output images."
fi
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN CIFAR example using TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.nets import dcgan
tfgan = tf.contrib.gan
def _last_conv_layer(end_points):
""""Returns the last convolutional layer from an endpoints dictionary."""
conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
conv_list.sort()
return end_points[conv_list[-1]]
def generator(noise, is_training=True):
"""Generator to produce CIFAR images.
Args:
noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
does not use conditioning, this Tensor represents a noise vector of some
kind that will be reshaped by the generator into CIFAR examples.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A single Tensor with a batch of generated CIFAR images.
"""
images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
# Make sure output lies between [-1, 1].
return tf.tanh(images)
def conditional_generator(inputs, is_training=True):
"""Generator to produce CIFAR images.
Args:
inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
conditional generator.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A single Tensor with a batch of generated CIFAR images.
"""
noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
# Make sure output lies between [-1, 1].
return tf.tanh(images)
def discriminator(img, unused_conditioning, is_training=True):
"""Discriminator for CIFAR images.
Args:
img: A Tensor of shape [batch size, width, height, channels], that can be
either real or generated. It is the discriminator's goal to distinguish
between the two.
unused_conditioning: The TFGAN API can help with conditional GANs, which
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
argument.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
"""
logits, _ = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
return logits
# (joelshor): This discriminator creates variables that aren't used, and
# causes logging warnings. Improve `dcgan` nets to accept a target end layer,
# so extraneous variables aren't created.
def conditional_discriminator(img, conditioning, is_training=True):
"""Discriminator for CIFAR images.
Args:
img: A Tensor of shape [batch size, width, height, channels], that can be
either real or generated. It is the discriminator's goal to distinguish
between the two.
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
"""
logits, end_points = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
# Condition the last convolution layer.
_, one_hot_labels = conditioning
net = _last_conv_layer(end_points)
net = tfgan.features.condition_tensor_from_onehot(
tf.contrib.layers.flatten(net), one_hot_labels)
logits = tf.contrib.layers.linear(net, 1)
return logits
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.cifar.networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import networks
class NetworksTest(tf.test.TestCase):
def test_generator(self):
tf.set_random_seed(1234)
batch_size = 100
noise = tf.random_normal([batch_size, 64])
image = networks.generator(noise)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
image_np = image.eval()
self.assertAllEqual([batch_size, 32, 32, 3], image_np.shape)
self.assertTrue(np.all(np.abs(image_np) <= 1))
def test_generator_conditional(self):
tf.set_random_seed(1234)
batch_size = 100
noise = tf.random_normal([batch_size, 64])
conditioning = tf.one_hot([0] * batch_size, 10)
image = networks.conditional_generator((noise, conditioning))
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
image_np = image.eval()
self.assertAllEqual([batch_size, 32, 32, 3], image_np.shape)
self.assertTrue(np.all(np.abs(image_np) <= 1))
def test_discriminator(self):
batch_size = 5
image = tf.random_uniform([batch_size, 32, 32, 3], -1, 1)
dis_output = networks.discriminator(image, None)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
dis_output_np = dis_output.eval()
self.assertAllEqual([batch_size, 1], dis_output_np.shape)
def test_discriminator_conditional(self):
batch_size = 5
image = tf.random_uniform([batch_size, 32, 32, 3], -1, 1)
conditioning = (None, tf.one_hot([0] * batch_size, 10))
dis_output = networks.conditional_discriminator(image, conditioning)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
dis_output_np = dis_output.eval()
self.assertAllEqual([batch_size, 1], dis_output_np.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a generator on CIFAR data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl import logging
import tensorflow as tf
import data_provider
import networks
tfgan = tf.contrib.gan
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/cifar/',
'Directory where to write event logs.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
flags.DEFINE_integer('max_number_of_steps', 1000000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_boolean(
'conditional', False,
'If `True`, set up a conditional GAN. If False, it is unconditional.')
# Sync replicas flags.
flags.DEFINE_boolean(
'use_sync_replicas', True,
'If `True`, use sync replicas. Otherwise use async.')
flags.DEFINE_integer(
'worker_replicas', 10,
'The number of gradients to collect before updating params. Only used '
'with sync replicas.')
flags.DEFINE_integer(
'backup_workers', 1,
'Number of workers to be kept as backup in the sync replicas case.')
flags.DEFINE_integer(
'inter_op_parallelism_threads', 0,
'Number of threads to use for inter-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
flags.DEFINE_integer(
'intra_op_parallelism_threads', 0,
'Number of threads to use for intra-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
FLAGS = flags.FLAGS
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Force all input processing onto CPU in order to reserve the GPU for
# the forward inference and back-propagation.
with tf.name_scope('inputs'):
with tf.device('/cpu:0'):
images, one_hot_labels, _, _ = data_provider.provide_data(
FLAGS.batch_size, FLAGS.dataset_dir)
# Define the GANModel tuple.
noise = tf.random_normal([FLAGS.batch_size, 64])
if FLAGS.conditional:
generator_fn = networks.conditional_generator
discriminator_fn = networks.conditional_discriminator
generator_inputs = (noise, one_hot_labels)
else:
generator_fn = networks.generator
discriminator_fn = networks.discriminator
generator_inputs = noise
gan_model = tfgan.gan_model(
generator_fn,
discriminator_fn,
real_data=images,
generator_inputs=generator_inputs)
tfgan.eval.add_gan_model_image_summaries(gan_model)
# Get the GANLoss tuple. Use the selected GAN loss functions.
# (joelshor): Put this block in `with tf.name_scope('loss'):` when
# cl/171610946 goes into the opensource release.
gan_loss = tfgan.gan_loss(gan_model,
gradient_penalty_weight=1.0,
add_summaries=True)
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
with tf.name_scope('train'):
gen_lr, dis_lr = _learning_rate()
gen_opt, dis_opt = _optimizer(gen_lr, dis_lr, FLAGS.use_sync_replicas)
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
sync_hooks = ([gen_opt.make_session_run_hook(FLAGS.task == 0),
dis_opt.make_session_run_hook(FLAGS.task == 0)]
if FLAGS.use_sync_replicas else [])
status_message = tf.string_join(
['Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())],
name='status_message')
if FLAGS.max_number_of_steps == 0: return
sess_config = tf.ConfigProto(
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
tfgan.gan_train(
train_ops,
hooks=(
[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)] +
sync_hooks),
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
config=sess_config)
def _learning_rate():
generator_lr = tf.train.exponential_decay(
learning_rate=0.0001,
global_step=tf.train.get_or_create_global_step(),
decay_steps=100000,
decay_rate=0.9,
staircase=True)
discriminator_lr = 0.001
return generator_lr, discriminator_lr
def _optimizer(gen_lr, dis_lr, use_sync_replicas):
"""Get an optimizer, that's optionally synchronous."""
generator_opt = tf.train.RMSPropOptimizer(gen_lr, decay=.9, momentum=0.1)
discriminator_opt = tf.train.RMSPropOptimizer(dis_lr, decay=.95, momentum=0.1)
def _make_sync(opt):
return tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=FLAGS.worker_replicas-FLAGS.backup_workers,
total_num_replicas=FLAGS.worker_replicas)
if use_sync_replicas:
generator_opt = _make_sync(generator_opt)
discriminator_opt = _make_sync(discriminator_opt)
return generator_opt, discriminator_opt
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import train
FLAGS = flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('Unconditional', False, False),
('Conditional', True, False),
('SyncReplicas', False, True))
def test_build_graph(self, conditional, use_sync_replicas):
FLAGS.max_number_of_steps = 0
FLAGS.conditional = conditional
FLAGS.use_sync_replicas = use_sync_replicas
FLAGS.batch_size = 16
# Mock input pipeline.
mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = (
mock_imgs, mock_lbls, None, None)
train.main(None)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience functions for training and evaluating a TFGAN CIFAR example."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
tfgan = tf.contrib.gan
def get_generator_conditioning(batch_size, num_classes):
"""Generates TFGAN conditioning inputs for evaluation.
Args:
batch_size: A Python integer. The desired batch size.
num_classes: A Python integer. The number of classes.
Returns:
A Tensor of one-hot vectors corresponding to an even distribution over
classes.
Raises:
ValueError: If `batch_size` isn't evenly divisible by `num_classes`.
"""
if batch_size % num_classes != 0:
raise ValueError('`batch_size` %i must be evenly divisible by '
'`num_classes` %i.' % (batch_size, num_classes))
labels = [lbl for lbl in xrange(num_classes)
for _ in xrange(batch_size // num_classes)]
return tf.one_hot(tf.constant(labels), num_classes)
def get_image_grid(images, batch_size, num_classes, num_images_per_class):
"""Combines images from each class in a single summary image.
Args:
images: Tensor of images that are arranged by class. The first
`batch_size / num_classes` images belong to the first class, the second
group belong to the second class, etc. Shape is
[batch, width, height, channels].
batch_size: Python integer. Batch dimension.
num_classes: Number of classes to show.
num_images_per_class: Number of image examples per class to show.
Raises:
ValueError: If the batch dimension of `images` is known at graph
construction, and it isn't `batch_size`.
ValueError: If there aren't enough images to show
`num_classes * num_images_per_class` images.
ValueError: If `batch_size` isn't divisible by `num_classes`.
Returns:
A single image.
"""
# Validate inputs.
images.shape[0:1].assert_is_compatible_with([batch_size])
if batch_size < num_classes * num_images_per_class:
raise ValueError('Not enough images in batch to show the desired number of '
'images.')
if batch_size % num_classes != 0:
raise ValueError('`batch_size` must be divisible by `num_classes`.')
# Only get a certain number of images per class.
num_batches = batch_size // num_classes
indices = [i * num_batches + j for i in xrange(num_classes)
for j in xrange(num_images_per_class)]
sampled_images = tf.gather(images, indices)
return tfgan.eval.image_reshaper(
sampled_images, num_cols=num_images_per_class)
def get_inception_scores(images, batch_size, num_inception_images):
"""Get Inception score for some images.
Args:
images: Image minibatch. Shape [batch size, width, height, channels]. Values
are in [-1, 1].
batch_size: Python integer. Batch dimension.
num_inception_images: Number of images to run through Inception at once.
Returns:
Inception scores. Tensor shape is [batch size].
Raises:
ValueError: If `batch_size` is incompatible with the first dimension of
`images`.
ValueError: If `batch_size` isn't divisible by `num_inception_images`.
"""
# Validate inputs.
images.shape[0:1].assert_is_compatible_with([batch_size])
if batch_size % num_inception_images != 0:
raise ValueError(
'`batch_size` must be divisible by `num_inception_images`.')
# Resize images.
size = 299
resized_images = tf.image.resize_bilinear(images, [size, size])
# Run images through Inception.
num_batches = batch_size // num_inception_images
inc_score = tfgan.eval.inception_score(
resized_images, num_batches=num_batches)
return inc_score
def get_frechet_inception_distance(real_images, generated_images, batch_size,
num_inception_images):
"""Get Frechet Inception Distance between real and generated images.
Args:
real_images: Real images minibatch. Shape [batch size, width, height,
channels. Values are in [-1, 1].
generated_images: Generated images minibatch. Shape [batch size, width,
height, channels]. Values are in [-1, 1].
batch_size: Python integer. Batch dimension.
num_inception_images: Number of images to run through Inception at once.
Returns:
Frechet Inception distance. A floating-point scalar.
Raises:
ValueError: If the minibatch size is known at graph construction time, and
doesn't batch `batch_size`.
"""
# Validate input dimensions.
real_images.shape[0:1].assert_is_compatible_with([batch_size])
generated_images.shape[0:1].assert_is_compatible_with([batch_size])
# Resize input images.
size = 299
resized_real_images = tf.image.resize_bilinear(real_images, [size, size])
resized_generated_images = tf.image.resize_bilinear(
generated_images, [size, size])
# Compute Frechet Inception Distance.
num_batches = batch_size // num_inception_images
fid = tfgan.eval.frechet_inception_distance(
resized_real_images, resized_generated_images, num_batches=num_batches)
return fid
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.cifar.util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import util
mock = tf.test.mock
class UtilTest(tf.test.TestCase):
def test_get_generator_conditioning(self):
conditioning = util.get_generator_conditioning(12, 4)
self.assertEqual([12, 4], conditioning.shape.as_list())
def test_get_image_grid(self):
util.get_image_grid(
tf.zeros([6, 28, 28, 1]),
batch_size=6,
num_classes=3,
num_images_per_class=1)
# Mock `inception_score` which is expensive.
@mock.patch.object(util.tfgan.eval, 'inception_score', autospec=True)
def test_get_inception_scores(self, mock_inception_score):
mock_inception_score.return_value = 1.0
util.get_inception_scores(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
# Mock `frechet_inception_distance` which is expensive.
@mock.patch.object(util.tfgan.eval, 'frechet_inception_distance',
autospec=True)
def test_get_frechet_inception_distance(self, mock_fid):
mock_fid.return_value = 1.0
util.get_frechet_inception_distance(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing image data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def normalize_image(image):
"""Rescale from range [0, 255] to [-1, 1]."""
return (tf.to_float(image) - 127.5) / 127.5
def undo_normalize_image(normalized_image):
"""Convert to a numpy array that can be read by PIL."""
# Convert from NHWC to HWC.
normalized_image = np.squeeze(normalized_image, axis=0)
return np.uint8(normalized_image * 127.5 + 127.5)
def _sample_patch(image, patch_size):
"""Crop image to square shape and resize it to `patch_size`.
Args:
image: A 3D `Tensor` of HWC format.
patch_size: A Python scalar. The output image size.
Returns:
A 3D `Tensor` of HWC format which has the shape of
[patch_size, patch_size, 3].
"""
image_shape = tf.shape(image)
height, width = image_shape[0], image_shape[1]
target_size = tf.minimum(height, width)
image = tf.image.resize_image_with_crop_or_pad(image, target_size,
target_size)
# tf.image.resize_area only accepts 4D tensor, so expand dims first.
image = tf.expand_dims(image, axis=0)
image = tf.image.resize_images(image, [patch_size, patch_size])
image = tf.squeeze(image, axis=0)
# Force image num_channels = 3
image = tf.tile(image, [1, 1, tf.maximum(1, 4 - tf.shape(image)[2])])
image = tf.slice(image, [0, 0, 0], [patch_size, patch_size, 3])
return image
def full_image_to_patch(image, patch_size):
image = normalize_image(image)
# Sample a patch of fixed size.
image_patch = _sample_patch(image, patch_size)
image_patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
return image_patch
def _provide_custom_dataset(image_file_pattern,
batch_size,
shuffle=True,
num_threads=1,
patch_size=128):
"""Provides batches of custom image data.
Args:
image_file_pattern: A string of glob pattern of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of mapping threads. Defaults to 1.
patch_size: Size of the path to extract from the image. Defaults to 128.
Returns:
A tf.data.Dataset with Tensors of shape
[batch_size, patch_size, patch_size, 3] representing a batch of images.
Raises:
ValueError: If no files match `image_file_pattern`.
"""
if not tf.gfile.Glob(image_file_pattern):
raise ValueError('No file patterns found.')
filenames_ds = tf.data.Dataset.list_files(image_file_pattern)
bytes_ds = filenames_ds.map(tf.io.read_file, num_parallel_calls=num_threads)
images_ds = bytes_ds.map(
tf.image.decode_image, num_parallel_calls=num_threads)
patches_ds = images_ds.map(
lambda img: full_image_to_patch(img, patch_size),
num_parallel_calls=num_threads)
patches_ds = patches_ds.repeat()
if shuffle:
patches_ds = patches_ds.shuffle(5 * batch_size)
patches_ds = patches_ds.prefetch(5 * batch_size)
patches_ds = patches_ds.batch(batch_size)
return patches_ds
def provide_custom_datasets(image_file_patterns,
batch_size,
shuffle=True,
num_threads=1,
patch_size=128):
"""Provides multiple batches of custom image data.
Args:
image_file_patterns: A list of glob patterns of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
A list of tf.data.Datasets the same number as `image_file_patterns`. Each
of the datasets have `Tensor`'s in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images.
Raises:
ValueError: If image_file_patterns is not a list or tuple.
"""
if not isinstance(image_file_patterns, (list, tuple)):
raise ValueError(
'`image_file_patterns` should be either list or tuple, but was {}.'.
format(type(image_file_patterns)))
custom_datasets = []
for pattern in image_file_patterns:
custom_datasets.append(
_provide_custom_dataset(
pattern,
batch_size=batch_size,
shuffle=shuffle,
num_threads=num_threads,
patch_size=patch_size))
return custom_datasets
def provide_custom_data(image_file_patterns,
batch_size,
shuffle=True,
num_threads=1,
patch_size=128):
"""Provides multiple batches of custom image data.
Args:
image_file_patterns: A list of glob patterns of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
A list of float `Tensor`s with the same size of `image_file_patterns`. Each
of the `Tensor` in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images. As a
side effect, the tf.Dataset initializer is added to the
tf.GraphKeys.TABLE_INITIALIZERS collection.
Raises:
ValueError: If image_file_patterns is not a list or tuple.
"""
datasets = provide_custom_datasets(
image_file_patterns, batch_size, shuffle, num_threads, patch_size)
tensors = []
for ds in datasets:
iterator = ds.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
tensors.append(iterator.get_next())
return tensors
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
import tensorflow as tf
import data_provider
mock = tf.test.mock
class DataProviderTest(tf.test.TestCase):
def setUp(self):
super(DataProviderTest, self).setUp()
self.testdata_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata')
def test_normalize_image(self):
image = tf.random_uniform(shape=(8, 8, 3), maxval=256, dtype=tf.int32)
rescaled_image = data_provider.normalize_image(image)
self.assertEqual(tf.float32, rescaled_image.dtype)
self.assertListEqual(image.shape.as_list(), rescaled_image.shape.as_list())
with self.test_session(use_gpu=True) as sess:
rescaled_image_out = sess.run(rescaled_image)
self.assertTrue(np.all(np.abs(rescaled_image_out) <= 1.0))
def test_sample_patch(self):
image = tf.zeros(shape=(8, 8, 3))
patch1 = data_provider._sample_patch(image, 7)
patch2 = data_provider._sample_patch(image, 10)
image = tf.zeros(shape=(8, 8, 1))
patch3 = data_provider._sample_patch(image, 10)
with self.test_session(use_gpu=True) as sess:
self.assertTupleEqual((7, 7, 3), sess.run(patch1).shape)
self.assertTupleEqual((10, 10, 3), sess.run(patch2).shape)
self.assertTupleEqual((10, 10, 3), sess.run(patch3).shape)
def test_custom_dataset_provider(self):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images_ds = data_provider._provide_custom_dataset(
file_pattern, batch_size=batch_size, patch_size=patch_size)
self.assertListEqual([None, patch_size, patch_size, 3],
images_ds.output_shapes.as_list())
self.assertEqual(tf.float32, images_ds.output_types)
iterator = images_ds.make_initializable_iterator()
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
sess.run(iterator.initializer)
images_out = sess.run(iterator.get_next())
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_custom_datasets_provider(self):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images_ds_list = data_provider.provide_custom_datasets(
[file_pattern, file_pattern],
batch_size=batch_size,
patch_size=patch_size)
for images_ds in images_ds_list:
self.assertListEqual([None, patch_size, patch_size, 3],
images_ds.output_shapes.as_list())
self.assertEqual(tf.float32, images_ds.output_types)
iterators = [x.make_initializable_iterator() for x in images_ds_list]
initialiers = [x.initializer for x in iterators]
img_tensors = [x.get_next() for x in iterators]
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
sess.run(initialiers)
images_out_list = sess.run(img_tensors)
for images_out in images_out_list:
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_custom_data_provider(self):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images_list = data_provider.provide_custom_data(
[file_pattern, file_pattern],
batch_size=batch_size,
patch_size=patch_size)
for images in images_list:
self.assertListEqual([None, patch_size, patch_size, 3],
images.shape.as_list())
self.assertEqual(tf.float32, images.dtype)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
images_out_list = sess.run(images_list)
for images_out in images_out_list:
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Demo that makes inference requests against a running inference server."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import numpy as np
import PIL
import tensorflow as tf
import data_provider
import networks
tfgan = tf.contrib.gan
flags.DEFINE_string('checkpoint_path', '',
'CycleGAN checkpoint path created by train.py. '
'(e.g. "/mylogdir/model.ckpt-18442")')
flags.DEFINE_string(
'image_set_x_glob', '',
'Optional: Glob path to images of class X to feed through the CycleGAN.')
flags.DEFINE_string(
'image_set_y_glob', '',
'Optional: Glob path to images of class Y to feed through the CycleGAN.')
flags.DEFINE_string(
'generated_x_dir', '/tmp/generated_x/',
'If image_set_y_glob is defined, where to output the generated X '
'images.')
flags.DEFINE_string(
'generated_y_dir', '/tmp/generated_y/',
'If image_set_x_glob is defined, where to output the generated Y '
'images.')
flags.DEFINE_integer('patch_dim', 128,
'The patch size of images that was used in train.py.')
FLAGS = flags.FLAGS
def _make_dir_if_not_exists(dir_path):
"""Make a directory if it does not exist."""
if not tf.gfile.Exists(dir_path):
tf.gfile.MakeDirs(dir_path)
def _file_output_path(dir_path, input_file_path):
"""Create output path for an individual file."""
return os.path.join(dir_path, os.path.basename(input_file_path))
def make_inference_graph(model_name, patch_dim):
"""Build the inference graph for either the X2Y or Y2X GAN.
Args:
model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
patch_dim: An integer size of patches to feed to the generator.
Returns:
Tuple of (input_placeholder, generated_tensor).
"""
input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3])
# Expand HWC to NHWC
images_x = tf.expand_dims(
data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0)
with tf.variable_scope(model_name):
with tf.variable_scope('Generator'):
generated = networks.generator(images_x)
return input_hwc_pl, generated
def export(sess, input_pl, output_tensor, input_file_pattern, output_dir):
"""Exports inference outputs to an output directory.
Args:
sess: tf.Session with variables already loaded.
input_pl: tf.Placeholder for input (HWC format).
output_tensor: Tensor for generated outut images.
input_file_pattern: Glob file pattern for input images.
output_dir: Output directory.
"""
if output_dir:
_make_dir_if_not_exists(output_dir)
if input_file_pattern:
for file_path in tf.gfile.Glob(input_file_pattern):
# Grab a single image and run it through inference
input_np = np.asarray(PIL.Image.open(file_path))
output_np = sess.run(output_tensor, feed_dict={input_pl: input_np})
image_np = data_provider.undo_normalize_image(output_np)
output_path = _file_output_path(output_dir, file_path)
PIL.Image.fromarray(image_np).save(output_path)
def _validate_flags():
flags.register_validator('checkpoint_path', bool,
'Must provide `checkpoint_path`.')
flags.register_validator(
'generated_x_dir',
lambda x: False if (FLAGS.image_set_y_glob and not x) else True,
'Must provide `generated_x_dir`.')
flags.register_validator(
'generated_y_dir',
lambda x: False if (FLAGS.image_set_x_glob and not x) else True,
'Must provide `generated_y_dir`.')
def main(_):
_validate_flags()
images_x_hwc_pl, generated_y = make_inference_graph('ModelX2Y',
FLAGS.patch_dim)
images_y_hwc_pl, generated_x = make_inference_graph('ModelY2X',
FLAGS.patch_dim)
# Restore all the variables that were saved in the checkpoint.
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, FLAGS.checkpoint_path)
export(sess, images_x_hwc_pl, generated_y, FLAGS.image_set_x_glob,
FLAGS.generated_y_dir)
export(sess, images_y_hwc_pl, generated_x, FLAGS.image_set_y_glob,
FLAGS.generated_x_dir)
if __name__ == '__main__':
app.run()
"""Tests for CycleGAN inference demo."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import numpy as np
import PIL
import tensorflow as tf
import inference_demo
import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
def _basenames_from_glob(file_glob):
return [os.path.basename(file_path) for file_path in tf.gfile.Glob(file_glob)]
class InferenceDemoTest(tf.test.TestCase):
def setUp(self):
self._export_dir = os.path.join(FLAGS.test_tmpdir, 'export')
self._ckpt_path = os.path.join(self._export_dir, 'model.ckpt')
self._image_glob = os.path.join(
FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata', '*.jpg')
self._genx_dir = os.path.join(FLAGS.test_tmpdir, 'genx')
self._geny_dir = os.path.join(FLAGS.test_tmpdir, 'geny')
@mock.patch.object(tfgan, 'gan_train', autospec=True)
@mock.patch.object(
train.data_provider, 'provide_custom_data', autospec=True)
def testTrainingAndInferenceGraphsAreCompatible(
self, mock_provide_custom_data, unused_mock_gan_train):
# Training and inference graphs can get out of sync if changes are made
# to one but not the other. This test will keep them in sync.
# Save the training graph
train_sess = tf.Session()
FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg'
FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg'
FLAGS.batch_size = 3
FLAGS.patch_size = 128
FLAGS.generator_lr = 0.02
FLAGS.discriminator_lr = 0.3
FLAGS.train_log_dir = self._export_dir
FLAGS.master = 'master'
FLAGS.task = 0
FLAGS.cycle_consistency_loss_weight = 2.0
FLAGS.max_number_of_steps = 1
mock_provide_custom_data.return_value = (
tf.zeros([3, 4, 4, 3,]), tf.zeros([3, 4, 4, 3]))
train.main(None)
init_op = tf.global_variables_initializer()
train_sess.run(init_op)
train_saver = tf.train.Saver()
train_saver.save(train_sess, save_path=self._ckpt_path)
# Create inference graph
tf.reset_default_graph()
FLAGS.patch_dim = FLAGS.patch_size
logging.info('dir_path: %s', os.listdir(self._export_dir))
FLAGS.checkpoint_path = self._ckpt_path
FLAGS.image_set_x_glob = self._image_glob
FLAGS.image_set_y_glob = self._image_glob
FLAGS.generated_x_dir = self._genx_dir
FLAGS.generated_y_dir = self._geny_dir
inference_demo.main(None)
logging.info('gen x: %s', os.listdir(self._genx_dir))
# Check that the image names match
self.assertSetEqual(
set(_basenames_from_glob(FLAGS.image_set_x_glob)),
set(os.listdir(FLAGS.generated_y_dir)))
self.assertSetEqual(
set(_basenames_from_glob(FLAGS.image_set_y_glob)),
set(os.listdir(FLAGS.generated_x_dir)))
# Check that each image in the directory looks as expected
for directory in [FLAGS.generated_x_dir, FLAGS.generated_x_dir]:
for base_name in os.listdir(directory):
image_path = os.path.join(directory, base_name)
self.assertRealisticImage(image_path)
def assertRealisticImage(self, image_path):
logging.info('Testing %s for realism.', image_path)
# If the normalization is off or forgotten, then the generated image is
# all one pixel value. This tests that different pixel values are achieved.
input_np = np.asarray(PIL.Image.open(image_path))
self.assertEqual(len(input_np.shape), 3)
self.assertGreaterEqual(input_np.shape[0], 50)
self.assertGreaterEqual(input_np.shape[1], 50)
self.assertGreater(np.mean(input_np), 20)
self.assertGreater(np.var(input_np), 100)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a CycleGAN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
import data_provider
import networks
tfgan = tf.contrib.gan
flags.DEFINE_string('image_set_x_file_pattern', None,
'File pattern of images in image set X')
flags.DEFINE_string('image_set_y_file_pattern', None,
'File pattern of images in image set Y')
flags.DEFINE_integer('batch_size', 1, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 64, 'The patch size of images.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/cyclegan/',
'Directory where to write event logs.')
flags.DEFINE_float('generator_lr', 0.0002,
'The compression model learning rate.')
flags.DEFINE_float('discriminator_lr', 0.0001,
'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 500000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_float('cycle_consistency_loss_weight', 10.0,
'The weight of cycle consistency loss')
FLAGS = flags.FLAGS
def _define_model(images_x, images_y):
"""Defines a CycleGAN model that maps between images_x and images_y.
Args:
images_x: A 4D float `Tensor` of NHWC format. Images in set X.
images_y: A 4D float `Tensor` of NHWC format. Images in set Y.
Returns:
A `CycleGANModel` namedtuple.
"""
cyclegan_model = tfgan.cyclegan_model(
generator_fn=networks.generator,
discriminator_fn=networks.discriminator,
data_x=images_x,
data_y=images_y)
# Add summaries for generated images.
tfgan.eval.add_cyclegan_image_summaries(cyclegan_model)
return cyclegan_model
def _get_lr(base_lr):
"""Returns a learning rate `Tensor`.
Args:
base_lr: A scalar float `Tensor` or a Python number. The base learning
rate.
Returns:
A scalar float `Tensor` of learning rate which equals `base_lr` when the
global training step is less than FLAGS.max_number_of_steps / 2, afterwards
it linearly decays to zero.
"""
global_step = tf.train.get_or_create_global_step()
lr_constant_steps = FLAGS.max_number_of_steps // 2
def _lr_decay():
return tf.train.polynomial_decay(
learning_rate=base_lr,
global_step=(global_step - lr_constant_steps),
decay_steps=(FLAGS.max_number_of_steps - lr_constant_steps),
end_learning_rate=0.0)
return tf.cond(global_step < lr_constant_steps, lambda: base_lr, _lr_decay)
def _get_optimizer(gen_lr, dis_lr):
"""Returns generator optimizer and discriminator optimizer.
Args:
gen_lr: A scalar float `Tensor` or a Python number. The Generator learning
rate.
dis_lr: A scalar float `Tensor` or a Python number. The Discriminator
learning rate.
Returns:
A tuple of generator optimizer and discriminator optimizer.
"""
# beta1 follows
# https://github.com/junyanz/CycleGAN/blob/master/options.lua
gen_opt = tf.train.AdamOptimizer(gen_lr, beta1=0.5, use_locking=True)
dis_opt = tf.train.AdamOptimizer(dis_lr, beta1=0.5, use_locking=True)
return gen_opt, dis_opt
def _define_train_ops(cyclegan_model, cyclegan_loss):
"""Defines train ops that trains `cyclegan_model` with `cyclegan_loss`.
Args:
cyclegan_model: A `CycleGANModel` namedtuple.
cyclegan_loss: A `CycleGANLoss` namedtuple containing all losses for
`cyclegan_model`.
Returns:
A `GANTrainOps` namedtuple.
"""
gen_lr = _get_lr(FLAGS.generator_lr)
dis_lr = _get_lr(FLAGS.discriminator_lr)
gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr)
train_ops = tfgan.gan_train_ops(
cyclegan_model,
cyclegan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
return train_ops
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
with tf.name_scope('inputs'):
images_x, images_y = data_provider.provide_custom_data(
[FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern],
batch_size=FLAGS.batch_size,
patch_size=FLAGS.patch_size)
# Set batch size for summaries.
images_x.set_shape([FLAGS.batch_size, None, None, None])
images_y.set_shape([FLAGS.batch_size, None, None, None])
# Define CycleGAN model.
cyclegan_model = _define_model(images_x, images_y)
# Define CycleGAN loss.
cyclegan_loss = tfgan.cyclegan_loss(
cyclegan_model,
cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight,
tensor_pool_fn=tfgan.features.tensor_pool)
# Define CycleGAN train ops.
train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)
# Training
train_steps = tfgan.GANTrainSteps(1, 1)
status_message = tf.string_join(
[
'Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())
],
name='status_message')
if not FLAGS.max_number_of_steps:
return
tfgan.gan_train(
train_ops,
FLAGS.train_log_dir,
get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
hooks=[
tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)
],
master=FLAGS.master,
is_chief=FLAGS.task == 0)
if __name__ == '__main__':
tf.flags.mark_flag_as_required('image_set_x_file_pattern')
tf.flags.mark_flag_as_required('image_set_y_file_pattern')
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cyclegan.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import numpy as np
import tensorflow as tf
import train
FLAGS = flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
def _test_generator(input_images):
"""Simple generator function."""
return input_images * tf.get_variable('dummy_g', initializer=2.0)
def _test_discriminator(image_batch, unused_conditioning=None):
"""Simple discriminator function."""
return tf.contrib.layers.flatten(
image_batch * tf.get_variable('dummy_d', initializer=2.0))
train.networks.generator = _test_generator
train.networks.discriminator = _test_discriminator
class TrainTest(tf.test.TestCase):
@mock.patch.object(tfgan, 'eval', autospec=True)
def test_define_model(self, mock_eval):
FLAGS.batch_size = 2
images_shape = [FLAGS.batch_size, 4, 4, 3]
images_x_np = np.zeros(shape=images_shape)
images_y_np = np.zeros(shape=images_shape)
images_x = tf.constant(images_x_np, dtype=tf.float32)
images_y = tf.constant(images_y_np, dtype=tf.float32)
cyclegan_model = train._define_model(images_x, images_y)
self.assertIsInstance(cyclegan_model, tfgan.CycleGANModel)
self.assertShapeEqual(images_x_np, cyclegan_model.reconstructed_x)
self.assertShapeEqual(images_y_np, cyclegan_model.reconstructed_y)
@mock.patch.object(train.networks, 'generator', autospec=True)
@mock.patch.object(train.networks, 'discriminator', autospec=True)
@mock.patch.object(
tf.train, 'get_or_create_global_step', autospec=True)
def test_get_lr(self, mock_get_or_create_global_step,
unused_mock_discriminator, unused_mock_generator):
FLAGS.max_number_of_steps = 10
base_lr = 0.01
with self.test_session(use_gpu=True) as sess:
mock_get_or_create_global_step.return_value = tf.constant(2)
lr_step2 = sess.run(train._get_lr(base_lr))
mock_get_or_create_global_step.return_value = tf.constant(9)
lr_step9 = sess.run(train._get_lr(base_lr))
self.assertAlmostEqual(base_lr, lr_step2)
self.assertAlmostEqual(base_lr * 0.2, lr_step9)
@mock.patch.object(tf.train, 'AdamOptimizer', autospec=True)
def test_get_optimizer(self, mock_adam_optimizer):
gen_lr, dis_lr = 0.1, 0.01
train._get_optimizer(gen_lr=gen_lr, dis_lr=dis_lr)
mock_adam_optimizer.assert_has_calls([
mock.call(gen_lr, beta1=mock.ANY, use_locking=True),
mock.call(dis_lr, beta1=mock.ANY, use_locking=True)
])
@mock.patch.object(tf.summary, 'scalar', autospec=True)
def test_define_train_ops(self, mock_summary_scalar):
FLAGS.batch_size = 2
FLAGS.generator_lr = 0.1
FLAGS.discriminator_lr = 0.01
images_shape = [FLAGS.batch_size, 4, 4, 3]
images_x = tf.zeros(images_shape, dtype=tf.float32)
images_y = tf.zeros(images_shape, dtype=tf.float32)
cyclegan_model = train._define_model(images_x, images_y)
cyclegan_loss = tfgan.cyclegan_loss(
cyclegan_model, cycle_consistency_loss_weight=10.0)
train_ops = train._define_train_ops(cyclegan_model, cyclegan_loss)
self.assertIsInstance(train_ops, tfgan.GANTrainOps)
mock_summary_scalar.assert_has_calls([
mock.call('generator_lr', mock.ANY),
mock.call('discriminator_lr', mock.ANY)
])
@mock.patch.object(tf, 'gfile', autospec=True)
@mock.patch.object(train, 'data_provider', autospec=True)
@mock.patch.object(train, '_define_model', autospec=True)
@mock.patch.object(tfgan, 'cyclegan_loss', autospec=True)
@mock.patch.object(train, '_define_train_ops', autospec=True)
@mock.patch.object(tfgan, 'gan_train', autospec=True)
def test_main(self, mock_gan_train, mock_define_train_ops, mock_cyclegan_loss,
mock_define_model, mock_data_provider, mock_gfile):
FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg'
FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg'
FLAGS.batch_size = 3
FLAGS.patch_size = 8
FLAGS.generator_lr = 0.02
FLAGS.discriminator_lr = 0.3
FLAGS.train_log_dir = '/tmp/foo'
FLAGS.master = 'master'
FLAGS.task = 0
FLAGS.cycle_consistency_loss_weight = 2.0
FLAGS.max_number_of_steps = 1
mock_data_provider.provide_custom_data.return_value = (
tf.zeros([3, 2, 2, 3], dtype=tf.float32),
tf.zeros([3, 2, 2, 3], dtype=tf.float32))
train.main(None)
mock_data_provider.provide_custom_data.assert_called_once_with(
['/tmp/x/*.jpg', '/tmp/y/*.jpg'], batch_size=3, patch_size=8)
mock_define_model.assert_called_once_with(mock.ANY, mock.ANY)
mock_cyclegan_loss.assert_called_once_with(
mock_define_model.return_value,
cycle_consistency_loss_weight=2.0,
tensor_pool_fn=mock.ANY)
mock_define_train_ops.assert_called_once_with(
mock_define_model.return_value, mock_cyclegan_loss.return_value)
mock_gan_train.assert_called_once_with(
mock_define_train_ops.return_value,
'/tmp/foo',
get_hooks_fn=mock.ANY,
hooks=mock.ANY,
master='master',
is_chief=True)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment