Commit 4dcd5116 authored by Jing Li's avatar Jing Li
Browse files

Removed deprecated research/gan. Please visit https://github.com/tensorflow/gan.

parent f673f7a8
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a GANEstimator on MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
import scipy.misc
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from mnist import data_provider
from mnist import networks
tfgan = tf.contrib.gan
flags.DEFINE_integer('batch_size', 32,
'The number of images in each train batch.')
flags.DEFINE_integer('max_number_of_steps', 20000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'noise_dims', 64, 'Dimensions of the generator noise vector')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
flags.DEFINE_string('eval_dir', '/tmp/mnist-estimator/',
'Directory where the results are saved to.')
FLAGS = flags.FLAGS
def _get_train_input_fn(batch_size, noise_dims, dataset_dir=None,
num_threads=4):
def train_input_fn():
with tf.device('/cpu:0'):
images, _, _ = data_provider.provide_data(
'train', batch_size, dataset_dir, num_threads=num_threads)
noise = tf.random_normal([batch_size, noise_dims])
return noise, images
return train_input_fn
def _get_predict_input_fn(batch_size, noise_dims):
def predict_input_fn():
noise = tf.random_normal([batch_size, noise_dims])
return noise
return predict_input_fn
def _unconditional_generator(noise, mode):
"""MNIST generator with extra argument for tf.Estimator's `mode`."""
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
return networks.unconditional_generator(noise, is_training=is_training)
def main(_):
# Initialize GANEstimator with options and hyperparameters.
gan_estimator = tfgan.estimator.GANEstimator(
generator_fn=_unconditional_generator,
discriminator_fn=networks.unconditional_discriminator,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
generator_optimizer=tf.train.AdamOptimizer(0.001, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(0.0001, 0.5),
add_summaries=tfgan.estimator.SummaryType.IMAGES)
# Train estimator.
train_input_fn = _get_train_input_fn(
FLAGS.batch_size, FLAGS.noise_dims, FLAGS.dataset_dir)
gan_estimator.train(train_input_fn, max_steps=FLAGS.max_number_of_steps)
# Run inference.
predict_input_fn = _get_predict_input_fn(36, FLAGS.noise_dims)
prediction_iterable = gan_estimator.predict(predict_input_fn)
predictions = [prediction_iterable.next() for _ in xrange(36)]
# Nicely tile.
image_rows = [np.concatenate(predictions[i:i+6], axis=0) for i in
range(0, 36, 6)]
tiled_image = np.concatenate(image_rows, axis=1)
# Write to disk.
if not tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.MakeDirs(FLAGS.eval_dir)
scipy.misc.imsave(os.path.join(FLAGS.eval_dir, 'unconditional_gan.png'),
np.squeeze(tiled_image, axis=2))
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mnist_estimator.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import numpy as np
import tensorflow as tf
import train
FLAGS = flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
@mock.patch.object(train, 'data_provider', autospec=True)
def test_full_flow(self, mock_data_provider):
FLAGS.eval_dir = self.get_temp_dir()
FLAGS.batch_size = 16
FLAGS.max_number_of_steps = 2
FLAGS.noise_dims = 3
# Construct mock inputs.
mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None)
train.main(None)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Imagenet dataset.
# 2. Trains image compression model on patches from Imagenet.
# 3. Evaluates the models and writes sample images to disk.
#
# Usage:
# cd models/research/gan/image_compression
# ./launch_jobs.sh ${weight_factor} ${git_repo}
set -e
# Weight of the adversarial loss.
weight_factor=$1
if [[ "$weight_factor" == "" ]]; then
echo "'weight_factor' must not be empty."
exit
fi
# Location of the git repository.
git_repo=$2
if [[ "$git_repo" == "" ]]; then
echo "'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/compression-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR=/tmp/compression-model/eval
# Where the dataset is saved to.
DATASET_DIR=/tmp/imagenet-data
export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/slim:$git_repo/research/slim/nets
# A helper function for printing pretty output.
Banner () {
local text=$1
local green='\033[0;32m'
local nc='\033[0m' # No color.
echo -e "${green}${text}${nc}"
}
# Download the dataset.
bazel build "${git_repo}/research/slim:download_and_convert_imagenet"
"./bazel-bin/download_and_convert_imagenet" ${DATASET_DIR}
# Run the pix2pix model.
NUM_STEPS=10000
MODEL_TRAIN_DIR="${TRAIN_DIR}/wt${weight_factor}"
Banner "Starting training an image compression model for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/image_compression/train.py" \
--train_log_dir=${MODEL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--weight_factor=${weight_factor} \
--alsologtostderr
Banner "Finished training pix2pix model ${NUM_STEPS} steps."
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN Pix2Pix example using TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.nets import cyclegan
from slim.nets import pix2pix
def generator(input_images):
"""Thin wrapper around CycleGAN generator to conform to the TFGAN API.
Args:
input_images: A batch of images to translate. Images should be normalized
already. Shape is [batch, height, width, channels].
Returns:
Returns generated image batch.
Raises:
ValueError: If shape of last dimension (channels) is not defined.
"""
input_images.shape.assert_has_rank(4)
input_size = input_images.shape.as_list()
channels = input_size[-1]
if channels is None:
raise ValueError(
'Last dimension shape must be known but is None: %s' % input_size)
with tf.contrib.framework.arg_scope(cyclegan.cyclegan_arg_scope()):
output_images, _ = cyclegan.cyclegan_generator_resnet(input_images,
num_outputs=channels)
return output_images
def discriminator(image_batch, unused_conditioning=None):
"""A thin wrapper around the Pix2Pix discriminator to conform to TFGAN API."""
with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
logits_4d, _ = pix2pix.pix2pix_discriminator(
image_batch, num_filters=[64, 128, 256, 512])
logits_4d.shape.assert_has_rank(4)
# Output of logits is 4D. Reshape to 2D, for TFGAN.
logits_2d = tf.contrib.layers.flatten(logits_4d)
return logits_2d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.networks.networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import networks
class Pix2PixTest(tf.test.TestCase):
def test_generator_run(self):
img_batch = tf.zeros([3, 128, 128, 3])
model_output = networks.generator(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(model_output)
def test_generator_graph(self):
for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]):
tf.reset_default_graph()
img = tf.ones(shape + [3])
output_imgs = networks.generator(img)
self.assertAllEqual(shape + [3], output_imgs.shape.as_list())
def test_generator_graph_unknown_batch_dim(self):
img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
output_imgs = networks.generator(img)
self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list())
def test_generator_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
networks.generator(tf.zeros([28, 28, 3]))
def test_generator_run_multi_channel(self):
img_batch = tf.zeros([3, 128, 128, 5])
model_output = networks.generator(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(model_output)
def test_generator_invalid_channels(self):
with self.assertRaisesRegexp(
ValueError, 'Last dimension shape must be known but is None'):
img = tf.placeholder(tf.float32, shape=[4, 32, 32, None])
networks.generator(img)
def test_discriminator_run(self):
img_batch = tf.zeros([3, 70, 70, 3])
disc_output = networks.discriminator(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(disc_output)
def test_discriminator_graph(self):
# Check graph construction for a number of image size/depths and batch
# sizes.
for batch_size, patch_size in zip([3, 6], [70, 128]):
tf.reset_default_graph()
img = tf.ones([batch_size, patch_size, patch_size, 3])
disc_output = networks.discriminator(img)
self.assertEqual(2, disc_output.shape.ndims)
self.assertEqual(batch_size, disc_output.shape.as_list()[0])
def test_discriminator_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'Shape must be rank 4'):
networks.discriminator(tf.zeros([28, 28, 3]))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains an image-to-image translation network with an adversarial loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
import data_provider
import networks
tfgan = tf.contrib.gan
flags.DEFINE_integer('batch_size', 10, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/pix2pix/',
'Directory where to write event logs.')
flags.DEFINE_float('generator_lr', 0.00001,
'The compression model learning rate.')
flags.DEFINE_float('discriminator_lr', 0.00001,
'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 2000000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_float(
'weight_factor', 0.0,
'How much to weight the adversarial loss relative to pixel loss.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
FLAGS = flags.FLAGS
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Get real and distorted images.
with tf.device('/cpu:0'), tf.name_scope('inputs'):
real_images = data_provider.provide_data(
'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
patch_size=FLAGS.patch_size)
distorted_images = _distort_images(
real_images, downscale_size=int(FLAGS.patch_size / 2),
upscale_size=FLAGS.patch_size)
# Create a GANModel tuple.
gan_model = tfgan.gan_model(
generator_fn=networks.generator,
discriminator_fn=networks.discriminator,
real_data=real_images,
generator_inputs=distorted_images)
tfgan.eval.add_image_comparison_summaries(
gan_model, num_comparisons=3, display_diffs=True)
tfgan.eval.add_gan_model_image_summaries(gan_model, grid_size=3)
# Define the GANLoss tuple using standard library functions.
with tf.name_scope('losses'):
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.least_squares_generator_loss,
discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
# Define the standard L1 pixel loss.
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data,
ord=1) / FLAGS.patch_size ** 2
# Modify the loss tuple to include the pixel loss. Add summaries as well.
gan_loss = tfgan.losses.combine_adversarial_loss(
gan_loss, gan_model, l1_pixel_loss,
weight_factor=FLAGS.weight_factor)
with tf.name_scope('train_ops'):
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
gen_lr, dis_lr = _lr(FLAGS.generator_lr, FLAGS.discriminator_lr)
gen_opt, dis_opt = _optimizer(gen_lr, dis_lr)
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N,
transform_grads_fn=tf.contrib.training.clip_gradient_norms_fn(1e3))
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
# Use GAN train step function if using adversarial loss, otherwise
# only train the generator.
train_steps = tfgan.GANTrainSteps(
generator_train_steps=1,
discriminator_train_steps=int(FLAGS.weight_factor > 0))
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
status_message = tf.string_join(
['Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())],
name='status_message')
if FLAGS.max_number_of_steps == 0: return
tfgan.gan_train(
train_ops,
FLAGS.train_log_dir,
get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)],
master=FLAGS.master,
is_chief=FLAGS.task == 0)
def _optimizer(gen_lr, dis_lr):
kwargs = {'beta1': 0.5, 'beta2': 0.999}
generator_opt = tf.train.AdamOptimizer(gen_lr, **kwargs)
discriminator_opt = tf.train.AdamOptimizer(dis_lr, **kwargs)
return generator_opt, discriminator_opt
def _lr(gen_lr_base, dis_lr_base):
"""Return the generator and discriminator learning rates."""
gen_lr = tf.train.exponential_decay(
learning_rate=gen_lr_base,
global_step=tf.train.get_or_create_global_step(),
decay_steps=100000,
decay_rate=0.8,
staircase=True,)
dis_lr = dis_lr_base
return gen_lr, dis_lr
def _distort_images(images, downscale_size, upscale_size):
downscaled = tf.image.resize_area(images, [downscale_size] * 2)
upscaled = tf.image.resize_area(downscaled, [upscale_size] * 2)
return upscaled
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for pix2pix.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import train
FLAGS = flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('NoAdversarialLoss', 0.0),
('AdversarialLoss', 1.0))
def test_build_graph(self, weight_factor):
FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor
FLAGS.batch_size = 9
FLAGS.patch_size = 32
mock_imgs = np.zeros(
[FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3],
dtype=np.float32)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = mock_imgs
train.main(None)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loading and preprocessing image data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
def normalize_image(image):
"""Rescales image from range [0, 255] to [-1, 1]."""
return (tf.to_float(image) - 127.5) / 127.5
def sample_patch(image, patch_height, patch_width, colors):
"""Crops image to the desired aspect ratio shape and resizes it.
If the image has shape H x W, crops a square in the center of
shape min(H,W) x min(H,W).
Args:
image: A 3D `Tensor` of HWC format.
patch_height: A Python integer. The output images height.
patch_width: A Python integer. The output images width.
colors: Number of output image channels. Defaults to 3.
Returns:
A 3D `Tensor` of HWC format with shape [patch_height, patch_width, colors].
"""
image_shape = tf.shape(image)
h, w = image_shape[0], image_shape[1]
h_major_target_h = h
h_major_target_w = tf.maximum(1, tf.to_int32(
(h * patch_width) / patch_height))
w_major_target_h = tf.maximum(1, tf.to_int32(
(w * patch_height) / patch_width))
w_major_target_w = w
target_hw = tf.cond(
h_major_target_w <= w,
lambda: tf.convert_to_tensor([h_major_target_h, h_major_target_w]),
lambda: tf.convert_to_tensor([w_major_target_h, w_major_target_w]))
# Cut a patch of shape (target_h, target_w).
image = tf.image.resize_image_with_crop_or_pad(image, target_hw[0],
target_hw[1])
# Resize the cropped image to (patch_h, patch_w).
image = tf.image.resize_images([image], [patch_height, patch_width])[0]
# Force number of channels: repeat the channel dimension enough
# number of times and then slice the first `colors` channels.
num_repeats = tf.to_int32(tf.ceil(colors / image_shape[2]))
image = tf.tile(image, [1, 1, num_repeats])
image = tf.slice(image, [0, 0, 0], [-1, -1, colors])
image.set_shape([patch_height, patch_width, colors])
return image
def batch_images(image, patch_height, patch_width, colors, batch_size, shuffle,
num_threads):
"""Creates a batch of images.
Args:
image: A 3D `Tensor` of HWC format.
patch_height: A Python integer. The output images height.
patch_width: A Python integer. The output images width.
colors: Number of channels.
batch_size: The number of images in each minibatch. Defaults to 32.
shuffle: Whether to shuffle the read images.
num_threads: Number of prefetching threads.
Returns:
A float `Tensor`s with shape [batch_size, patch_height, patch_width, colors]
representing a batch of images.
"""
image = sample_patch(image, patch_height, patch_width, colors)
images = None
if shuffle:
images = tf.train.shuffle_batch(
[image],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
images = tf.train.batch(
[image],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)
images.set_shape([batch_size, patch_height, patch_width, colors])
return images
def provide_data(dataset_name='cifar10',
split_name='train',
dataset_dir,
batch_size=32,
shuffle=True,
num_threads=1,
patch_height=32,
patch_width=32,
colors=3):
"""Provides a batch of image data from predefined dataset.
Args:
dataset_name: A string of dataset name. Defaults to 'cifar10'.
split_name: Either 'train' or 'validation'. Defaults to 'train'.
dataset_dir: The directory where the data can be found. If `None`, use
default.
batch_size: The number of images in each minibatch. Defaults to 32.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_height: A Python integer. The read images height. Defaults to 32.
patch_width: A Python integer. The read images width. Defaults to 32.
colors: Number of channels. Defaults to 3.
Returns:
A float `Tensor`s with shape [batch_size, patch_height, patch_width, colors]
representing a batch of images.
"""
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=1,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=shuffle)
return batch_images(
image=normalize_image(provider.get(['image'])[0]),
patch_height=patch_height,
patch_width=patch_width,
colors=colors,
batch_size=batch_size,
shuffle=shuffle,
num_threads=num_threads)
def provide_data_from_image_files(file_pattern,
batch_size=32,
shuffle=True,
num_threads=1,
patch_height=32,
patch_width=32,
colors=3):
"""Provides a batch of image data from image files.
Args:
file_pattern: A file pattern (glob), or 1D `Tensor` of file patterns.
batch_size: The number of images in each minibatch. Defaults to 32.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_height: A Python integer. The read images height. Defaults to 32.
patch_width: A Python integer. The read images width. Defaults to 32.
colors: Number of channels. Defaults to 3.
Returns:
A float `Tensor` of shape [batch_size, patch_height, patch_width, 3]
representing a batch of images.
"""
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(file_pattern),
shuffle=shuffle,
capacity=5 * batch_size)
_, image_bytes = tf.WholeFileReader().read(filename_queue)
return batch_images(
image=normalize_image(tf.image.decode_image(image_bytes)),
patch_height=patch_height,
patch_width=patch_width,
colors=colors,
batch_size=batch_size,
shuffle=shuffle,
num_threads=num_threads)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
def setUp(self):
super(DataProviderTest, self).setUp()
self.testdata_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/progressive_gan/testdata/')
def test_normalize_image(self):
image_np = np.asarray([0, 255, 210], dtype=np.uint8)
normalized_image = data_provider.normalize_image(tf.constant(image_np))
self.assertEqual(normalized_image.dtype, tf.float32)
self.assertEqual(normalized_image.shape.as_list(), [3])
with self.test_session(use_gpu=True) as sess:
normalized_image_np = sess.run(normalized_image)
self.assertNDArrayNear(normalized_image_np, [-1, 1, 0.6470588235], 1.0e-6)
def test_sample_patch_large_patch_returns_upscaled_image(self):
image_np = np.reshape(np.arange(2 * 2), [2, 2, 1])
image = tf.constant(image_np, dtype=tf.float32)
image_patch = data_provider.sample_patch(
image, patch_height=3, patch_width=3, colors=1)
with self.test_session(use_gpu=True) as sess:
image_patch_np = sess.run(image_patch)
expected_np = np.asarray([[[0.], [0.66666669], [1.]], [[1.33333337], [2.],
[2.33333349]],
[[2.], [2.66666675], [3.]]])
self.assertNDArrayNear(image_patch_np, expected_np, 1.0e-6)
def test_sample_patch_small_patch_returns_downscaled_image(self):
image_np = np.reshape(np.arange(3 * 3), [3, 3, 1])
image = tf.constant(image_np, dtype=tf.float32)
image_patch = data_provider.sample_patch(
image, patch_height=2, patch_width=2, colors=1)
with self.test_session(use_gpu=True) as sess:
image_patch_np = sess.run(image_patch)
expected_np = np.asarray([[[0.], [1.5]], [[4.5], [6.]]])
self.assertNDArrayNear(image_patch_np, expected_np, 1.0e-6)
def test_batch_images(self):
image_np = np.reshape(np.arange(3 * 3), [3, 3, 1])
image = tf.constant(image_np, dtype=tf.float32)
images = data_provider.batch_images(
image,
patch_height=2,
patch_width=2,
colors=1,
batch_size=2,
shuffle=False,
num_threads=1)
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
expected_np = np.asarray([[[[0.], [1.5]], [[4.5], [6.]]], [[[0.], [1.5]],
[[4.5], [6.]]]])
self.assertNDArrayNear(images_np, expected_np, 1.0e-6)
def test_provide_data(self):
images = data_provider.provide_data(
'mnist',
'train',
dataset_dir=self.testdata_dir,
batch_size=2,
shuffle=False,
patch_height=3,
patch_width=3,
colors=1)
self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
self.assertEqual(images_np.shape, (2, 3, 3, 1))
def test_provide_data_from_image_files_a_single_pattern(self):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
images = data_provider.provide_data_from_image_files(
file_pattern,
batch_size=2,
shuffle=False,
patch_height=3,
patch_width=3,
colors=1)
self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
self.assertEqual(images_np.shape, (2, 3, 3, 1))
def test_provide_data_from_image_files_a_list_of_patterns(self):
file_pattern = [os.path.join(self.testdata_dir, '*.jpg')]
images = data_provider.provide_data_from_image_files(
file_pattern,
batch_size=2,
shuffle=False,
patch_height=3,
patch_width=3,
colors=1)
self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
self.assertEqual(images_np.shape, (2, 3, 3, 1))
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for a progressive GAN model.
This module contains basic building blocks to build a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def pixel_norm(images, epsilon=1.0e-8):
"""Pixel normalization.
For each pixel a[i,j,k] of image in HWC format, normalize its value to
b[i,j,k] = a[i,j,k] / SQRT(SUM_k(a[i,j,k]^2) / C + eps).
Args:
images: A 4D `Tensor` of NHWC format.
epsilon: A small positive number to avoid division by zero.
Returns:
A 4D `Tensor` with pixel-wise normalized channels.
"""
return images * tf.rsqrt(
tf.reduce_mean(tf.square(images), axis=3, keepdims=True) + epsilon)
def _get_validated_scale(scale):
"""Returns the scale guaranteed to be a positive integer."""
scale = int(scale)
if scale <= 0:
raise ValueError('`scale` must be a positive integer.')
return scale
def downscale(images, scale):
"""Box downscaling of images.
Args:
images: A 4D `Tensor` in NHWC format.
scale: A positive integer scale.
Returns:
A 4D `Tensor` of `images` down scaled by a factor `scale`.
Raises:
ValueError: If `scale` is not a positive integer.
"""
scale = _get_validated_scale(scale)
if scale == 1:
return images
return tf.nn.avg_pool(
images,
ksize=[1, scale, scale, 1],
strides=[1, scale, scale, 1],
padding='VALID')
def upscale(images, scale):
"""Box upscaling (also called nearest neighbors) of images.
Args:
images: A 4D `Tensor` in NHWC format.
scale: A positive integer scale.
Returns:
A 4D `Tensor` of `images` up scaled by a factor `scale`.
Raises:
ValueError: If `scale` is not a positive integer.
"""
scale = _get_validated_scale(scale)
if scale == 1:
return images
return tf.batch_to_space(
tf.tile(images, [scale**2, 1, 1, 1]),
crops=[[0, 0], [0, 0]],
block_size=scale)
def minibatch_mean_stddev(x):
"""Computes the standard deviation average.
This is used by the discriminator as a form of batch discrimination.
Args:
x: A `Tensor` for which to compute the standard deviation average. The first
dimension must be batch size.
Returns:
A scalar `Tensor` which is the mean variance of variable x.
"""
mean, var = tf.nn.moments(x, axes=[0])
del mean
return tf.reduce_mean(tf.sqrt(var))
def scalar_concat(tensor, scalar):
"""Concatenates a scalar to the last dimension of a tensor.
Args:
tensor: A `Tensor`.
scalar: a scalar `Tensor` to concatenate to tensor `tensor`.
Returns:
A `Tensor`. If `tensor` has shape [...,N], the result R has shape
[...,N+1] and R[...,N] = scalar.
Raises:
ValueError: If `tensor` is a scalar `Tensor`.
"""
ndims = tensor.shape.ndims
if ndims < 1:
raise ValueError('`tensor` must have number of dimensions >= 1.')
shape = tf.shape(tensor)
return tf.concat(
[tensor, tf.ones([shape[i] for i in range(ndims - 1)] + [1]) * scalar],
axis=ndims - 1)
def he_initializer_scale(shape, slope=1.0):
"""The scale of He neural network initializer.
Args:
shape: A list of ints representing the dimensions of a tensor.
slope: A float representing the slope of the ReLu following the layer.
Returns:
A float of he initializer scale.
"""
fan_in = np.prod(shape[:-1])
return np.sqrt(2. / ((1. + slope**2) * fan_in))
def _custom_layer_impl(apply_kernel, kernel_shape, bias_shape, activation,
he_initializer_slope, use_weight_scaling):
"""Helper function to implement custom_xxx layer.
Args:
apply_kernel: A function that transforms kernel to output.
kernel_shape: An integer tuple or list of the kernel shape.
bias_shape: An integer tuple or list of the bias shape.
activation: An activation function to be applied. None means no
activation.
he_initializer_slope: A float slope for the He initializer.
use_weight_scaling: Whether to apply weight scaling.
Returns:
A `Tensor` computed as apply_kernel(kernel) + bias where kernel is a
`Tensor` variable with shape `kernel_shape`, bias is a `Tensor` variable
with shape `bias_shape`.
"""
kernel_scale = he_initializer_scale(kernel_shape, he_initializer_slope)
init_scale, post_scale = kernel_scale, 1.0
if use_weight_scaling:
init_scale, post_scale = post_scale, init_scale
kernel_initializer = tf.random_normal_initializer(stddev=init_scale)
bias = tf.get_variable(
'bias', shape=bias_shape, initializer=tf.zeros_initializer())
output = post_scale * apply_kernel(kernel_shape, kernel_initializer) + bias
if activation is not None:
output = activation(output)
return output
def custom_conv2d(x,
filters,
kernel_size,
strides=(1, 1),
padding='SAME',
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True,
scope='custom_conv2d',
reuse=None):
"""Custom conv2d layer.
In comparison with tf.layers.conv2d this implementation use the He initializer
to initialize convolutional kernel and the weight scaling trick (if
`use_weight_scaling` is True) to equalize learning rates. See
https://arxiv.org/abs/1710.10196 for more details.
Args:
x: A `Tensor` of NHWC format.
filters: An int of output channels.
kernel_size: An integer or a int tuple of [kernel_height, kernel_width].
strides: A list of strides.
padding: One of "VALID" or "SAME".
activation: An activation function to be applied. None means no
activation. Defaults to None.
he_initializer_slope: A float slope for the He initializer. Defaults to 1.0.
use_weight_scaling: Whether to apply weight scaling. Defaults to True.
scope: A string or variable scope.
reuse: Whether to reuse the weights. Defaults to None.
Returns:
A `Tensor` of NHWC format where the last dimension has size `filters`.
"""
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * 2
kernel_size = list(kernel_size)
def _apply_kernel(kernel_shape, kernel_initializer):
return tf.layers.conv2d(
x,
filters=filters,
kernel_size=kernel_shape[0:2],
strides=strides,
padding=padding,
use_bias=False,
kernel_initializer=kernel_initializer)
with tf.variable_scope(scope, reuse=reuse):
return _custom_layer_impl(
_apply_kernel,
kernel_shape=kernel_size + [x.shape.as_list()[3], filters],
bias_shape=(filters,),
activation=activation,
he_initializer_slope=he_initializer_slope,
use_weight_scaling=use_weight_scaling)
def custom_dense(x,
units,
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True,
scope='custom_dense',
reuse=None):
"""Custom dense layer.
In comparison with tf.layers.dense This implementation use the He
initializer to initialize weights and the weight scaling trick
(if `use_weight_scaling` is True) to equalize learning rates. See
https://arxiv.org/abs/1710.10196 for more details.
Args:
x: A `Tensor`.
units: An int of the last dimension size of output.
activation: An activation function to be applied. None means no
activation. Defaults to None.
he_initializer_slope: A float slope for the He initializer. Defaults to 1.0.
use_weight_scaling: Whether to apply weight scaling. Defaults to True.
scope: A string or variable scope.
reuse: Whether to reuse the weights. Defaults to None.
Returns:
A `Tensor` where the last dimension has size `units`.
"""
x = tf.contrib.layers.flatten(x)
def _apply_kernel(kernel_shape, kernel_initializer):
return tf.layers.dense(
x,
kernel_shape[1],
use_bias=False,
kernel_initializer=kernel_initializer)
with tf.variable_scope(scope, reuse=reuse):
return _custom_layer_impl(
_apply_kernel,
kernel_shape=(x.shape.as_list()[-1], units),
bias_shape=(units,),
activation=activation,
he_initializer_slope=he_initializer_slope,
use_weight_scaling=use_weight_scaling)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import layers
mock = tf.test.mock
def dummy_apply_kernel(kernel_shape, kernel_initializer):
kernel = tf.get_variable(
'kernel', shape=kernel_shape, initializer=kernel_initializer)
return tf.reduce_sum(kernel) + 1.5
class LayersTest(tf.test.TestCase):
def test_pixel_norm_4d_images_returns_channel_normalized_images(self):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
with self.test_session(use_gpu=True) as sess:
output_np = sess.run(layers.pixel_norm(x))
expected_np = [[[[0.46291006, 0.92582011, 1.38873017],
[0.78954202, 0.98692751, 1.18431306]],
[[0.87047803, 0.99483204, 1.11918604],
[0.90659684, 0.99725652, 1.08791625]]],
[[[0., 0., 0.], [-0.46291006, -0.92582011, -1.38873017]],
[[0.57735026, -1.15470052, 1.15470052],
[0.56195146, 1.40487862, 0.84292722]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
def test_get_validated_scale_invalid_scale_throws_exception(self):
with self.assertRaises(ValueError):
layers._get_validated_scale(0)
def test_get_validated_scale_float_scale_returns_integer(self):
self.assertEqual(layers._get_validated_scale(5.5), 5)
def test_downscale_invalid_scale_throws_exception(self):
with self.assertRaises(ValueError):
layers.downscale(tf.constant([]), -1)
def test_downscale_4d_images_returns_downscaled_images(self):
x_np = np.array(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=np.float32)
with self.test_session(use_gpu=True) as sess:
x1_np, x2_np = sess.run(
[layers.downscale(tf.constant(x_np), n) for n in [1, 2]])
expected2_np = [[[[5.5, 6.5, 7.5]]], [[[0.5, 0.25, 0.5]]]]
self.assertNDArrayNear(x1_np, x_np, 1.0e-5)
self.assertNDArrayNear(x2_np, expected2_np, 1.0e-5)
def test_upscale_invalid_scale_throws_exception(self):
with self.assertRaises(ValueError):
self.assertRaises(layers.upscale(tf.constant([]), -1))
def test_upscale_4d_images_returns_upscaled_images(self):
x_np = np.array([[[[1, 2, 3]]], [[[4, 5, 6]]]], dtype=np.float32)
with self.test_session(use_gpu=True) as sess:
x1_np, x2_np = sess.run(
[layers.upscale(tf.constant(x_np), n) for n in [1, 2]])
expected2_np = [[[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
[[[4, 5, 6], [4, 5, 6]], [[4, 5, 6], [4, 5, 6]]]]
self.assertNDArrayNear(x1_np, x_np, 1.0e-5)
self.assertNDArrayNear(x2_np, expected2_np, 1.0e-5)
def test_minibatch_mean_stddev_4d_images_returns_scalar(self):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
with self.test_session(use_gpu=True) as sess:
output_np = sess.run(layers.minibatch_mean_stddev(x))
self.assertAlmostEqual(output_np, 3.0416667, 5)
def test_scalar_concat_invalid_input_throws_exception(self):
with self.assertRaises(ValueError):
layers.scalar_concat(tf.constant(1.2), 2.0)
def test_scalar_concat_4d_images_and_scalar(self):
x = tf.constant(
[[[[1, 2], [4, 5]], [[7, 8], [10, 11]]], [[[0, 0], [-1, -2]],
[[1, -2], [2, 5]]]],
dtype=tf.float32)
with self.test_session(use_gpu=True) as sess:
output_np = sess.run(layers.scalar_concat(x, 7))
expected_np = [[[[1, 2, 7], [4, 5, 7]], [[7, 8, 7], [10, 11, 7]]],
[[[0, 0, 7], [-1, -2, 7]], [[1, -2, 7], [2, 5, 7]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
def test_he_initializer_scale_slope_linear(self):
self.assertAlmostEqual(
layers.he_initializer_scale([3, 4, 5, 6], 1.0), 0.1290994, 5)
def test_he_initializer_scale_slope_relu(self):
self.assertAlmostEqual(
layers.he_initializer_scale([3, 4, 5, 6], 0.0), 0.1825742, 5)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_layer_impl_with_weight_scaling(
self, mock_zeros_initializer, mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
output = layers._custom_layer_impl(
apply_kernel=dummy_apply_kernel,
kernel_shape=(25, 6),
bias_shape=(),
activation=lambda x: 2.0 * x,
he_initializer_slope=1.0,
use_weight_scaling=True)
mock_zeros_initializer.assert_called_once_with()
mock_random_normal_initializer.assert_called_once_with(stddev=1.0)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
self.assertAlmostEqual(output_np, 182.6, 3)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_layer_impl_no_weight_scaling(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
output = layers._custom_layer_impl(
apply_kernel=dummy_apply_kernel,
kernel_shape=(25, 6),
bias_shape=(),
activation=lambda x: 2.0 * x,
he_initializer_slope=1.0,
use_weight_scaling=False)
mock_zeros_initializer.assert_called_once_with()
mock_random_normal_initializer.assert_called_once_with(stddev=0.2)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
self.assertAlmostEqual(output_np, 905.0, 3)
@mock.patch.object(tf.layers, 'conv2d', autospec=True)
def test_custom_conv2d_passes_conv2d_options(self, mock_conv2d):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
layers.custom_conv2d(x, 1, 2)
mock_conv2d.assert_called_once_with(
x,
filters=1,
kernel_size=[2, 2],
strides=(1, 1),
padding='SAME',
use_bias=False,
kernel_initializer=mock.ANY)
@mock.patch.object(layers, '_custom_layer_impl', autospec=True)
def test_custom_conv2d_passes_custom_layer_options(self,
mock_custom_layer_impl):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
layers.custom_conv2d(x, 1, 2)
mock_custom_layer_impl.assert_called_once_with(
mock.ANY,
kernel_shape=[2, 2, 3, 1],
bias_shape=(1,),
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_conv2d_scalar_kernel_size(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
output = layers.custom_conv2d(x, 1, 2)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
expected_np = [[[[68.54998016], [42.56921768]], [[50.36344528],
[29.57883835]]],
[[[5.33012676], [4.46410179]], [[10.52627945],
[9.66025352]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_conv2d_list_kernel_size(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
output = layers.custom_conv2d(x, 1, [2, 3])
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
expected_np = [[
[[56.15432739], [56.15432739]],
[[41.30508804], [41.30508804]],
], [[[4.53553391], [4.53553391]], [[8.7781744], [8.7781744]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
@mock.patch.object(layers, '_custom_layer_impl', autospec=True)
def test_custom_dense_passes_custom_layer_options(self,
mock_custom_layer_impl):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
layers.custom_dense(x, 3)
mock_custom_layer_impl.assert_called_once_with(
mock.ANY,
kernel_shape=(12, 3),
bias_shape=(3,),
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_dense_output_is_correct(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
output = layers.custom_dense(x, 3)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
expected_np = [[68.54998016, 68.54998016, 68.54998016],
[5.33012676, 5.33012676, 5.33012676]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generator and discriminator for a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
import layers
class ResolutionSchedule(object):
"""Image resolution upscaling schedule."""
def __init__(self, start_resolutions=(4, 4), scale_base=2, num_resolutions=4):
"""Initializer.
Args:
start_resolutions: An tuple of integers of HxW format for start image
resolutions. Defaults to (4, 4).
scale_base: An integer of resolution base multiplier. Defaults to 2.
num_resolutions: An integer of how many progressive resolutions (including
`start_resolutions`). Defaults to 4.
"""
self._start_resolutions = start_resolutions
self._scale_base = scale_base
self._num_resolutions = num_resolutions
@property
def start_resolutions(self):
return tuple(self._start_resolutions)
@property
def scale_base(self):
return self._scale_base
@property
def num_resolutions(self):
return self._num_resolutions
@property
def final_resolutions(self):
"""Returns the final resolutions."""
return tuple([
r * self._scale_base**(self._num_resolutions - 1)
for r in self._start_resolutions
])
def scale_factor(self, block_id):
"""Returns the scale factor for network block `block_id`."""
if block_id < 1 or block_id > self._num_resolutions:
raise ValueError('`block_id` must be in [1, {}]'.format(
self._num_resolutions))
return self._scale_base**(self._num_resolutions - block_id)
def block_name(block_id):
"""Returns the scope name for the network block `block_id`."""
return 'progressive_gan_block_{}'.format(block_id)
def min_total_num_images(stable_stage_num_images, transition_stage_num_images,
num_blocks):
"""Returns the minimum total number of images.
Computes the minimum total number of images required to reach the desired
`resolution`.
Args:
stable_stage_num_images: Number of images in the stable stage.
transition_stage_num_images: Number of images in the transition stage.
num_blocks: Number of network blocks.
Returns:
An integer of the minimum total number of images.
"""
return (num_blocks * stable_stage_num_images +
(num_blocks - 1) * transition_stage_num_images)
def compute_progress(current_image_id, stable_stage_num_images,
transition_stage_num_images, num_blocks):
"""Computes the training progress.
The training alternates between stable phase and transition phase.
The `progress` indicates the training progress, i.e. the training is at
- a stable phase p if progress = p
- a transition stage between p and p + 1 if progress = p + fraction
where p = 0,1,2.,...
Note the max value of progress is `num_blocks` - 1.
In terms of LOD (of the original implementation):
progress = `num_blocks` - 1 - LOD
Args:
current_image_id: An scalar integer `Tensor` of the current image id, count
from 0.
stable_stage_num_images: An integer representing the number of images in
each stable stage.
transition_stage_num_images: An integer representing the number of images in
each transition stage.
num_blocks: Number of network blocks.
Returns:
A scalar float `Tensor` of the training progress.
"""
# Note when current_image_id >= min_total_num_images - 1 (which means we
# are already at the highest resolution), we want to keep progress constant.
# Therefore, cap current_image_id here.
capped_current_image_id = tf.minimum(
current_image_id,
min_total_num_images(stable_stage_num_images, transition_stage_num_images,
num_blocks) - 1)
stage_num_images = stable_stage_num_images + transition_stage_num_images
progress_integer = tf.floordiv(capped_current_image_id, stage_num_images)
progress_fraction = tf.maximum(
0.0,
tf.to_float(
tf.mod(capped_current_image_id, stage_num_images) -
stable_stage_num_images) / tf.to_float(transition_stage_num_images))
return tf.to_float(progress_integer) + progress_fraction
def _generator_alpha(block_id, progress):
"""Returns the block output parameter for the generator network.
The generator has N blocks with `block_id` = 1,2,...,N. Each block
block_id outputs a fake data output(block_id). The generator output is a
linear combination of all block outputs, i.e.
SUM_block_id(output(block_id) * alpha(block_id, progress)) where
alpha(block_id, progress) = _generator_alpha(block_id, progress). Note it
garantees that SUM_block_id(alpha(block_id, progress)) = 1 for any progress.
With a fixed block_id, the plot of alpha(block_id, progress) against progress
is a 'triangle' with its peak at (block_id - 1, 1).
Args:
block_id: An integer of generator block id.
progress: A scalar float `Tensor` of training progress.
Returns:
A scalar float `Tensor` of block output parameter.
"""
return tf.maximum(0.0,
tf.minimum(progress - (block_id - 2), block_id - progress))
def _discriminator_alpha(block_id, progress):
"""Returns the block input parameter for discriminator network.
The discriminator has N blocks with `block_id` = 1,2,...,N. Each block
block_id accepts an
- input(block_id) transformed from the real data and
- the output of block block_id + 1, i.e. output(block_id + 1)
The final input is a linear combination of them,
i.e. alpha * input(block_id) + (1 - alpha) * output(block_id + 1)
where alpha = _discriminator_alpha(block_id, progress).
With a fixed block_id, alpha(block_id, progress) stays to be 1
when progress <= block_id - 1, then linear decays to 0 when
block_id - 1 < progress <= block_id, and finally stays at 0
when progress > block_id.
Args:
block_id: An integer of generator block id.
progress: A scalar float `Tensor` of training progress.
Returns:
A scalar float `Tensor` of block input parameter.
"""
return tf.clip_by_value(block_id - progress, 0.0, 1.0)
def blend_images(x, progress, resolution_schedule, num_blocks):
"""Blends images of different resolutions according to `progress`.
When training `progress` is at a stable stage for resolution r, returns
image `x` downscaled to resolution r and then upscaled to `final_resolutions`,
call it x'(r).
Otherwise when training `progress` is at a transition stage from resolution
r to 2r, returns a linear combination of x'(r) and x'(2r).
Args:
x: An image `Tensor` of NHWC format with resolution `final_resolutions`.
progress: A scalar float `Tensor` of training progress.
resolution_schedule: An object of `ResolutionSchedule`.
num_blocks: An integer of number of blocks.
Returns:
An image `Tensor` which is a blend of images of different resolutions.
"""
x_blend = []
for block_id in range(1, num_blocks + 1):
alpha = _generator_alpha(block_id, progress)
scale = resolution_schedule.scale_factor(block_id)
x_blend.append(alpha * layers.upscale(layers.downscale(x, scale), scale))
return tf.add_n(x_blend)
def num_filters(block_id, fmap_base=4096, fmap_decay=1.0, fmap_max=256):
"""Computes number of filters of block `block_id`."""
return int(min(fmap_base / math.pow(2.0, block_id * fmap_decay), fmap_max))
def generator(z,
progress,
num_filters_fn,
resolution_schedule,
num_blocks=None,
kernel_size=3,
colors=3,
to_rgb_activation=None,
scope='progressive_gan_generator',
reuse=None):
"""Generator network for the progressive GAN model.
Args:
z: A `Tensor` of latent vector. The first dimension must be batch size.
progress: A scalar float `Tensor` of training progress.
num_filters_fn: A function that maps `block_id` to # of filters for the
block.
resolution_schedule: An object of `ResolutionSchedule`.
num_blocks: An integer of number of blocks. None means maximum number of
blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None.
kernel_size: An integer of convolution kernel size.
colors: Number of output color channels. Defaults to 3.
to_rgb_activation: Activation function applied when output rgb.
scope: A string or variable scope.
reuse: Whether to reuse `scope`. Defaults to None which means to inherit
the reuse option of the parent scope.
Returns:
A `Tensor` of model output and a dictionary of model end points.
"""
if num_blocks is None:
num_blocks = resolution_schedule.num_resolutions
start_h, start_w = resolution_schedule.start_resolutions
final_h, final_w = resolution_schedule.final_resolutions
def _conv2d(scope, x, kernel_size, filters, padding='SAME'):
return layers.custom_conv2d(
x=x,
filters=filters,
kernel_size=kernel_size,
padding=padding,
activation=lambda x: layers.pixel_norm(tf.nn.leaky_relu(x)),
he_initializer_slope=0.0,
scope=scope)
def _to_rgb(x):
return layers.custom_conv2d(
x=x,
filters=colors,
kernel_size=1,
padding='SAME',
activation=to_rgb_activation,
scope='to_rgb')
end_points = {}
with tf.variable_scope(scope, reuse=reuse):
with tf.name_scope('input'):
x = tf.contrib.layers.flatten(z)
end_points['latent_vector'] = x
with tf.variable_scope(block_name(1)):
x = tf.expand_dims(tf.expand_dims(x, 1), 1)
x = layers.pixel_norm(x)
# Pad the 1 x 1 image to 2 * (start_h - 1) x 2 * (start_w - 1)
# with zeros for the next conv.
x = tf.pad(x, [[0] * 2, [start_h - 1] * 2, [start_w - 1] * 2, [0] * 2])
# The output is start_h x start_w x num_filters_fn(1).
x = _conv2d('conv0', x, (start_h, start_w), num_filters_fn(1), 'VALID')
x = _conv2d('conv1', x, kernel_size, num_filters_fn(1))
lods = [x]
for block_id in range(2, num_blocks + 1):
with tf.variable_scope(block_name(block_id)):
x = layers.upscale(x, resolution_schedule.scale_base)
x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id))
x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id))
lods.append(x)
outputs = []
for block_id in range(1, num_blocks + 1):
with tf.variable_scope(block_name(block_id)):
lod = _to_rgb(lods[block_id - 1])
scale = resolution_schedule.scale_factor(block_id)
lod = layers.upscale(lod, scale)
end_points['upscaled_rgb_{}'.format(block_id)] = lod
# alpha_i is used to replace lod_select. Note sum(alpha_i) is
# garanteed to be 1.
alpha = _generator_alpha(block_id, progress)
end_points['alpha_{}'.format(block_id)] = alpha
outputs.append(lod * alpha)
predictions = tf.add_n(outputs)
batch_size = z.shape[0].value
predictions.set_shape([batch_size, final_h, final_w, colors])
end_points['predictions'] = predictions
return predictions, end_points
def discriminator(x,
progress,
num_filters_fn,
resolution_schedule,
num_blocks=None,
kernel_size=3,
scope='progressive_gan_discriminator',
reuse=None):
"""Discriminator network for the progressive GAN model.
Args:
x: A `Tensor`of NHWC format representing images of size `resolution`.
progress: A scalar float `Tensor` of training progress.
num_filters_fn: A function that maps `block_id` to # of filters for the
block.
resolution_schedule: An object of `ResolutionSchedule`.
num_blocks: An integer of number of blocks. None means maximum number of
blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None.
kernel_size: An integer of convolution kernel size.
scope: A string or variable scope.
reuse: Whether to reuse `scope`. Defaults to None which means to inherit
the reuse option of the parent scope.
Returns:
A `Tensor` of model output and a dictionary of model end points.
"""
if num_blocks is None:
num_blocks = resolution_schedule.num_resolutions
def _conv2d(scope, x, kernel_size, filters, padding='SAME'):
return layers.custom_conv2d(
x=x,
filters=filters,
kernel_size=kernel_size,
padding=padding,
activation=tf.nn.leaky_relu,
he_initializer_slope=0.0,
scope=scope)
def _from_rgb(x, block_id):
return _conv2d('from_rgb', x, 1, num_filters_fn(block_id))
end_points = {}
with tf.variable_scope(scope, reuse=reuse):
x0 = x
end_points['rgb'] = x0
lods = []
for block_id in range(num_blocks, 0, -1):
with tf.variable_scope(block_name(block_id)):
scale = resolution_schedule.scale_factor(block_id)
lod = layers.downscale(x0, scale)
end_points['downscaled_rgb_{}'.format(block_id)] = lod
lod = _from_rgb(lod, block_id)
# alpha_i is used to replace lod_select.
alpha = _discriminator_alpha(block_id, progress)
end_points['alpha_{}'.format(block_id)] = alpha
lods.append((lod, alpha))
lods_iter = iter(lods)
x, _ = lods_iter.next()
for block_id in range(num_blocks, 1, -1):
with tf.variable_scope(block_name(block_id)):
x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id))
x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id - 1))
x = layers.downscale(x, resolution_schedule.scale_base)
lod, alpha = lods_iter.next()
x = alpha * lod + (1.0 - alpha) * x
with tf.variable_scope(block_name(1)):
x = layers.scalar_concat(x, layers.minibatch_mean_stddev(x))
x = _conv2d('conv0', x, kernel_size, num_filters_fn(1))
x = _conv2d('conv1', x, resolution_schedule.start_resolutions,
num_filters_fn(0), 'VALID')
end_points['last_conv'] = x
logits = layers.custom_dense(x=x, units=1, scope='logits')
end_points['logits'] = logits
return logits, end_points
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import layers
import networks
def _get_grad_norm(ys, xs):
"""Compute 2-norm of dys / dxs."""
return tf.sqrt(
tf.add_n([tf.reduce_sum(tf.square(g)) for g in tf.gradients(ys, xs)]))
def _num_filters_stub(block_id):
return networks.num_filters(block_id, 8, 1, 8)
class NetworksTest(tf.test.TestCase):
def test_resolution_schedule_correct(self):
rs = networks.ResolutionSchedule(
start_resolutions=[5, 3], scale_base=2, num_resolutions=3)
self.assertEqual(rs.start_resolutions, (5, 3))
self.assertEqual(rs.scale_base, 2)
self.assertEqual(rs.num_resolutions, 3)
self.assertEqual(rs.final_resolutions, (20, 12))
self.assertEqual(rs.scale_factor(1), 4)
self.assertEqual(rs.scale_factor(2), 2)
self.assertEqual(rs.scale_factor(3), 1)
with self.assertRaises(ValueError):
rs.scale_factor(0)
with self.assertRaises(ValueError):
rs.scale_factor(4)
def test_block_name(self):
self.assertEqual(networks.block_name(10), 'progressive_gan_block_10')
def test_min_total_num_images(self):
self.assertEqual(networks.min_total_num_images(7, 8, 4), 52)
def test_compute_progress(self):
current_image_id_ph = tf.placeholder(tf.int32, [])
progress = networks.compute_progress(
current_image_id_ph,
stable_stage_num_images=7,
transition_stage_num_images=8,
num_blocks=2)
with self.test_session(use_gpu=True) as sess:
progress_output = [
sess.run(progress, feed_dict={current_image_id_ph: current_image_id})
for current_image_id in [0, 3, 6, 7, 8, 10, 15, 29, 100]
]
self.assertArrayNear(progress_output,
[0.0, 0.0, 0.0, 0.0, 0.125, 0.375, 1.0, 1.0, 1.0],
1.0e-6)
def test_generator_alpha(self):
with self.test_session(use_gpu=True) as sess:
alpha_fixed_block_id = [
sess.run(
networks._generator_alpha(2, tf.constant(progress, tf.float32)))
for progress in [0, 0.2, 1, 1.2, 2, 2.2, 3]
]
alpha_fixed_progress = [
sess.run(
networks._generator_alpha(block_id, tf.constant(1.2, tf.float32)))
for block_id in range(1, 5)
]
self.assertArrayNear(alpha_fixed_block_id, [0, 0.2, 1, 0.8, 0, 0, 0],
1.0e-6)
self.assertArrayNear(alpha_fixed_progress, [0, 0.8, 0.2, 0], 1.0e-6)
def test_discriminator_alpha(self):
with self.test_session(use_gpu=True) as sess:
alpha_fixed_block_id = [
sess.run(
networks._discriminator_alpha(2, tf.constant(
progress, tf.float32)))
for progress in [0, 0.2, 1, 1.2, 2, 2.2, 3]
]
alpha_fixed_progress = [
sess.run(
networks._discriminator_alpha(block_id,
tf.constant(1.2, tf.float32)))
for block_id in range(1, 5)
]
self.assertArrayNear(alpha_fixed_block_id, [1, 1, 1, 0.8, 0, 0, 0], 1.0e-6)
self.assertArrayNear(alpha_fixed_progress, [0, 0.8, 1, 1], 1.0e-6)
def test_blend_images_in_stable_stage(self):
x_np = np.random.normal(size=[2, 8, 8, 3])
x = tf.constant(x_np, tf.float32)
x_blend = networks.blend_images(
x,
progress=tf.constant(0.0),
resolution_schedule=networks.ResolutionSchedule(
scale_base=2, num_resolutions=2),
num_blocks=2)
with self.test_session(use_gpu=True) as sess:
x_blend_np = sess.run(x_blend)
x_blend_expected_np = sess.run(layers.upscale(layers.downscale(x, 2), 2))
self.assertNDArrayNear(x_blend_np, x_blend_expected_np, 1.0e-6)
def test_blend_images_in_transition_stage(self):
x_np = np.random.normal(size=[2, 8, 8, 3])
x = tf.constant(x_np, tf.float32)
x_blend = networks.blend_images(
x,
tf.constant(0.2),
resolution_schedule=networks.ResolutionSchedule(
scale_base=2, num_resolutions=2),
num_blocks=2)
with self.test_session(use_gpu=True) as sess:
x_blend_np = sess.run(x_blend)
x_blend_expected_np = 0.8 * sess.run(
layers.upscale(layers.downscale(x, 2), 2)) + 0.2 * x_np
self.assertNDArrayNear(x_blend_np, x_blend_expected_np, 1.0e-6)
def test_num_filters(self):
self.assertEqual(networks.num_filters(1, 4096, 1, 256), 256)
self.assertEqual(networks.num_filters(5, 4096, 1, 256), 128)
def test_generator_grad_norm_progress(self):
stable_stage_num_images = 2
transition_stage_num_images = 3
current_image_id_ph = tf.placeholder(tf.int32, [])
progress = networks.compute_progress(
current_image_id_ph,
stable_stage_num_images,
transition_stage_num_images,
num_blocks=3)
z = tf.random_normal([2, 10], dtype=tf.float32)
x, _ = networks.generator(
z, progress, _num_filters_stub,
networks.ResolutionSchedule(
start_resolutions=(4, 4), scale_base=2, num_resolutions=3))
fake_loss = tf.reduce_sum(tf.square(x))
grad_norms = [
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_1/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_2/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_3/.*'))
]
grad_norms_output = None
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
x1_np = sess.run(x, feed_dict={current_image_id_ph: 0.12})
x2_np = sess.run(x, feed_dict={current_image_id_ph: 1.8})
grad_norms_output = np.array([
sess.run(grad_norms, feed_dict={current_image_id_ph: i})
for i in range(15) # total num of images
])
self.assertEqual((2, 16, 16, 3), x1_np.shape)
self.assertEqual((2, 16, 16, 3), x2_np.shape)
# The gradient of block_1 is always on.
self.assertEqual(
np.argmax(grad_norms_output[:, 0] > 0), 0,
'gradient norms {} for block 1 is not always on'.format(
grad_norms_output[:, 0]))
# The gradient of block_2 is on after 1 stable stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 1] > 0), 3,
'gradient norms {} for block 2 is not on at step 3'.format(
grad_norms_output[:, 1]))
# The gradient of block_3 is on after 2 stable stage + 1 transition stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 2] > 0), 8,
'gradient norms {} for block 3 is not on at step 8'.format(
grad_norms_output[:, 2]))
def test_discriminator_grad_norm_progress(self):
stable_stage_num_images = 2
transition_stage_num_images = 3
current_image_id_ph = tf.placeholder(tf.int32, [])
progress = networks.compute_progress(
current_image_id_ph,
stable_stage_num_images,
transition_stage_num_images,
num_blocks=3)
x = tf.random_normal([2, 16, 16, 3])
logits, _ = networks.discriminator(
x, progress, _num_filters_stub,
networks.ResolutionSchedule(
start_resolutions=(4, 4), scale_base=2, num_resolutions=3))
fake_loss = tf.reduce_sum(tf.square(logits))
grad_norms = [
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_1/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_2/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_3/.*'))
]
grad_norms_output = None
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
grad_norms_output = np.array([
sess.run(grad_norms, feed_dict={current_image_id_ph: i})
for i in range(15) # total num of images
])
# The gradient of block_1 is always on.
self.assertEqual(
np.argmax(grad_norms_output[:, 0] > 0), 0,
'gradient norms {} for block 1 is not always on'.format(
grad_norms_output[:, 0]))
# The gradient of block_2 is on after 1 stable stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 1] > 0), 3,
'gradient norms {} for block 2 is not on at step 3'.format(
grad_norms_output[:, 1]))
# The gradient of block_3 is on after 2 stable stage + 1 transition stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 2] > 0), 8,
'gradient norms {} for block 3 is not on at step 8'.format(
grad_norms_output[:, 2]))
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import numpy as np
import tensorflow as tf
import networks
tfgan = tf.contrib.gan
def make_train_sub_dir(stage_id, **kwargs):
"""Returns the log directory for training stage `stage_id`."""
return os.path.join(kwargs['train_root_dir'], 'stage_{:05d}'.format(stage_id))
def make_resolution_schedule(**kwargs):
"""Returns an object of `ResolutionSchedule`."""
return networks.ResolutionSchedule(
start_resolutions=(kwargs['start_height'], kwargs['start_width']),
scale_base=kwargs['scale_base'],
num_resolutions=kwargs['num_resolutions'])
def get_stage_ids(**kwargs):
"""Returns a list of stage ids.
Args:
**kwargs: A dictionary of
'train_root_dir': A string of root directory of training logs.
'num_resolutions': An integer of number of progressive resolutions.
"""
train_sub_dirs = [
sub_dir for sub_dir in tf.gfile.ListDirectory(kwargs['train_root_dir'])
if sub_dir.startswith('stage_')
]
# If fresh start, start with start_stage_id = 0
# If has been trained for n = len(train_sub_dirs) stages, start with the last
# stage, i.e. start_stage_id = n - 1.
start_stage_id = max(0, len(train_sub_dirs) - 1)
return range(start_stage_id, get_total_num_stages(**kwargs))
def get_total_num_stages(**kwargs):
"""Returns total number of training stages."""
return 2 * kwargs['num_resolutions'] - 1
def get_batch_size(stage_id, **kwargs):
"""Returns batch size for each stage.
It is expected that `len(batch_size_schedule) == num_resolutions`. Each stage
corresponds to a resolution and hence a batch size. However if
`len(batch_size_schedule) < num_resolutions`, pad `batch_size_schedule` in the
beginning with the first batch size.
Args:
stage_id: An integer of training stage index.
**kwargs: A dictionary of
'batch_size_schedule': A list of integer, each element is the batch size
for the current training image resolution.
'num_resolutions': An integer of number of progressive resolutions.
Returns:
An integer batch size for the `stage_id`.
"""
batch_size_schedule = kwargs['batch_size_schedule']
num_resolutions = kwargs['num_resolutions']
if len(batch_size_schedule) < num_resolutions:
batch_size_schedule = (
[batch_size_schedule[0]] * (num_resolutions - len(batch_size_schedule))
+ batch_size_schedule)
return int(batch_size_schedule[(stage_id + 1) // 2])
def get_stage_info(stage_id, **kwargs):
"""Returns information for a training stage.
Args:
stage_id: An integer of training stage index.
**kwargs: A dictionary of
'num_resolutions': An integer of number of progressive resolutions.
'stable_stage_num_images': An integer of number of training images in
the stable stage.
'transition_stage_num_images': An integer of number of training images
in the transition stage.
'total_num_images': An integer of total number of training images.
Returns:
A tuple of integers. The first entry is the number of blocks. The second
entry is the accumulated total number of training images when stage
`stage_id` is finished.
Raises:
ValueError: If `stage_id` is not in [0, total number of stages).
"""
total_num_stages = get_total_num_stages(**kwargs)
if not (stage_id >= 0 and stage_id < total_num_stages):
raise ValueError(
'`stage_id` must be in [0, {0}), but instead was {1}'.format(
total_num_stages, stage_id))
# Even stage_id: stable training stage.
# Odd stage_id: transition training stage.
num_blocks = (stage_id + 1) // 2 + 1
num_images = ((stage_id // 2 + 1) * kwargs['stable_stage_num_images'] + (
(stage_id + 1) // 2) * kwargs['transition_stage_num_images'])
total_num_images = kwargs['total_num_images']
if stage_id >= total_num_stages - 1:
num_images = total_num_images
num_images = min(num_images, total_num_images)
return num_blocks, num_images
def make_latent_vectors(num, **kwargs):
"""Returns a batch of `num` random latent vectors."""
return tf.random_normal([num, kwargs['latent_vector_size']], dtype=tf.float32)
def make_interpolated_latent_vectors(num_rows, num_columns, **kwargs):
"""Returns a batch of linearly interpolated latent vectors.
Given two randomly generated latent vector za and zb, it can generate
a row of `num_columns` interpolated latent vectors, i.e.
[..., za + (zb - za) * i / (num_columns - 1), ...] where
i = 0, 1, ..., `num_columns` - 1.
This function produces `num_rows` such rows and returns a (flattened)
batch of latent vectors with batch size `num_rows * num_columns`.
Args:
num_rows: An integer. Number of rows of interpolated latent vectors.
num_columns: An integer. Number of interpolated latent vectors in each row.
**kwargs: A dictionary of
'latent_vector_size': An integer of latent vector size.
Returns:
A `Tensor` of shape `[num_rows * num_columns, latent_vector_size]`.
"""
ans = []
for _ in range(num_rows):
z = tf.random_normal([2, kwargs['latent_vector_size']])
r = tf.reshape(
tf.to_float(tf.range(num_columns)) / (num_columns - 1), [-1, 1])
dz = z[1] - z[0]
ans.append(z[0] + tf.stack([dz] * num_columns) * r)
return tf.concat(ans, axis=0)
def define_loss(gan_model, **kwargs):
"""Defines progressive GAN losses.
The generator and discriminator both use wasserstein loss. In addition,
a small penalty term is added to the discriminator loss to prevent it getting
too large.
Args:
gan_model: A `GANModel` namedtuple.
**kwargs: A dictionary of
'gradient_penalty_weight': A float of gradient norm target for
wasserstein loss.
'gradient_penalty_target': A float of gradient penalty weight for
wasserstein loss.
'real_score_penalty_weight': A float of Additional penalty to keep
the scores from drifting too far from zero.
Returns:
A `GANLoss` namedtuple.
"""
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
gradient_penalty_weight=kwargs['gradient_penalty_weight'],
gradient_penalty_target=kwargs['gradient_penalty_target'],
gradient_penalty_epsilon=0.0)
real_score_penalty = tf.reduce_mean(
tf.square(gan_model.discriminator_real_outputs))
tf.summary.scalar('real_score_penalty', real_score_penalty)
return gan_loss._replace(
discriminator_loss=(
gan_loss.discriminator_loss +
kwargs['real_score_penalty_weight'] * real_score_penalty))
def define_train_ops(gan_model, gan_loss, **kwargs):
"""Defines progressive GAN train ops.
Args:
gan_model: A `GANModel` namedtuple.
gan_loss: A `GANLoss` namedtuple.
**kwargs: A dictionary of
'adam_beta1': A float of Adam optimizer beta1.
'adam_beta2': A float of Adam optimizer beta2.
'generator_learning_rate': A float of generator learning rate.
'discriminator_learning_rate': A float of discriminator learning rate.
Returns:
A tuple of `GANTrainOps` namedtuple and a list variables tracking the state
of optimizers.
"""
with tf.variable_scope('progressive_gan_train_ops') as var_scope:
beta1, beta2 = kwargs['adam_beta1'], kwargs['adam_beta2']
gen_opt = tf.train.AdamOptimizer(kwargs['generator_learning_rate'], beta1,
beta2)
dis_opt = tf.train.AdamOptimizer(kwargs['discriminator_learning_rate'],
beta1, beta2)
gan_train_ops = tfgan.gan_train_ops(gan_model, gan_loss, gen_opt, dis_opt)
return gan_train_ops, tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope.name)
def add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops):
"""Adds generator smoothing ops."""
with tf.control_dependencies([gan_train_ops.generator_train_op]):
new_generator_train_op = generator_ema.apply(gan_model.generator_variables)
gan_train_ops = gan_train_ops._replace(
generator_train_op=new_generator_train_op)
generator_vars_to_restore = generator_ema.variables_to_restore(
gan_model.generator_variables)
return gan_train_ops, generator_vars_to_restore
def build_model(stage_id, batch_size, real_images, **kwargs):
"""Builds progressive GAN model.
Args:
stage_id: An integer of training stage index.
batch_size: Number of training images in each minibatch.
real_images: A 4D `Tensor` of NHWC format.
**kwargs: A dictionary of
'start_height': An integer of start image height.
'start_width': An integer of start image width.
'scale_base': An integer of resolution multiplier.
'num_resolutions': An integer of number of progressive resolutions.
'stable_stage_num_images': An integer of number of training images in
the stable stage.
'transition_stage_num_images': An integer of number of training images
in the transition stage.
'total_num_images': An integer of total number of training images.
'kernel_size': Convolution kernel size.
'colors': Number of image channels.
'to_rgb_use_tanh_activation': Whether to apply tanh activation when
output rgb.
'fmap_base': Base number of filters.
'fmap_decay': Decay of number of filters.
'fmap_max': Max number of filters.
'latent_vector_size': An integer of latent vector size.
'gradient_penalty_weight': A float of gradient norm target for
wasserstein loss.
'gradient_penalty_target': A float of gradient penalty weight for
wasserstein loss.
'real_score_penalty_weight': A float of Additional penalty to keep
the scores from drifting too far from zero.
'adam_beta1': A float of Adam optimizer beta1.
'adam_beta2': A float of Adam optimizer beta2.
'generator_learning_rate': A float of generator learning rate.
'discriminator_learning_rate': A float of discriminator learning rate.
Returns:
An inernal object that wraps all information about the model.
"""
kernel_size = kwargs['kernel_size']
colors = kwargs['colors']
resolution_schedule = make_resolution_schedule(**kwargs)
num_blocks, num_images = get_stage_info(stage_id, **kwargs)
current_image_id = tf.train.get_or_create_global_step()
current_image_id_inc_op = current_image_id.assign_add(batch_size)
tf.summary.scalar('current_image_id', current_image_id)
progress = networks.compute_progress(
current_image_id, kwargs['stable_stage_num_images'],
kwargs['transition_stage_num_images'], num_blocks)
tf.summary.scalar('progress', progress)
real_images = networks.blend_images(
real_images, progress, resolution_schedule, num_blocks=num_blocks)
def _num_filters_fn(block_id):
"""Computes number of filters of block `block_id`."""
return networks.num_filters(block_id, kwargs['fmap_base'],
kwargs['fmap_decay'], kwargs['fmap_max'])
def _generator_fn(z):
"""Builds generator network."""
return networks.generator(
z,
progress,
_num_filters_fn,
resolution_schedule,
num_blocks=num_blocks,
kernel_size=kernel_size,
colors=colors,
to_rgb_activation=(tf.tanh
if kwargs['to_rgb_use_tanh_activation'] else None))
def _discriminator_fn(x):
"""Builds discriminator network."""
return networks.discriminator(
x,
progress,
_num_filters_fn,
resolution_schedule,
num_blocks=num_blocks,
kernel_size=kernel_size)
########## Define model.
z = make_latent_vectors(batch_size, **kwargs)
gan_model = tfgan.gan_model(
generator_fn=lambda z: _generator_fn(z)[0],
discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0],
real_data=real_images,
generator_inputs=z)
########## Define loss.
gan_loss = define_loss(gan_model, **kwargs)
########## Define train ops.
gan_train_ops, optimizer_var_list = define_train_ops(gan_model, gan_loss,
**kwargs)
gan_train_ops = gan_train_ops._replace(
global_step_inc_op=current_image_id_inc_op)
########## Generator smoothing.
generator_ema = tf.train.ExponentialMovingAverage(decay=0.999)
gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops(
generator_ema, gan_model, gan_train_ops)
class Model(object):
pass
model = Model()
model.stage_id = stage_id
model.batch_size = batch_size
model.resolution_schedule = resolution_schedule
model.num_images = num_images
model.num_blocks = num_blocks
model.current_image_id = current_image_id
model.progress = progress
model.num_filters_fn = _num_filters_fn
model.generator_fn = _generator_fn
model.discriminator_fn = _discriminator_fn
model.gan_model = gan_model
model.gan_loss = gan_loss
model.gan_train_ops = gan_train_ops
model.optimizer_var_list = optimizer_var_list
model.generator_ema = generator_ema
model.generator_vars_to_restore = generator_vars_to_restore
return model
def make_var_scope_custom_getter_for_ema(ema):
"""Makes variable scope custom getter."""
def _custom_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var else var
return _custom_getter
def add_model_summaries(model, **kwargs):
"""Adds model summaries.
This function adds several useful summaries during training:
- fake_images: A grid of fake images based on random latent vectors.
- interp_images: A grid of fake images based on interpolated latent vectors.
- real_images_blend: A grid of real images.
- summaries for `gan_model` losses, variable distributions etc.
Args:
model: An model object having all information of progressive GAN model,
e.g. the return of build_model().
**kwargs: A dictionary of
'fake_grid_size': The fake image grid size for summaries.
'interp_grid_size': The latent space interpolated image grid size for
summaries.
'colors': Number of image channels.
'latent_vector_size': An integer of latent vector size.
"""
fake_grid_size = kwargs['fake_grid_size']
interp_grid_size = kwargs['interp_grid_size']
colors = kwargs['colors']
image_shape = list(model.resolution_schedule.final_resolutions)
fake_batch_size = fake_grid_size**2
fake_images_shape = [fake_batch_size] + image_shape + [colors]
interp_batch_size = interp_grid_size**2
interp_images_shape = [interp_batch_size] + image_shape + [colors]
# When making prediction, use the ema smoothed generator vars.
with tf.variable_scope(
model.gan_model.generator_scope,
reuse=True,
custom_getter=make_var_scope_custom_getter_for_ema(model.generator_ema)):
z_fake = make_latent_vectors(fake_batch_size, **kwargs)
fake_images = model.gan_model.generator_fn(z_fake)
fake_images.set_shape(fake_images_shape)
z_interp = make_interpolated_latent_vectors(interp_grid_size,
interp_grid_size, **kwargs)
interp_images = model.gan_model.generator_fn(z_interp)
interp_images.set_shape(interp_images_shape)
tf.summary.image(
'fake_images',
tfgan.eval.eval_utils.image_grid(
fake_images,
grid_shape=[fake_grid_size] * 2,
image_shape=image_shape,
num_channels=colors),
max_outputs=1)
tf.summary.image(
'interp_images',
tfgan.eval.eval_utils.image_grid(
interp_images,
grid_shape=[interp_grid_size] * 2,
image_shape=image_shape,
num_channels=colors),
max_outputs=1)
real_grid_size = int(np.sqrt(model.batch_size))
tf.summary.image(
'real_images_blend',
tfgan.eval.eval_utils.image_grid(
model.gan_model.real_data[:real_grid_size**2],
grid_shape=(real_grid_size, real_grid_size),
image_shape=image_shape,
num_channels=colors),
max_outputs=1)
tfgan.eval.add_gan_model_summaries(model.gan_model)
def make_scaffold(stage_id, optimizer_var_list, **kwargs):
"""Makes a custom scaffold.
The scaffold
- restores variables from the last training stage.
- initializes new variables in the new block.
Args:
stage_id: An integer of stage id.
optimizer_var_list: A list of optimizer variables.
**kwargs: A dictionary of
'train_root_dir': A string of root directory of training logs.
'num_resolutions': An integer of number of progressive resolutions.
'stable_stage_num_images': An integer of number of training images in
the stable stage.
'transition_stage_num_images': An integer of number of training images
in the transition stage.
'total_num_images': An integer of total number of training images.
Returns:
A `Scaffold` object.
"""
# Holds variables that from the previous stage and need to be restored.
restore_var_list = []
prev_ckpt = None
curr_ckpt = tf.train.latest_checkpoint(make_train_sub_dir(stage_id, **kwargs))
if stage_id > 0 and curr_ckpt is None:
prev_ckpt = tf.train.latest_checkpoint(
make_train_sub_dir(stage_id - 1, **kwargs))
num_blocks, _ = get_stage_info(stage_id, **kwargs)
prev_num_blocks, _ = get_stage_info(stage_id - 1, **kwargs)
# Holds variables created in the new block of the current stage. If the
# current stage is a stable stage (except the initial stage), this list
# will be empty.
new_block_var_list = []
for block_id in range(prev_num_blocks + 1, num_blocks + 1):
new_block_var_list.extend(
tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope='.*/{}/'.format(networks.block_name(block_id))))
# Every variables that are 1) not for optimizers and 2) from the new block
# need to be restored.
restore_var_list = [
var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if var not in set(optimizer_var_list + new_block_var_list)
]
# Add saver op to graph. This saver is used to restore variables from the
# previous stage.
saver_for_restore = tf.train.Saver(
var_list=restore_var_list, allow_empty=True)
# Add the op to graph that initializes all global variables.
init_op = tf.global_variables_initializer()
def _init_fn(unused_scaffold, sess):
# First initialize every variables.
sess.run(init_op)
logging.info('\n'.join([var.name for var in restore_var_list]))
# Then overwrite variables saved in previous stage.
if prev_ckpt is not None:
saver_for_restore.restore(sess, prev_ckpt)
# Use a dummy init_op here as all initialization is done in init_fn.
return tf.train.Scaffold(init_op=tf.constant([]), init_fn=_init_fn)
def make_status_message(model):
"""Makes a string `Tensor` of training status."""
return tf.string_join(
[
'Starting train step: current_image_id: ',
tf.as_string(model.current_image_id), ', progress: ',
tf.as_string(model.progress), ', num_blocks: {}'.format(
model.num_blocks), ', batch_size: {}'.format(model.batch_size)
],
name='status_message')
def train(model, **kwargs):
"""Trains progressive GAN for stage `stage_id`.
Args:
model: An model object having all information of progressive GAN model,
e.g. the return of build_model().
**kwargs: A dictionary of
'train_root_dir': A string of root directory of training logs.
'master': Name of the TensorFlow master to use.
'task': The Task ID. This value is used when training with multiple
workers to identify each worker.
'save_summaries_num_images': Save summaries in this number of images.
Returns:
None.
"""
logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
model.num_blocks, model.num_images)
scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs)
tfgan.gan_train(
model.gan_train_ops,
logdir=make_train_sub_dir(model.stage_id, **kwargs),
get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)),
hooks=[
tf.train.StopAtStepHook(last_step=model.num_images),
tf.train.LoggingTensorHook(
[make_status_message(model)], every_n_iter=10)
],
master=kwargs['master'],
is_chief=(kwargs['task'] == 0),
scaffold=scaffold,
save_checkpoint_secs=600,
save_summaries_steps=(kwargs['save_summaries_num_images']))
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from absl import flags
from absl import logging
import tensorflow as tf
import data_provider
import train
tfgan = tf.contrib.gan
flags.DEFINE_string('dataset_name', 'cifar10', 'Dataset name.')
flags.DEFINE_string('dataset_file_pattern', '', 'Dataset file pattern.')
flags.DEFINE_integer('start_height', 4, 'Start image height.')
flags.DEFINE_integer('start_width', 4, 'Start image width.')
flags.DEFINE_integer('scale_base', 2, 'Resolution multiplier.')
flags.DEFINE_integer('num_resolutions', 4, 'Number of progressive resolutions.')
flags.DEFINE_list(
'batch_size_schedule', [8, 8, 4],
'A list of batch sizes for each resolution, if '
'len(batch_size_schedule) < num_resolutions, pad the schedule in the '
'beginning with the first batch size.')
flags.DEFINE_integer('kernel_size', 3, 'Convolution kernel size.')
flags.DEFINE_integer('colors', 3, 'Number of image channels.')
flags.DEFINE_bool('to_rgb_use_tanh_activation', False,
'Whether to apply tanh activation when output rgb.')
flags.DEFINE_integer('stable_stage_num_images', 1000,
'Number of images in the stable stage.')
flags.DEFINE_integer('transition_stage_num_images', 1000,
'Number of images in the transition stage.')
flags.DEFINE_integer('total_num_images', 10000, 'Total number of images.')
flags.DEFINE_integer('save_summaries_num_images', 100,
'Save summaries in this number of images.')
flags.DEFINE_integer('latent_vector_size', 128, 'Latent vector size.')
flags.DEFINE_integer('fmap_base', 4096, 'Base number of filters.')
flags.DEFINE_float('fmap_decay', 1.0, 'Decay of number of filters.')
flags.DEFINE_integer('fmap_max', 128, 'Max number of filters.')
flags.DEFINE_float('gradient_penalty_target', 1.0,
'Gradient norm target for wasserstein loss.')
flags.DEFINE_float('gradient_penalty_weight', 10.0,
'Gradient penalty weight for wasserstein loss.')
flags.DEFINE_float('real_score_penalty_weight', 0.001,
'Additional penalty to keep the scores from drifting too '
'far from zero.')
flags.DEFINE_float('generator_learning_rate', 0.001, 'Learning rate.')
flags.DEFINE_float('discriminator_learning_rate', 0.001, 'Learning rate.')
flags.DEFINE_float('adam_beta1', 0.0, 'Adam beta 1.')
flags.DEFINE_float('adam_beta2', 0.99, 'Adam beta 2.')
flags.DEFINE_integer('fake_grid_size', 8, 'The fake image grid size for eval.')
flags.DEFINE_integer('interp_grid_size', 8,
'The interp image grid size for eval.')
flags.DEFINE_string('train_root_dir', '/tmp/progressive_gan/',
'Directory where to write event logs.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
FLAGS = flags.FLAGS
def _make_config_from_flags():
"""Makes a config dictionary from commandline flags."""
return dict([(flag.name, flag.value)
for flag in FLAGS.get_key_flags_for_module(sys.argv[0])])
def _provide_real_images(batch_size, **kwargs):
"""Provides real images."""
dataset_name = kwargs.get('dataset_name')
dataset_file_pattern = kwargs.get('dataset_file_pattern')
colors = kwargs['colors']
final_height, final_width = train.make_resolution_schedule(
**kwargs).final_resolutions
if dataset_name is not None:
return data_provider.provide_data(
dataset_name=dataset_name,
split_name='train',
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
elif dataset_file_pattern is not None:
return data_provider.provide_data_from_image_files(
file_pattern=dataset_file_pattern,
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
def main(_):
if not tf.gfile.Exists(FLAGS.train_root_dir):
tf.gfile.MakeDirs(FLAGS.train_root_dir)
config = _make_config_from_flags()
logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))
for stage_id in train.get_stage_ids(**config):
batch_size = train.get_batch_size(stage_id, **config)
tf.reset_default_graph()
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
real_images = None
with tf.device('/cpu:0'), tf.name_scope('inputs'):
real_images = _provide_real_images(batch_size, **config)
model = train.build_model(stage_id, batch_size, real_images, **config)
train.add_model_summaries(model, **config)
train.train(model, **config)
if __name__ == '__main__':
tf.app.run()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from absl.testing import absltest
import tensorflow as tf
import train
FLAGS = flags.FLAGS
def provide_random_data(batch_size=2, patch_size=4, colors=1, **unused_kwargs):
return tf.random_normal([batch_size, patch_size, patch_size, colors])
class TrainTest(absltest.TestCase):
def setUp(self):
self._config = {
'start_height': 2,
'start_width': 2,
'scale_base': 2,
'num_resolutions': 2,
'batch_size_schedule': [2],
'colors': 1,
'to_rgb_use_tanh_activation': True,
'kernel_size': 3,
'stable_stage_num_images': 1,
'transition_stage_num_images': 1,
'total_num_images': 3,
'save_summaries_num_images': 2,
'latent_vector_size': 2,
'fmap_base': 8,
'fmap_decay': 1.0,
'fmap_max': 8,
'gradient_penalty_target': 1.0,
'gradient_penalty_weight': 10.0,
'real_score_penalty_weight': 0.001,
'generator_learning_rate': 0.001,
'discriminator_learning_rate': 0.001,
'adam_beta1': 0.0,
'adam_beta2': 0.99,
'fake_grid_size': 2,
'interp_grid_size': 2,
'train_root_dir': os.path.join(FLAGS.test_tmpdir, 'progressive_gan'),
'master': '',
'task': 0
}
def test_train_success(self):
train_root_dir = self._config['train_root_dir']
if not tf.gfile.Exists(train_root_dir):
tf.gfile.MakeDirs(train_root_dir)
for stage_id in train.get_stage_ids(**self._config):
batch_size = train.get_batch_size(stage_id, **self._config)
tf.reset_default_graph()
real_images = provide_random_data(batch_size=batch_size)
model = train.build_model(stage_id, batch_size, real_images,
**self._config)
train.add_model_summaries(model, **self._config)
train.train(model, **self._config)
def test_get_batch_size(self):
config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]}
# batch_size_schedule is expanded to [8, 8, 8, 4, 2]
# At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]):
self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)
if __name__ == '__main__':
absltest.main()
"""StarGAN data provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
def provide_data(image_file_patterns, batch_size, patch_size):
"""Data provider wrapper on for the data_provider in gan/cyclegan.
Args:
image_file_patterns: A list of file pattern globs.
batch_size: Python int. Batch size.
patch_size: Python int. The patch size to extract.
Returns:
List of `Tensor` of shape (N, H, W, C) representing the images.
List of `Tensor` of shape (N, num_domains) representing the labels.
"""
images = data_provider.provide_custom_data(
image_file_patterns,
batch_size=batch_size,
patch_size=patch_size)
num_domains = len(images)
labels = [tf.one_hot([idx] * batch_size, num_domains) for idx in
range(num_domains)]
return images, labels
"""Tests for google3.third_party.tensorflow_models.gan.stargan.data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from google3.testing.pybase import googletest
import data_provider
mock = tf.test.mock
class DataProviderTest(googletest.TestCase):
@mock.patch.object(
data_provider.data_provider, 'provide_custom_data', autospec=True)
def test_data_provider(self, mock_provide_custom_data):
batch_size = 2
patch_size = 8
num_domains = 3
images_shape = [batch_size, patch_size, patch_size, 3]
mock_provide_custom_data.return_value = [
tf.zeros(images_shape) for _ in range(num_domains)
]
images, labels = data_provider.provide_data(
image_file_patterns=None, batch_size=batch_size, patch_size=patch_size)
self.assertEqual(num_domains, len(images))
self.assertEqual(num_domains, len(labels))
for label in labels:
self.assertListEqual([batch_size, num_domains], label.shape.as_list())
for image in images:
self.assertListEqual(images_shape, image.shape.as_list())
if __name__ == '__main__':
googletest.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for a StarGAN model.
This module contains basic layers to build a StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorvh implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import ops
def generator_down_sample(input_net, final_num_outputs=256):
"""Down-sampling module in Generator.
Down sampling pathway of the Generator Architecture:
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L32
Notes:
We require dimension 1 and dimension 2 of the input_net to be fully defined
for the correct down sampling.
Args:
input_net: Tensor of shape (batch_size, h, w, c + num_class).
final_num_outputs: (int) Number of hidden unit for the final layer.
Returns:
Tensor of shape (batch_size, h / 4, w / 4, 256).
Raises:
ValueError: If final_num_outputs are not divisible by 4,
or input_net does not have a rank of 4,
or dimension 1 and dimension 2 of input_net are not defined at graph
construction time,
or dimension 1 and dimension 2 of input_net are not divisible by 4.
"""
if final_num_outputs % 4 != 0:
raise ValueError('Final number outputs need to be divisible by 4.')
# Check the rank of input_net.
input_net.shape.assert_has_rank(4)
# Check dimension 1 and dimension 2 are defined and divisible by 4.
if input_net.shape[1]:
if input_net.shape[1] % 4 != 0:
raise ValueError(
'Dimension 1 of the input should be divisible by 4, but is {} '
'instead.'.
format(input_net.shape[1]))
else:
raise ValueError('Dimension 1 of the input should be explicitly defined.')
# Check dimension 1 and dimension 2 are defined and divisible by 4.
if input_net.shape[2]:
if input_net.shape[2] % 4 != 0:
raise ValueError(
'Dimension 2 of the input should be divisible by 4, but is {} '
'instead.'.
format(input_net.shape[2]))
else:
raise ValueError('Dimension 2 of the input should be explicitly defined.')
with tf.variable_scope('generator_down_sample'):
with tf.contrib.framework.arg_scope(
[tf.contrib.layers.conv2d],
padding='VALID',
biases_initializer=None,
normalizer_fn=tf.contrib.layers.instance_norm,
activation_fn=tf.nn.relu):
down_sample = ops.pad(input_net, 3)
down_sample = tf.contrib.layers.conv2d(
inputs=down_sample,
num_outputs=final_num_outputs / 4,
kernel_size=7,
stride=1,
scope='conv_0')
down_sample = ops.pad(down_sample, 1)
down_sample = tf.contrib.layers.conv2d(
inputs=down_sample,
num_outputs=final_num_outputs / 2,
kernel_size=4,
stride=2,
scope='conv_1')
down_sample = ops.pad(down_sample, 1)
output_net = tf.contrib.layers.conv2d(
inputs=down_sample,
num_outputs=final_num_outputs,
kernel_size=4,
stride=2,
scope='conv_2')
return output_net
def _residual_block(input_net,
num_outputs,
kernel_size,
stride=1,
padding_size=0,
activation_fn=tf.nn.relu,
normalizer_fn=None,
name='residual_block'):
"""Residual Block.
Input Tensor X - > Conv1 -> IN -> ReLU -> Conv2 -> IN + X
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L7
Args:
input_net: Tensor as input.
num_outputs: (int) number of output channels for Convolution.
kernel_size: (int) size of the square kernel for Convolution.
stride: (int) stride for Convolution. Default to 1.
padding_size: (int) padding size for Convolution. Default to 0.
activation_fn: Activation function.
normalizer_fn: Normalization function.
name: Name scope
Returns:
Residual Tensor with the same shape as the input tensor.
"""
with tf.variable_scope(name):
with tf.contrib.framework.arg_scope(
[tf.contrib.layers.conv2d],
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
padding='VALID',
normalizer_fn=normalizer_fn,
activation_fn=None):
res_block = ops.pad(input_net, padding_size)
res_block = tf.contrib.layers.conv2d(inputs=res_block, scope='conv_0')
res_block = activation_fn(res_block, name='activation_0')
res_block = ops.pad(res_block, padding_size)
res_block = tf.contrib.layers.conv2d(inputs=res_block, scope='conv_1')
output_net = res_block + input_net
return output_net
def generator_bottleneck(input_net, residual_block_num=6, num_outputs=256):
"""Bottleneck module in Generator.
Residual bottleneck pathway in Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L40
Args:
input_net: Tensor of shape (batch_size, h / 4, w / 4, 256).
residual_block_num: (int) Number of residual_block_num. Default to 6 per the
original implementation.
num_outputs: (int) Number of hidden unit in the residual bottleneck. Default
to 256 per the original implementation.
Returns:
Tensor of shape (batch_size, h / 4, w / 4, 256).
Raises:
ValueError: If the rank of the input tensor is not 4,
or the last channel of the input_tensor is not explicitly defined,
or the last channel of the input_tensor is not the same as num_outputs.
"""
# Check the rank of input_net.
input_net.shape.assert_has_rank(4)
# Check dimension 4 of the input_net.
if input_net.shape[-1]:
if input_net.shape[-1] != num_outputs:
raise ValueError(
'The last dimension of the input_net should be the same as '
'num_outputs: but {} vs. {} instead.'.format(input_net.shape[-1],
num_outputs))
else:
raise ValueError(
'The last dimension of the input_net should be explicitly defined.')
with tf.variable_scope('generator_bottleneck'):
bottleneck = input_net
for i in range(residual_block_num):
bottleneck = _residual_block(
input_net=bottleneck,
num_outputs=num_outputs,
kernel_size=3,
stride=1,
padding_size=1,
activation_fn=tf.nn.relu,
normalizer_fn=tf.contrib.layers.instance_norm,
name='residual_block_{}'.format(i))
return bottleneck
def generator_up_sample(input_net, num_outputs):
"""Up-sampling module in Generator.
Up sampling path for image generation in the Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L44
Args:
input_net: Tensor of shape (batch_size, h / 4, w / 4, 256).
num_outputs: (int) Number of channel for the output tensor.
Returns:
Tensor of shape (batch_size, h, w, num_outputs).
"""
with tf.variable_scope('generator_up_sample'):
with tf.contrib.framework.arg_scope(
[tf.contrib.layers.conv2d_transpose],
kernel_size=4,
stride=2,
padding='VALID',
normalizer_fn=tf.contrib.layers.instance_norm,
activation_fn=tf.nn.relu):
up_sample = tf.contrib.layers.conv2d_transpose(
inputs=input_net, num_outputs=128, scope='deconv_0')
up_sample = up_sample[:, 1:-1, 1:-1, :]
up_sample = tf.contrib.layers.conv2d_transpose(
inputs=up_sample, num_outputs=64, scope='deconv_1')
up_sample = up_sample[:, 1:-1, 1:-1, :]
output_net = ops.pad(up_sample, 3)
output_net = tf.contrib.layers.conv2d(
inputs=output_net,
num_outputs=num_outputs,
kernel_size=7,
stride=1,
padding='VALID',
activation_fn=tf.nn.tanh,
normalizer_fn=None,
biases_initializer=None,
scope='conv_0')
return output_net
def discriminator_input_hidden(input_net, hidden_layer=6, init_num_outputs=64):
"""Input Layer + Hidden Layer in the Discriminator.
Feature extraction pathway in the Discriminator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L68
Args:
input_net: Tensor of shape (batch_size, h, w, 3) as batch of images.
hidden_layer: (int) Number of hidden layers. Default to 6 per the original
implementation.
init_num_outputs: (int) Number of hidden unit in the first hidden layer. The
number of hidden unit double after each layer. Default to 64 per the
original implementation.
Returns:
Tensor of shape (batch_size, h / 64, w / 64, 2048) as features.
"""
num_outputs = init_num_outputs
with tf.variable_scope('discriminator_input_hidden'):
hidden = input_net
for i in range(hidden_layer):
hidden = ops.pad(hidden, 1)
hidden = tf.contrib.layers.conv2d(
inputs=hidden,
num_outputs=num_outputs,
kernel_size=4,
stride=2,
padding='VALID',
activation_fn=None,
normalizer_fn=None,
scope='conv_{}'.format(i))
hidden = tf.nn.leaky_relu(hidden, alpha=0.01)
num_outputs = 2 * num_outputs
return hidden
def discriminator_output_source(input_net):
"""Output Layer for Source in the Discriminator.
Determine if the image is real/fake based on the feature extracted. We follow
the original paper design where the output is not a simple (batch_size) shape
Tensor but rather a (batch_size, 2, 2, 2048) shape Tensor. We will get the
correct shape later when we piece things together.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L79
Args:
input_net: Tensor of shape (batch_size, h / 64, w / 64, 2048) as features.
Returns:
Tensor of shape (batch_size, h / 64, w / 64, 1) as the score.
"""
with tf.variable_scope('discriminator_output_source'):
output_src = ops.pad(input_net, 1)
output_src = tf.contrib.layers.conv2d(
inputs=output_src,
num_outputs=1,
kernel_size=3,
stride=1,
padding='VALID',
activation_fn=None,
normalizer_fn=None,
biases_initializer=None,
scope='conv')
return output_src
def discriminator_output_class(input_net, class_num):
"""Output Layer for Domain Classification in the Discriminator.
The original paper use convolution layer where the kernel size is the height
and width of the Tensor. We use an equivalent operation here where we first
flatten the Tensor to shape (batch_size, K) and a fully connected layer.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L80https
Args:
input_net: Tensor of shape (batch_size, h / 64, w / 64, 2028).
class_num: Number of output classes to be predicted.
Returns:
Tensor of shape (batch_size, class_num).
"""
with tf.variable_scope('discriminator_output_class'):
output_cls = tf.contrib.layers.flatten(input_net, scope='flatten')
output_cls = tf.contrib.layers.fully_connected(
inputs=output_cls,
num_outputs=class_num,
activation_fn=None,
normalizer_fn=None,
biases_initializer=None,
scope='fully_connected')
return output_cls
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment