Unverified Commit 0d4f35d9 authored by Joel Shor's avatar Joel Shor Committed by GitHub
Browse files

Merge pull request #3371 from joel-shor/master

Project import generated by Copybara. fixes #16593
parents 0b74e527 5f99e589
...@@ -63,6 +63,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None, ...@@ -63,6 +63,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run ' 'Number of times to run evaluation. If `None`, run '
'forever.') 'forever.')
flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.')
def main(_, run_eval_loop=True): def main(_, run_eval_loop=True):
# Fetch and generate images to run through Inception. # Fetch and generate images to run through Inception.
...@@ -97,7 +99,7 @@ def main(_, run_eval_loop=True): ...@@ -97,7 +99,7 @@ def main(_, run_eval_loop=True):
# Create ops that write images to disk. # Create ops that write images to disk.
image_write_ops = None image_write_ops = None
if FLAGS.conditional_eval: if FLAGS.conditional_eval and FLAGS.write_to_disk:
reshaped_imgs = util.get_image_grid( reshaped_imgs = util.get_image_grid(
generated_data, FLAGS.num_images_generated, num_classes, generated_data, FLAGS.num_images_generated, num_classes,
FLAGS.num_images_per_class) FLAGS.num_images_per_class)
...@@ -106,7 +108,7 @@ def main(_, run_eval_loop=True): ...@@ -106,7 +108,7 @@ def main(_, run_eval_loop=True):
'%s/%s'% (FLAGS.eval_dir, 'conditional_cifar10.png'), '%s/%s'% (FLAGS.eval_dir, 'conditional_cifar10.png'),
tf.image.encode_png(uint8_images[0])) tf.image.encode_png(uint8_images[0]))
else: else:
if FLAGS.num_images_generated >= 100: if FLAGS.num_images_generated >= 100 and FLAGS.write_to_disk:
reshaped_imgs = tfgan.eval.image_reshaper( reshaped_imgs = tfgan.eval.image_reshaper(
generated_data[:100], num_cols=FLAGS.num_images_per_class) generated_data[:100], num_cols=FLAGS.num_images_per_class)
uint8_images = data_provider.float_image_to_uint8(reshaped_imgs) uint8_images = data_provider.float_image_to_uint8(reshaped_imgs)
...@@ -147,7 +149,7 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes): ...@@ -147,7 +149,7 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
# In order for variables to load, use the same variable scope as in the # In order for variables to load, use the same variable scope as in the
# train job. # train job.
with tf.variable_scope('Generator'): with tf.variable_scope('Generator'):
data = generator_fn(generator_inputs) data = generator_fn(generator_inputs, is_training=False)
return data return data
......
...@@ -32,29 +32,35 @@ def _last_conv_layer(end_points): ...@@ -32,29 +32,35 @@ def _last_conv_layer(end_points):
return end_points[conv_list[-1]] return end_points[conv_list[-1]]
def generator(noise): def generator(noise, is_training=True):
"""Generator to produce CIFAR images. """Generator to produce CIFAR images.
Args: Args:
noise: A 2D Tensor of shape [batch size, noise dim]. Since this example noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
does not use conditioning, this Tensor represents a noise vector of some does not use conditioning, this Tensor represents a noise vector of some
kind that will be reshaped by the generator into CIFAR examples. kind that will be reshaped by the generator into CIFAR examples.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A single Tensor with a batch of generated CIFAR images. A single Tensor with a batch of generated CIFAR images.
""" """
images, _ = dcgan.generator(noise) images, _ = dcgan.generator(noise, is_training=is_training)
# Make sure output lies between [-1, 1]. # Make sure output lies between [-1, 1].
return tf.tanh(images) return tf.tanh(images)
def conditional_generator(inputs): def conditional_generator(inputs, is_training=True):
"""Generator to produce CIFAR images. """Generator to produce CIFAR images.
Args: Args:
inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
conditional generator. conditional generator.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A single Tensor with a batch of generated CIFAR images. A single Tensor with a batch of generated CIFAR images.
...@@ -62,13 +68,13 @@ def conditional_generator(inputs): ...@@ -62,13 +68,13 @@ def conditional_generator(inputs):
noise, one_hot_labels = inputs noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels) noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
images, _ = dcgan.generator(noise) images, _ = dcgan.generator(noise, is_training=is_training)
# Make sure output lies between [-1, 1]. # Make sure output lies between [-1, 1].
return tf.tanh(images) return tf.tanh(images)
def discriminator(img, unused_conditioning): def discriminator(img, unused_conditioning, is_training=True):
"""Discriminator for CIFAR images. """Discriminator for CIFAR images.
Args: Args:
...@@ -79,20 +85,23 @@ def discriminator(img, unused_conditioning): ...@@ -79,20 +85,23 @@ def discriminator(img, unused_conditioning):
would require extra `condition` information to both the generator and the would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this discriminator. Since this example is not conditional, we do not use this
argument. argument.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A 1D Tensor of shape [batch size] representing the confidence that the A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real. indicating high confidence that the images are real.
""" """
logits, _ = dcgan.discriminator(img) logits, _ = dcgan.discriminator(img, is_training=is_training)
return logits return logits
# (joelshor): This discriminator creates variables that aren't used, and # (joelshor): This discriminator creates variables that aren't used, and
# causes logging warnings. Improve `dcgan` nets to accept a target end layer, # causes logging warnings. Improve `dcgan` nets to accept a target end layer,
# so extraneous variables aren't created. # so extraneous variables aren't created.
def conditional_discriminator(img, conditioning): def conditional_discriminator(img, conditioning, is_training=True):
"""Discriminator for CIFAR images. """Discriminator for CIFAR images.
Args: Args:
...@@ -100,13 +109,16 @@ def conditional_discriminator(img, conditioning): ...@@ -100,13 +109,16 @@ def conditional_discriminator(img, conditioning):
either real or generated. It is the discriminator's goal to distinguish either real or generated. It is the discriminator's goal to distinguish
between the two. between the two.
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels). conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A 1D Tensor of shape [batch size] representing the confidence that the A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real. indicating high confidence that the images are real.
""" """
logits, end_points = dcgan.discriminator(img) logits, end_points = dcgan.discriminator(img, is_training=is_training)
# Condition the last convolution layer. # Condition the last convolution layer.
_, one_hot_labels = conditioning _, one_hot_labels = conditioning
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
......
...@@ -57,7 +57,8 @@ Banner () { ...@@ -57,7 +57,8 @@ Banner () {
echo -e "${green}${text}${nc}" echo -e "${green}${text}${nc}"
} }
# Download the dataset. # Download the dataset. You will be asked for an ImageNet username and password.
# To get one, register at http://www.image-net.org/.
bazel build "${git_repo}/research/slim:download_and_convert_imagenet" bazel build "${git_repo}/research/slim:download_and_convert_imagenet"
"./bazel-bin/download_and_convert_imagenet" ${DATASET_DIR} "./bazel-bin/download_and_convert_imagenet" ${DATASET_DIR}
......
...@@ -18,8 +18,8 @@ from __future__ import absolute_import ...@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from six.moves import xrange
import networks import networks
......
...@@ -49,6 +49,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None, ...@@ -49,6 +49,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run ' 'Number of times to run evaluation. If `None`, run '
'forever.') 'forever.')
flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
NUM_CLASSES = 10 NUM_CLASSES = 10
...@@ -60,7 +62,8 @@ def main(_, run_eval_loop=True): ...@@ -60,7 +62,8 @@ def main(_, run_eval_loop=True):
# Generate images. # Generate images.
with tf.variable_scope('Generator'): # Same scope as in train job. with tf.variable_scope('Generator'): # Same scope as in train job.
images = networks.conditional_generator((noise, one_hot_labels)) images = networks.conditional_generator(
(noise, one_hot_labels), is_training=False)
# Visualize images. # Visualize images.
reshaped_img = tfgan.eval.image_reshaper( reshaped_img = tfgan.eval.image_reshaper(
...@@ -75,9 +78,12 @@ def main(_, run_eval_loop=True): ...@@ -75,9 +78,12 @@ def main(_, run_eval_loop=True):
images, one_hot_labels, FLAGS.classifier_filename)) images, one_hot_labels, FLAGS.classifier_filename))
# Write images to disk. # Write images to disk.
image_write_ops = None
if FLAGS.write_to_disk:
image_write_ops = tf.write_file( image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'conditional_gan.png'), '%s/%s'% (FLAGS.eval_dir, 'conditional_gan.png'),
tf.image.encode_png(data_provider.float_image_to_uint8(reshaped_img[0]))) tf.image.encode_png(data_provider.float_image_to_uint8(
reshaped_img[0])))
# For unit testing, use `run_eval_loop=False`. # For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return if not run_eval_loop: return
......
...@@ -56,6 +56,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None, ...@@ -56,6 +56,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run ' 'Number of times to run evaluation. If `None`, run '
'forever.') 'forever.')
flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.')
def main(_, run_eval_loop=True): def main(_, run_eval_loop=True):
# Fetch real images. # Fetch real images.
...@@ -72,13 +74,14 @@ def main(_, run_eval_loop=True): ...@@ -72,13 +74,14 @@ def main(_, run_eval_loop=True):
# train job. # train job.
with tf.variable_scope('Generator'): with tf.variable_scope('Generator'):
images = networks.unconditional_generator( images = networks.unconditional_generator(
tf.random_normal([FLAGS.num_images_generated, FLAGS.noise_dims])) tf.random_normal([FLAGS.num_images_generated, FLAGS.noise_dims]),
is_training=False)
tf.summary.scalar('MNIST_Frechet_distance', tf.summary.scalar('MNIST_Frechet_distance',
util.mnist_frechet_distance( util.mnist_frechet_distance(
real_images, images, FLAGS.classifier_filename)) real_images, images, FLAGS.classifier_filename))
tf.summary.scalar('MNIST_Classifier_score', tf.summary.scalar('MNIST_Classifier_score',
util.mnist_score(images, FLAGS.classifier_filename)) util.mnist_score(images, FLAGS.classifier_filename))
if FLAGS.num_images_generated >= 100: if FLAGS.num_images_generated >= 100 and FLAGS.write_to_disk:
reshaped_images = tfgan.eval.image_reshaper( reshaped_images = tfgan.eval.image_reshaper(
images[:100, ...], num_cols=10) images[:100, ...], num_cols=10)
uint8_images = data_provider.float_image_to_uint8(reshaped_images) uint8_images = data_provider.float_image_to_uint8(reshaped_images)
......
...@@ -62,6 +62,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None, ...@@ -62,6 +62,8 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run ' 'Number of times to run evaluation. If `None`, run '
'forever.') 'forever.')
flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.')
CAT_SAMPLE_POINTS = np.arange(0, 10) CAT_SAMPLE_POINTS = np.arange(0, 10)
CONT_SAMPLE_POINTS = np.linspace(-2.0, 2.0, 10) CONT_SAMPLE_POINTS = np.linspace(-2.0, 2.0, 10)
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -79,7 +81,9 @@ def main(_, run_eval_loop=True): ...@@ -79,7 +81,9 @@ def main(_, run_eval_loop=True):
# Visualize the effect of each structured noise dimension on the generated # Visualize the effect of each structured noise dimension on the generated
# image. # image.
generator_fn = lambda x: networks.infogan_generator(x, len(CAT_SAMPLE_POINTS)) def generator_fn(inputs):
return networks.infogan_generator(
inputs, len(CAT_SAMPLE_POINTS), is_training=False)
with tf.variable_scope('Generator') as genscope: # Same scope as in training. with tf.variable_scope('Generator') as genscope: # Same scope as in training.
categorical_images = generator_fn(display_noise1) categorical_images = generator_fn(display_noise1)
reshaped_categorical_img = tfgan.eval.image_reshaper( reshaped_categorical_img = tfgan.eval.image_reshaper(
...@@ -106,6 +110,7 @@ def main(_, run_eval_loop=True): ...@@ -106,6 +110,7 @@ def main(_, run_eval_loop=True):
# Write images to disk. # Write images to disk.
image_write_ops = [] image_write_ops = []
if FLAGS.write_to_disk:
image_write_ops.append(_get_write_image_ops( image_write_ops.append(_get_write_image_ops(
FLAGS.eval_dir, 'categorical_infogan.png', reshaped_categorical_img[0])) FLAGS.eval_dir, 'categorical_infogan.png', reshaped_categorical_img[0]))
image_write_ops.append(_get_write_image_ops( image_write_ops.append(_get_write_image_ops(
......
...@@ -26,7 +26,7 @@ tfgan = tf.contrib.gan ...@@ -26,7 +26,7 @@ tfgan = tf.contrib.gan
def _generator_helper( def _generator_helper(
noise, is_conditional, one_hot_labels, weight_decay): noise, is_conditional, one_hot_labels, weight_decay, is_training):
"""Core MNIST generator. """Core MNIST generator.
This function is reused between the different GAN modes (unconditional, This function is reused between the different GAN modes (unconditional,
...@@ -37,6 +37,9 @@ def _generator_helper( ...@@ -37,6 +37,9 @@ def _generator_helper(
is_conditional: Whether to condition on labels. is_conditional: Whether to condition on labels.
one_hot_labels: Optional labels for conditioning. one_hot_labels: Optional labels for conditioning.
weight_decay: The value of the l2 weight decay. weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A generated image in the range [-1, 1]. A generated image in the range [-1, 1].
...@@ -45,6 +48,8 @@ def _generator_helper( ...@@ -45,6 +48,8 @@ def _generator_helper(
[layers.fully_connected, layers.conv2d_transpose], [layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
weights_regularizer=layers.l2_regularizer(weight_decay)): weights_regularizer=layers.l2_regularizer(weight_decay)):
with tf.contrib.framework.arg_scope(
[layers.batch_norm], is_training=is_training):
net = layers.fully_connected(noise, 1024) net = layers.fully_connected(noise, 1024)
if is_conditional: if is_conditional:
net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels) net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels)
...@@ -60,34 +65,42 @@ def _generator_helper( ...@@ -60,34 +65,42 @@ def _generator_helper(
return net return net
def unconditional_generator(noise, weight_decay=2.5e-5): def unconditional_generator(noise, weight_decay=2.5e-5, is_training=True):
"""Generator to produce unconditional MNIST images. """Generator to produce unconditional MNIST images.
Args: Args:
noise: A single Tensor representing noise. noise: A single Tensor representing noise.
weight_decay: The value of the l2 weight decay. weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A generated image in the range [-1, 1]. A generated image in the range [-1, 1].
""" """
return _generator_helper(noise, False, None, weight_decay) return _generator_helper(noise, False, None, weight_decay, is_training)
def conditional_generator(inputs, weight_decay=2.5e-5): def conditional_generator(inputs, weight_decay=2.5e-5, is_training=True):
"""Generator to produce MNIST images conditioned on class. """Generator to produce MNIST images conditioned on class.
Args: Args:
inputs: A 2-tuple of Tensors (noise, one_hot_labels). inputs: A 2-tuple of Tensors (noise, one_hot_labels).
weight_decay: The value of the l2 weight decay. weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A generated image in the range [-1, 1]. A generated image in the range [-1, 1].
""" """
noise, one_hot_labels = inputs noise, one_hot_labels = inputs
return _generator_helper(noise, True, one_hot_labels, weight_decay) return _generator_helper(
noise, True, one_hot_labels, weight_decay, is_training)
def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5): def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5,
is_training=True):
"""InfoGAN generator network on MNIST digits. """InfoGAN generator network on MNIST digits.
Based on a paper https://arxiv.org/abs/1606.03657, their code Based on a paper https://arxiv.org/abs/1606.03657, their code
...@@ -99,6 +112,9 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5): ...@@ -99,6 +112,9 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5):
2D, and `inputs[1]` must be 1D. All must have the same first dimension. 2D, and `inputs[1]` must be 1D. All must have the same first dimension.
categorical_dim: Dimensions of the incompressible categorical noise. categorical_dim: Dimensions of the incompressible categorical noise.
weight_decay: The value of the l2 weight decay. weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns: Returns:
A generated image in the range [-1, 1]. A generated image in the range [-1, 1].
...@@ -107,7 +123,7 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5): ...@@ -107,7 +123,7 @@ def infogan_generator(inputs, categorical_dim, weight_decay=2.5e-5):
cat_noise_onehot = tf.one_hot(cat_noise, categorical_dim) cat_noise_onehot = tf.one_hot(cat_noise, categorical_dim)
all_noise = tf.concat( all_noise = tf.concat(
[unstructured_noise, cat_noise_onehot, cont_noise], axis=1) [unstructured_noise, cat_noise_onehot, cont_noise], axis=1)
return _generator_helper(all_noise, False, None, weight_decay) return _generator_helper(all_noise, False, None, weight_decay, is_training)
_leaky_relu = lambda x: tf.nn.leaky_relu(x, alpha=0.01) _leaky_relu = lambda x: tf.nn.leaky_relu(x, alpha=0.01)
......
...@@ -24,7 +24,7 @@ from __future__ import print_function ...@@ -24,7 +24,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
ds = tf.contrib.distributions ds = tf.contrib.distributions
......
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import numpy as np import numpy as np
import scipy.misc import scipy.misc
from six.moves import xrange from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from mnist import data_provider from mnist import data_provider
...@@ -66,10 +66,16 @@ def _get_predict_input_fn(batch_size, noise_dims): ...@@ -66,10 +66,16 @@ def _get_predict_input_fn(batch_size, noise_dims):
return predict_input_fn return predict_input_fn
def _unconditional_generator(noise, mode):
"""MNIST generator with extra argument for tf.Estimator's `mode`."""
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
return networks.unconditional_generator(noise, is_training=is_training)
def main(_): def main(_):
# Initialize GANEstimator with options and hyperparameters. # Initialize GANEstimator with options and hyperparameters.
gan_estimator = tfgan.estimator.GANEstimator( gan_estimator = tfgan.estimator.GANEstimator(
generator_fn=networks.unconditional_generator, generator_fn=_unconditional_generator,
discriminator_fn=networks.unconditional_discriminator, discriminator_fn=networks.unconditional_discriminator,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss, generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment