Unverified Commit 5266d031 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #2862 from joel-shor/master

Add TFGAN examples and add `gan` to readme
parents 0cc98628 e959526a
...@@ -29,6 +29,7 @@ installation](https://www.tensorflow.org/install). ...@@ -29,6 +29,7 @@ installation](https://www.tensorflow.org/install).
- [differential_privacy](differential_privacy): privacy-preserving student - [differential_privacy](differential_privacy): privacy-preserving student
models from multiple teachers. models from multiple teachers.
- [domain_adaptation](domain_adaptation): domain separation networks. - [domain_adaptation](domain_adaptation): domain separation networks.
- [gan](gan): generative adversarial networks.
- [im2txt](im2txt): image-to-text neural network for image captioning. - [im2txt](im2txt): image-to-text neural network for image captioning.
- [inception](inception): deep convolutional networks for computer vision. - [inception](inception): deep convolutional networks for computer vision.
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a - [learning_to_remember_rare_events](learning_to_remember_rare_events): a
......
# TFGAN Examples
[TFGAN](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan) is a lightweight library for training and evaluating Generative
Adversarial Networks (GANs). GANs have been in a wide range of tasks
including [image translation](https://arxiv.org/abs/1703.10593), [superresolution](https://arxiv.org/abs/1609.04802), and [data augmentation](https://arxiv.org/abs/1612.07828). This directory contains fully-working examples
that demonstrate the ease and flexibility of TFGAN. Each subdirectory contains a
different working example. The sub-sections below describe each of the problems,
and include some sample outputs. We've also included a [jupyter notebook](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb), which
provides a walkthrough of TFGAN.
## Contacts
Maintainers of TFGAN:
* Joel Shor,
github: [joel-shor](https://github.com/joel-shor)
## Table of contents
1. [MNIST](#mnist)
1. [MNIST with GANEstimator](#mnist_estimator)
1. [CIFAR10](#cifar10)
1. [Image compression (coming soon)](#compression)
## MNIST {#mnist}
We train a simple generator to produce [MNIST digits](http://yann.lecun.com/exdb/mnist/).
The unconditional case maps noise to MNIST digits. The conditional case maps
noise and digit class to MNIST digits. [InfoGAN](https://arxiv.org/abs/1606.03657) learns to produce
digits of a given class without labels, as well as controlling style. The
network architectures are defined [here](https://github.com/tensorflow/models/tree/master/research/gan/mnist/networks.py).
We use a classifier trained on MNIST digit classification for evaluation.
### Unconditional MNIST
![Unconditional GAN](g3doc/mnist_unconditional_gan.png "unconditional GAN")
### Conditional MNIST
![Conditional GAN](g3doc/mnist_conditional_gan.png "conditional GAN")
### InfoGAN MNIST
![InfoGAN](g3doc/mnist_infogan.png "InfoGAN")
## MNIST with GANEstimator {#mnist_estimator}
This setup is exactly the same as in the [unconditional MNIST example](#mnist), but
uses the `tf.Learn` `GANEstimator`.
![Unconditional GAN](g3doc/mnist_estimator_unconditional_gan.png "unconditional GAN")
## CIFAR10 {#cifar10}
We train a [DCGAN generator](https://arxiv.org/abs/1511.06434) to produce [CIFAR10 images](https://www.cs.toronto.edu/~kriz/cifar.html).
The unconditional case maps noise to CIFAR10 images. The conditional case maps
noise and image class to CIFAR10 images. The
network architectures are defined [here](https://github.com/tensorflow/models/tree/master/research/gan/cifar/networks.py).
We use the [Inception Score](https://arxiv.org/abs/1606.03498) to evaluate the images.
### Unconditional CIFAR10
![Unconditional GAN](g3doc/cifar_unconditional_gan.png "unconditional GAN")
### Conditional CIFAR10
![Unconditional GAN](g3doc/cifar_conditional_gan.png "unconditional GAN"){width="330"}
## Image compression {#compression}
In neural image compression, we attempt to reduce an image to a smaller representation
such that we can recreate the original image as closely as possible. See [`Full Resolution Image Compression with Recurrent Neural Networks`](https://arxiv.org/abs/1608.05148) for more details on using neural networks
for image compression.
In this example, we train an encoder to compress images to a compressed binary
representation and a decoder to map the binary representation back to the image.
We treat both systems together (encoder -> decoder) as the generator.
A typical image compression trained on L1 pixel loss will decode into
blurry images. We use an adversarial loss to force the outputs to be more
plausible.
This example also highlights the following infrastructure challenges:
* When you have custom code to keep track of your variables
Some other notes on the problem:
* Since the network is fully convolutional, we train on image patches.
* Bottleneck layer is floating point during training and binarized during
evaluation.
### Results
#### No adversarial loss
![compresson_no_adversarial](g3doc/compression_wf0.png "no adversarial loss")
#### Adversarial loss
![compresson_no_adversarial](g3doc/compression_wf10000.png "with adversarial loss")
### Architectures
#### Compression Network
The compression network is a DCGAN discriminator for the encoder and a DCGAN
generator for the decoder from [`Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks`](https://arxiv.org/abs/1511.06434).
The binarizer adds uniform noise during training then binarizes during eval, as in
[`End-to-end Optimized Image Compression`](https://arxiv.org/abs/1611.01704).
#### Discriminator
The discriminator looks at 70x70 patches, as in
[`Image-to-Image Translation with Conditional Adversarial Networks`](https://arxiv.org/abs/1611.07004).
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the CIFAR data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(batch_size, dataset_dir, dataset_name='cifar10',
split_name='train', one_hot=True):
"""Provides batches of CIFAR data.
Args:
batch_size: The number of images in each batch.
dataset_dir: The directory where the CIFAR10 data can be found. If `None`,
use default.
dataset_name: Name of the dataset.
split_name: Should be either 'train' or 'test'.
one_hot: Output one hot vector instead of int32 label.
Returns:
images: A `Tensor` of size [batch_size, 32, 32, 3]. Output pixel values are
in [-1, 1].
labels: Either (1) one_hot_labels if `one_hot` is `True`
A `Tensor` of size [batch_size, num_classes], where each row has a
single element set to one and the rest set to zeros.
Or (2) labels if `one_hot` is `False`
A `Tensor` of size [batch_size], holding the labels as integers.
num_samples: The number of total samples in the dataset.
num_classes: The number of classes in the dataset.
Raises:
ValueError: if the split_name is not either 'train' or 'test'.
"""
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=(split_name == 'train'))
[image, label] = provider.get(['image', 'label'])
# Preprocess the images.
image = (tf.to_float(image) - 128.0) / 128.0
# Creates a QueueRunner for the pre-fetching operation.
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=32,
capacity=5 * batch_size)
labels = tf.reshape(labels, [-1])
if one_hot:
labels = tf.one_hot(labels, dataset.num_classes)
return images, labels, dataset.num_samples, dataset.num_classes
def float_image_to_uint8(image):
"""Convert float image in [-1, 1) to [0, 255] uint8.
Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
Args:
image: An image tensor. Values should be in [-1, 1).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 128.0) + 128.0
return tf.cast(image, tf.uint8)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
def test_cifar10_train_set(self):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cifar/testdata')
batch_size = 4
images, labels, num_samples, num_classes = data_provider.provide_data(
batch_size, dataset_dir)
self.assertEqual(50000, num_samples)
self.assertEqual(10, num_classes)
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out, labels_out = sess.run([images, labels])
self.assertEqual(images_out.shape, (batch_size, 32, 32, 3))
expected_label_shape = (batch_size, 10)
self.assertEqual(expected_label_shape, labels_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained CIFAR model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
import util
flags = tf.flags
FLAGS = tf.flags.FLAGS
tfgan = tf.contrib.gan
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/cifar10/',
'Directory where the results are saved to.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
flags.DEFINE_integer('num_images_generated', 100,
'Number of images to generate at once.')
flags.DEFINE_integer('num_inception_images', 10,
'The number of images to run through Inception at once.')
flags.DEFINE_boolean('eval_real_images', False,
'If `True`, run Inception network on real images.')
flags.DEFINE_boolean('conditional_eval', False,
'If `True`, set up a conditional GAN.')
flags.DEFINE_boolean('eval_frechet_inception_distance', True,
'If `True`, compute Frechet Inception distance using real '
'images and generated images.')
flags.DEFINE_integer('num_images_per_class', 10,
'When a conditional generator is used, this is the number '
'of images to display per class.')
flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'forever.')
def main(_, run_eval_loop=True):
# Fetch and generate images to run through Inception.
with tf.name_scope('inputs'):
real_data, num_classes = _get_real_data(
FLAGS.num_images_generated, FLAGS.dataset_dir)
generated_data = _get_generated_data(
FLAGS.num_images_generated, FLAGS.conditional_eval, num_classes)
# Compute Frechet Inception Distance.
if FLAGS.eval_frechet_inception_distance:
fid = util.get_frechet_inception_distance(
real_data, generated_data, FLAGS.num_images_generated,
FLAGS.num_inception_images)
tf.summary.scalar('frechet_inception_distance', fid)
# Compute normal Inception scores.
if FLAGS.eval_real_images:
inc_score = util.get_inception_scores(
real_data, FLAGS.num_images_generated, FLAGS.num_inception_images)
else:
inc_score = util.get_inception_scores(
generated_data, FLAGS.num_images_generated, FLAGS.num_inception_images)
tf.summary.scalar('inception_score', inc_score)
# If conditional, display an image grid of difference classes.
if FLAGS.conditional_eval and not FLAGS.eval_real_images:
reshaped_imgs = util.get_image_grid(
generated_data, FLAGS.num_images_generated, num_classes,
FLAGS.num_images_per_class)
tf.summary.image('generated_data', reshaped_imgs, max_outputs=1)
# Create ops that write images to disk.
image_write_ops = None
if FLAGS.conditional_eval:
reshaped_imgs = util.get_image_grid(
generated_data, FLAGS.num_images_generated, num_classes,
FLAGS.num_images_per_class)
uint8_images = data_provider.float_image_to_uint8(reshaped_imgs)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'conditional_cifar10.png'),
tf.image.encode_png(uint8_images[0]))
else:
if FLAGS.num_images_generated >= 100:
reshaped_imgs = tfgan.eval.image_reshaper(
generated_data[:100], num_cols=FLAGS.num_images_per_class)
uint8_images = data_provider.float_image_to_uint8(reshaped_imgs)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'unconditional_cifar10.png'),
tf.image.encode_png(uint8_images[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
def _get_real_data(num_images_generated, dataset_dir):
"""Get real images."""
data, _, _, num_classes = data_provider.provide_data(
num_images_generated, dataset_dir)
return data, num_classes
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
"""Get generated images."""
noise = tf.random_normal([num_images_generated, 64])
# If conditional, generate class-specific images.
if conditional_eval:
conditioning = util.get_generator_conditioning(
num_images_generated, num_classes)
generator_inputs = (noise, conditioning)
generator_fn = networks.conditional_generator
else:
generator_inputs = noise
generator_fn = networks.generator
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('Generator'):
data = generator_fn(generator_inputs)
return data
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.cifar.eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import eval # pylint:disable=redefined-builtin
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
class EvalTest(tf.test.TestCase):
def _test_build_graph_helper(self, eval_real_images, conditional_eval):
FLAGS.eval_real_images = eval_real_images
FLAGS.conditional_eval = conditional_eval
# Mock `frechet_inception_distance` and `inception_score`, which are
# expensive.
with mock.patch.object(
eval.util, 'get_frechet_inception_distance') as mock_fid:
with mock.patch.object(eval.util, 'get_inception_scores') as mock_iscore:
mock_fid.return_value = 1.0
mock_iscore.return_value = 1.0
eval.main(None, run_eval_loop=False)
def test_build_graph_realdata(self):
self._test_build_graph_helper(True, False)
def test_build_graph_generateddata(self):
self._test_build_graph_helper(False, False)
def test_build_graph_generateddataconditional(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the CIFAR dataset.
# 2. Trains an unconditional or conditional model on the CIFAR training set.
# 3. Evaluates the models and writes sample images to disk.
#
#
# With the default batch size and number of steps, train times are:
#
# Usage:
# cd models/research/gan/cifar
# ./launch_jobs.sh ${gan_type} ${git_repo}
set -e
# Type of GAN to run. Right now, options are `unconditional`, `conditional`, or
# `infogan`.
gan_type=$1
if ! [[ "$gan_type" =~ ^(unconditional|conditional) ]]; then
echo "'gan_type' must be one of: 'unconditional', 'conditional'."
exit
fi
# Location of the git repository.
git_repo=$2
if [[ "$git_repo" == "" ]]; then
echo "'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/cifar-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR=/tmp/cifar-model/eval
# Where the dataset is saved to.
DATASET_DIR=/tmp/cifar-data
export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/gan:$git_repo/research/slim
# A helper function for printing pretty output.
Banner () {
local text=$1
local green='\033[0;32m'
local nc='\033[0m' # No color.
echo -e "${green}${text}${nc}"
}
# Download the dataset.
python "${git_repo}/research/slim/download_and_convert_data.py" \
--dataset_name=cifar10 \
--dataset_dir=${DATASET_DIR}
# Run unconditional GAN.
if [[ "$gan_type" == "unconditional" ]]; then
UNCONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/unconditional"
UNCONDITIONAL_EVAL_DIR="${EVAL_DIR}/unconditional"
NUM_STEPS=10000
# Run training.
Banner "Starting training unconditional GAN for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/cifar/train.py" \
--train_log_dir=${UNCONDITIONAL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--gan_type="unconditional" \
--alsologtostderr
Banner "Finished training unconditional GAN ${NUM_STEPS} steps."
# Run evaluation.
Banner "Starting evaluation of unconditional GAN..."
python "${git_repo}/research/gan/cifar/eval.py" \
--checkpoint_dir=${UNCONDITIONAL_TRAIN_DIR} \
--eval_dir=${UNCONDITIONAL_EVAL_DIR} \
--dataset_dir=${DATASET_DIR} \
--eval_real_images=false \
--conditional_eval=false \
--max_number_of_evaluation=1
Banner "Finished unconditional evaluation. See ${UNCONDITIONAL_EVAL_DIR} for output images."
fi
# Run conditional GAN.
if [[ "$gan_type" == "conditional" ]]; then
CONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/conditional"
CONDITIONAL_EVAL_DIR="${EVAL_DIR}/conditional"
NUM_STEPS=10000
# Run training.
Banner "Starting training conditional GAN for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/cifar/train.py" \
--train_log_dir=${CONDITIONAL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--gan_type="conditional" \
--alsologtostderr
Banner "Finished training conditional GAN ${NUM_STEPS} steps."
# Run evaluation.
Banner "Starting evaluation of conditional GAN..."
python "${git_repo}/research/gan/cifar/eval.py" \
--checkpoint_dir=${CONDITIONAL_TRAIN_DIR} \
--eval_dir=${CONDITIONAL_EVAL_DIR} \
--dataset_dir=${DATASET_DIR} \
--eval_real_images=false \
--conditional_eval=true \
--max_number_of_evaluation=1
Banner "Finished conditional evaluation. See ${CONDITIONAL_EVAL_DIR} for output images."
fi
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN CIFAR example using TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.nets import dcgan
tfgan = tf.contrib.gan
def _last_conv_layer(end_points):
""""Returns the last convolutional layer from an endpoints dictionary."""
conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
conv_list.sort()
return end_points[conv_list[-1]]
def generator(noise):
"""Generator to produce CIFAR images.
Args:
noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
does not use conditioning, this Tensor represents a noise vector of some
kind that will be reshaped by the generator into CIFAR examples.
Returns:
A single Tensor with a batch of generated CIFAR images.
"""
images, _ = dcgan.generator(noise)
# Make sure output lies between [-1, 1].
return tf.tanh(images)
def conditional_generator(inputs):
"""Generator to produce CIFAR images.
Args:
inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
conditional generator.
Returns:
A single Tensor with a batch of generated CIFAR images.
"""
noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
images, _ = dcgan.generator(noise)
# Make sure output lies between [-1, 1].
return tf.tanh(images)
def discriminator(img, unused_conditioning):
"""Discriminator for CIFAR images.
Args:
img: A Tensor of shape [batch size, width, height, channels], that can be
either real or generated. It is the discriminator's goal to distinguish
between the two.
unused_conditioning: The TFGAN API can help with conditional GANs, which
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
argument.
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
"""
logits, _ = dcgan.discriminator(img)
return logits
# (joelshor): This discriminator creates variables that aren't used, and
# causes logging warnings. Improve `dcgan` nets to accept a target end layer,
# so extraneous variables aren't created.
def conditional_discriminator(img, conditioning):
"""Discriminator for CIFAR images.
Args:
img: A Tensor of shape [batch size, width, height, channels], that can be
either real or generated. It is the discriminator's goal to distinguish
between the two.
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
"""
logits, end_points = dcgan.discriminator(img)
# Condition the last convolution layer.
_, one_hot_labels = conditioning
net = _last_conv_layer(end_points)
net = tfgan.features.condition_tensor_from_onehot(
tf.contrib.layers.flatten(net), one_hot_labels)
logits = tf.contrib.layers.linear(net, 1)
return logits
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.cifar.networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import networks
class NetworksTest(tf.test.TestCase):
def test_generator(self):
tf.set_random_seed(1234)
batch_size = 100
noise = tf.random_normal([batch_size, 64])
image = networks.generator(noise)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
image_np = image.eval()
self.assertAllEqual([batch_size, 32, 32, 3], image_np.shape)
self.assertTrue(np.all(np.abs(image_np) <= 1))
def test_generator_conditional(self):
tf.set_random_seed(1234)
batch_size = 100
noise = tf.random_normal([batch_size, 64])
conditioning = tf.one_hot([0] * batch_size, 10)
image = networks.conditional_generator((noise, conditioning))
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
image_np = image.eval()
self.assertAllEqual([batch_size, 32, 32, 3], image_np.shape)
self.assertTrue(np.all(np.abs(image_np) <= 1))
def test_discriminator(self):
batch_size = 5
image = tf.random_uniform([batch_size, 32, 32, 3], -1, 1)
dis_output = networks.discriminator(image, None)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
dis_output_np = dis_output.eval()
self.assertAllEqual([batch_size, 1], dis_output_np.shape)
def test_discriminator_conditional(self):
batch_size = 5
image = tf.random_uniform([batch_size, 32, 32, 3], -1, 1)
conditioning = (None, tf.one_hot([0] * batch_size, 10))
dis_output = networks.conditional_discriminator(image, conditioning)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
dis_output_np = dis_output.eval()
self.assertAllEqual([batch_size, 1], dis_output_np.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a generator on CIFAR data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
tfgan = tf.contrib.gan
flags = tf.flags
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/cifar/',
'Directory where to write event logs.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
flags.DEFINE_integer('max_number_of_steps', 1000000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_boolean(
'conditional', False,
'If `True`, set up a conditional GAN. If False, it is unconditional.')
# Sync replicas flags.
flags.DEFINE_boolean(
'use_sync_replicas', True,
'If `True`, use sync replicas. Otherwise use async.')
flags.DEFINE_integer(
'worker_replicas', 10,
'The number of gradients to collect before updating params. Only used '
'with sync replicas.')
flags.DEFINE_integer(
'backup_workers', 1,
'Number of workers to be kept as backup in the sync replicas case.')
FLAGS = flags.FLAGS
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Force all input processing onto CPU in order to reserve the GPU for
# the forward inference and back-propagation.
with tf.name_scope('inputs'):
with tf.device('/cpu:0'):
images, one_hot_labels, _, _ = data_provider.provide_data(
FLAGS.batch_size, FLAGS.dataset_dir)
# Define the GANModel tuple.
noise = tf.random_normal([FLAGS.batch_size, 64])
if FLAGS.conditional:
generator_fn = networks.conditional_generator
discriminator_fn = networks.conditional_discriminator
generator_inputs = (noise, one_hot_labels)
else:
generator_fn = networks.generator
discriminator_fn = networks.discriminator
generator_inputs = noise
gan_model = tfgan.gan_model(
generator_fn,
discriminator_fn,
real_data=images,
generator_inputs=generator_inputs)
tfgan.eval.add_gan_model_image_summaries(gan_model)
# Get the GANLoss tuple. Use the selected GAN loss functions.
# (joelshor): Put this block in `with tf.name_scope('loss'):` when
# cl/171610946 goes into the opensource release.
gan_loss = tfgan.gan_loss(gan_model,
gradient_penalty_weight=1.0,
add_summaries=True)
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
with tf.name_scope('train'):
gen_lr, dis_lr = _learning_rate()
gen_opt, dis_opt = _optimizer(gen_lr, dis_lr, FLAGS.use_sync_replicas)
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
sync_hooks = ([gen_opt.make_session_run_hook(FLAGS.task == 0),
dis_opt.make_session_run_hook(FLAGS.task == 0)]
if FLAGS.use_sync_replicas else [])
status_message = tf.string_join(
['Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())],
name='status_message')
if FLAGS.max_number_of_steps == 0: return
tfgan.gan_train(
train_ops,
hooks=(
[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)] +
sync_hooks),
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0)
def _learning_rate():
generator_lr = tf.train.exponential_decay(
learning_rate=0.0001,
global_step=tf.train.get_or_create_global_step(),
decay_steps=100000,
decay_rate=0.9,
staircase=True)
discriminator_lr = 0.001
return generator_lr, discriminator_lr
def _optimizer(gen_lr, dis_lr, use_sync_replicas):
"""Get an optimizer, that's optionally synchronous."""
generator_opt = tf.train.RMSPropOptimizer(gen_lr, decay=.9, momentum=0.1)
discriminator_opt = tf.train.RMSPropOptimizer(dis_lr, decay=.95, momentum=0.1)
def _make_sync(opt):
return tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=FLAGS.worker_replicas-FLAGS.backup_workers,
total_num_replicas=FLAGS.worker_replicas)
if use_sync_replicas:
generator_opt = _make_sync(generator_opt)
discriminator_opt = _make_sync(discriminator_opt)
return generator_opt, discriminator_opt
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
def _test_build_graph_helper(self, conditional, use_sync_replicas):
FLAGS.max_number_of_steps = 0
FLAGS.conditional = conditional
FLAGS.use_sync_replicas = use_sync_replicas
FLAGS.batch_size = 16
# Mock input pipeline.
mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = (
mock_imgs, mock_lbls, None, None)
train.main(None)
def test_build_graph_unconditional(self):
self._test_build_graph_helper(False, False)
def test_build_graph_conditional(self):
self._test_build_graph_helper(True, False)
def test_build_graph_syncreplicas(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience functions for training and evaluating a TFGAN CIFAR example."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
tfgan = tf.contrib.gan
def get_generator_conditioning(batch_size, num_classes):
"""Generates TFGAN conditioning inputs for evaluation.
Args:
batch_size: A Python integer. The desired batch size.
num_classes: A Python integer. The number of classes.
Returns:
A Tensor of one-hot vectors corresponding to an even distribution over
classes.
Raises:
ValueError: If `batch_size` isn't evenly divisible by `num_classes`.
"""
if batch_size % num_classes != 0:
raise ValueError('`batch_size` %i must be evenly divisible by '
'`num_classes` %i.' % (batch_size, num_classes))
labels = [lbl for lbl in xrange(num_classes)
for _ in xrange(batch_size // num_classes)]
return tf.one_hot(tf.constant(labels), num_classes)
def get_image_grid(images, batch_size, num_classes, num_images_per_class):
"""Combines images from each class in a single summary image.
Args:
images: Tensor of images that are arranged by class. The first
`batch_size / num_classes` images belong to the first class, the second
group belong to the second class, etc. Shape is
[batch, width, height, channels].
batch_size: Python integer. Batch dimension.
num_classes: Number of classes to show.
num_images_per_class: Number of image examples per class to show.
Raises:
ValueError: If the batch dimension of `images` is known at graph
construction, and it isn't `batch_size`.
ValueError: If there aren't enough images to show
`num_classes * num_images_per_class` images.
ValueError: If `batch_size` isn't divisible by `num_classes`.
Returns:
A single image.
"""
# Validate inputs.
images.shape[0:1].assert_is_compatible_with([batch_size])
if batch_size < num_classes * num_images_per_class:
raise ValueError('Not enough images in batch to show the desired number of '
'images.')
if batch_size % num_classes != 0:
raise ValueError('`batch_size` must be divisible by `num_classes`.')
# Only get a certain number of images per class.
num_batches = batch_size // num_classes
indices = [i * num_batches + j for i in xrange(num_classes)
for j in xrange(num_images_per_class)]
sampled_images = tf.gather(images, indices)
return tfgan.eval.image_reshaper(
sampled_images, num_cols=num_images_per_class)
def get_inception_scores(images, batch_size, num_inception_images):
"""Get Inception score for some images.
Args:
images: Image minibatch. Shape [batch size, width, height, channels]. Values
are in [-1, 1].
batch_size: Python integer. Batch dimension.
num_inception_images: Number of images to run through Inception at once.
Returns:
Inception scores. Tensor shape is [batch size].
Raises:
ValueError: If `batch_size` is incompatible with the first dimension of
`images`.
ValueError: If `batch_size` isn't divisible by `num_inception_images`.
"""
# Validate inputs.
images.shape[0:1].assert_is_compatible_with([batch_size])
if batch_size % num_inception_images != 0:
raise ValueError(
'`batch_size` must be divisible by `num_inception_images`.')
# Resize images.
size = 299
resized_images = tf.image.resize_bilinear(images, [size, size])
# Run images through Inception.
num_batches = batch_size // num_inception_images
inc_score = tfgan.eval.inception_score(
resized_images, num_batches=num_batches)
return inc_score
def get_frechet_inception_distance(real_images, generated_images, batch_size,
num_inception_images):
"""Get Frechet Inception Distance between real and generated images.
Args:
real_images: Real images minibatch. Shape [batch size, width, height,
channels. Values are in [-1, 1].
generated_images: Generated images minibatch. Shape [batch size, width,
height, channels]. Values are in [-1, 1].
batch_size: Python integer. Batch dimension.
num_inception_images: Number of images to run through Inception at once.
Returns:
Frechet Inception distance. A floating-point scalar.
Raises:
ValueError: If the minibatch size is known at graph construction time, and
doesn't batch `batch_size`.
"""
# Validate input dimensions.
real_images.shape[0:1].assert_is_compatible_with([batch_size])
generated_images.shape[0:1].assert_is_compatible_with([batch_size])
# Resize input images.
size = 299
resized_real_images = tf.image.resize_bilinear(real_images, [size, size])
resized_generated_images = tf.image.resize_bilinear(
generated_images, [size, size])
# Compute Frechet Inception Distance.
num_batches = batch_size // num_inception_images
fid = tfgan.eval.frechet_inception_distance(
resized_real_images, resized_generated_images, num_batches=num_batches)
return fid
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.cifar.util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import util
mock = tf.test.mock
class UtilTest(tf.test.TestCase):
def test_get_generator_conditioning(self):
conditioning = util.get_generator_conditioning(12, 4)
self.assertEqual([12, 4], conditioning.shape.as_list())
def test_get_image_grid(self):
util.get_image_grid(
tf.zeros([6, 28, 28, 1]),
batch_size=6,
num_classes=3,
num_images_per_class=1)
def test_get_inception_scores(self):
# Mock `inception_score` which is expensive.
with mock.patch.object(
util.tfgan.eval, 'inception_score') as mock_inception_score:
mock_inception_score.return_value = 1.0
util.get_inception_scores(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
def test_get_frechet_inception_distance(self):
# Mock `frechet_inception_distance` which is expensive.
with mock.patch.object(
util.tfgan.eval, 'frechet_inception_distance') as mock_fid:
mock_fid.return_value = 1.0
util.get_frechet_inception_distance(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a conditional TFGAN trained MNIST model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
import util
flags = tf.flags
tfgan = tf.contrib.gan
flags.DEFINE_string('checkpoint_dir', '/tmp/mnist/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/mnist/',
'Directory where the results are saved to.')
flags.DEFINE_integer('num_images_per_class', 10,
'Number of images to generate per class.')
flags.DEFINE_integer('noise_dims', 64,
'Dimensions of the generator noise vector')
flags.DEFINE_string('classifier_filename', None,
'Location of the pretrained classifier. If `None`, use '
'default.')
flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'forever.')
FLAGS = flags.FLAGS
NUM_CLASSES = 10
def main(_, run_eval_loop=True):
with tf.name_scope('inputs'):
noise, one_hot_labels = _get_generator_inputs(
FLAGS.num_images_per_class, NUM_CLASSES, FLAGS.noise_dims)
# Generate images.
with tf.variable_scope('Generator'): # Same scope as in train job.
images = networks.conditional_generator((noise, one_hot_labels))
# Visualize images.
reshaped_img = tfgan.eval.image_reshaper(
images, num_cols=FLAGS.num_images_per_class)
tf.summary.image('generated_images', reshaped_img, max_outputs=1)
# Calculate evaluation metrics.
tf.summary.scalar('MNIST_Classifier_score',
util.mnist_score(images, FLAGS.classifier_filename))
tf.summary.scalar('MNIST_Cross_entropy',
util.mnist_cross_entropy(
images, one_hot_labels, FLAGS.classifier_filename))
# Write images to disk.
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'conditional_gan.png'),
tf.image.encode_png(data_provider.float_image_to_uint8(reshaped_img[0])))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
def _get_generator_inputs(num_images_per_class, num_classes, noise_dims):
# Since we want a grid of numbers for the conditional generator, manually
# construct the desired class labels.
num_images_generated = num_images_per_class * num_classes
noise = tf.random_normal([num_images_generated, noise_dims])
labels = [lbl for lbl in range(num_classes) for _
in range(num_images_per_class)]
one_hot_labels = tf.one_hot(tf.constant(labels), num_classes)
return noise, one_hot_labels
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.mnist.conditional_eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import conditional_eval
class ConditionalEvalTest(tf.test.TestCase):
def test_build_graph(self):
conditional_eval.main(None, run_eval_loop=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(split_name, batch_size, dataset_dir, num_readers=1,
num_threads=1):
"""Provides batches of MNIST digits.
Args:
split_name: Either 'train' or 'test'.
batch_size: The number of images in each batch.
dataset_dir: The directory where the MNIST data can be found.
num_readers: Number of dataset readers.
num_threads: Number of prefetching threads.
Returns:
images: A `Tensor` of size [batch_size, 28, 28, 1]
one_hot_labels: A `Tensor` of size [batch_size, mnist.NUM_CLASSES], where
each row has a single element set to one and the rest set to zeros.
num_samples: The number of total samples in the dataset.
Raises:
ValueError: If `split_name` is not either 'train' or 'test'.
"""
dataset = datasets.get_dataset('mnist', split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size,
shuffle=(split_name == 'train'))
[image, label] = provider.get(['image', 'label'])
# Preprocess the images.
image = (tf.to_float(image) - 128.0) / 128.0
# Creates a QueueRunner for the pre-fetching operation.
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size)
one_hot_labels = tf.one_hot(labels, dataset.num_classes)
return images, one_hot_labels, dataset.num_samples
def float_image_to_uint8(image):
"""Convert float image in [-1, 1) to [0, 255] uint8.
Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
Args:
image: An image tensor. Values should be in [-1, 1).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 128.0) + 128.0
return tf.cast(image, tf.uint8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment