Unverified Commit 2d5e95a3 authored by Joel Shor's avatar Joel Shor Committed by GitHub
Browse files

Merge pull request #4181 from joel-shor/master

Made `generate_cifar10_tfrecords.py` python3 compatible. Fixes #3209. Fixes #3428
parents a1adc50b 4f7074f6
...@@ -19,15 +19,15 @@ from __future__ import division ...@@ -19,15 +19,15 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf from absl.testing import absltest
import infogan_eval import infogan_eval
class MnistInfoGANEvalTest(tf.test.TestCase): class MnistInfoGANEvalTest(absltest.TestCase):
def test_build_graph(self): def test_build_graph(self):
infogan_eval.main(None, run_eval_loop=False) infogan_eval.main(None, run_eval_loop=False)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() absltest.main()
...@@ -21,13 +21,14 @@ from __future__ import print_function ...@@ -21,13 +21,14 @@ from __future__ import print_function
import functools import functools
from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
import networks import networks
import util import util
flags = tf.flags
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
...@@ -146,5 +147,5 @@ def main(_): ...@@ -146,5 +147,5 @@ def main(_):
get_hooks_fn=tfgan.get_joint_train_hooks()) get_hooks_fn=tfgan.get_joint_train_hooks())
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) logging.set_verbosity(logging.INFO)
tf.app.run() tf.app.run()
...@@ -19,16 +19,18 @@ from __future__ import division ...@@ -19,16 +19,18 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import train import train
FLAGS = tf.flags.FLAGS FLAGS = flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
class TrainTest(tf.test.TestCase): class TrainTest(tf.test.TestCase, parameterized.TestCase):
@mock.patch.object(train, 'data_provider', autospec=True) @mock.patch.object(train, 'data_provider', autospec=True)
def test_run_one_train_step(self, mock_data_provider): def test_run_one_train_step(self, mock_data_provider):
...@@ -47,7 +49,11 @@ class TrainTest(tf.test.TestCase): ...@@ -47,7 +49,11 @@ class TrainTest(tf.test.TestCase):
train.main(None) train.main(None)
def _test_build_graph_helper(self, gan_type): @parameterized.named_parameters(
('Unconditional', 'unconditional'),
('Conditional', 'conditional'),
('InfoGAN', 'infogan'))
def test_build_graph(self, gan_type):
FLAGS.max_number_of_steps = 0 FLAGS.max_number_of_steps = 0
FLAGS.gan_type = gan_type FLAGS.gan_type = gan_type
...@@ -61,14 +67,5 @@ class TrainTest(tf.test.TestCase): ...@@ -61,14 +67,5 @@ class TrainTest(tf.test.TestCase):
mock_imgs, mock_lbls, None) mock_imgs, mock_lbls, None)
train.main(None) train.main(None)
def test_build_graph_unconditional(self):
self._test_build_graph_helper('unconditional')
def test_build_graph_conditional(self):
self._test_build_graph_helper('conditional')
def test_build_graph_infogan(self):
self._test_build_graph_helper('infogan')
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
from absl import flags
import numpy as np import numpy as np
import scipy.misc import scipy.misc
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
...@@ -29,7 +30,6 @@ from mnist import data_provider ...@@ -29,7 +30,6 @@ from mnist import data_provider
from mnist import networks from mnist import networks
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
flags = tf.flags
flags.DEFINE_integer('batch_size', 32, flags.DEFINE_integer('batch_size', 32,
'The number of images in each train batch.') 'The number of images in each train batch.')
......
...@@ -19,13 +19,13 @@ from __future__ import division ...@@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import train import train
FLAGS = tf.flags.FLAGS FLAGS = flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
......
...@@ -33,10 +33,19 @@ def generator(input_images): ...@@ -33,10 +33,19 @@ def generator(input_images):
Returns: Returns:
Returns generated image batch. Returns generated image batch.
Raises:
ValueError: If shape of last dimension (channels) is not defined.
""" """
input_images.shape.assert_has_rank(4) input_images.shape.assert_has_rank(4)
input_size = input_images.shape.as_list()
channels = input_size[-1]
if channels is None:
raise ValueError(
'Last dimension shape must be known but is None: %s' % input_size)
with tf.contrib.framework.arg_scope(cyclegan.cyclegan_arg_scope()): with tf.contrib.framework.arg_scope(cyclegan.cyclegan_arg_scope()):
output_images, _ = cyclegan.cyclegan_generator_resnet(input_images) output_images, _ = cyclegan.cyclegan_generator_resnet(input_images,
num_outputs=channels)
return output_images return output_images
......
...@@ -49,6 +49,19 @@ class Pix2PixTest(tf.test.TestCase): ...@@ -49,6 +49,19 @@ class Pix2PixTest(tf.test.TestCase):
with self.assertRaisesRegexp(ValueError, 'must have rank 4'): with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
networks.generator(tf.zeros([28, 28, 3])) networks.generator(tf.zeros([28, 28, 3]))
def test_generator_run_multi_channel(self):
img_batch = tf.zeros([3, 128, 128, 5])
model_output = networks.generator(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(model_output)
def test_generator_invalid_channels(self):
with self.assertRaisesRegexp(
ValueError, 'Last dimension shape must be known but is None'):
img = tf.placeholder(tf.float32, shape=[4, 32, 32, None])
networks.generator(img)
def test_discriminator_run(self): def test_discriminator_run(self):
img_batch = tf.zeros([3, 70, 70, 3]) img_batch = tf.zeros([3, 70, 70, 3])
disc_output = networks.discriminator(img_batch) disc_output = networks.discriminator(img_batch)
......
...@@ -19,13 +19,12 @@ from __future__ import division ...@@ -19,13 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
import networks import networks
flags = tf.flags
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
......
...@@ -19,17 +19,22 @@ from __future__ import division ...@@ -19,17 +19,22 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import train import train
FLAGS = tf.flags.FLAGS FLAGS = flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
class TrainTest(tf.test.TestCase): class TrainTest(tf.test.TestCase, parameterized.TestCase):
def _test_build_graph_helper(self, weight_factor): @parameterized.named_parameters(
('NoAdversarialLoss', 0.0),
('AdversarialLoss', 1.0))
def test_build_graph(self, weight_factor):
FLAGS.max_number_of_steps = 0 FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor FLAGS.weight_factor = weight_factor
FLAGS.batch_size = 9 FLAGS.batch_size = 9
...@@ -42,12 +47,6 @@ class TrainTest(tf.test.TestCase): ...@@ -42,12 +47,6 @@ class TrainTest(tf.test.TestCase):
mock_data_provider.provide_data.return_value = mock_imgs mock_data_provider.provide_data.return_value = mock_imgs
train.main(None) train.main(None)
def test_build_graph_noadversarialloss(self):
self._test_build_graph_helper(0.0)
def test_build_graph_adversarialloss(self):
self._test_build_graph_helper(1.0)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loading and preprocessing image data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
def normalize_image(image):
"""Rescales image from range [0, 255] to [-1, 1]."""
return (tf.to_float(image) - 127.5) / 127.5
def sample_patch(image, patch_height, patch_width, colors):
"""Crops image to the desired aspect ratio shape and resizes it.
If the image has shape H x W, crops a square in the center of
shape min(H,W) x min(H,W).
Args:
image: A 3D `Tensor` of HWC format.
patch_height: A Python integer. The output images height.
patch_width: A Python integer. The output images width.
colors: Number of output image channels. Defaults to 3.
Returns:
A 3D `Tensor` of HWC format with shape [patch_height, patch_width, colors].
"""
image_shape = tf.shape(image)
h, w = image_shape[0], image_shape[1]
h_major_target_h = h
h_major_target_w = tf.maximum(1, tf.to_int32(
(h * patch_width) / patch_height))
w_major_target_h = tf.maximum(1, tf.to_int32(
(w * patch_height) / patch_width))
w_major_target_w = w
target_hw = tf.cond(
h_major_target_w <= w,
lambda: tf.convert_to_tensor([h_major_target_h, h_major_target_w]),
lambda: tf.convert_to_tensor([w_major_target_h, w_major_target_w]))
# Cut a patch of shape (target_h, target_w).
image = tf.image.resize_image_with_crop_or_pad(image, target_hw[0],
target_hw[1])
# Resize the cropped image to (patch_h, patch_w).
image = tf.image.resize_images([image], [patch_height, patch_width])[0]
# Force number of channels: repeat the channel dimension enough
# number of times and then slice the first `colors` channels.
num_repeats = tf.to_int32(tf.ceil(colors / image_shape[2]))
image = tf.tile(image, [1, 1, num_repeats])
image = tf.slice(image, [0, 0, 0], [-1, -1, colors])
image.set_shape([patch_height, patch_width, colors])
return image
def batch_images(image, patch_height, patch_width, colors, batch_size, shuffle,
num_threads):
"""Creates a batch of images.
Args:
image: A 3D `Tensor` of HWC format.
patch_height: A Python integer. The output images height.
patch_width: A Python integer. The output images width.
colors: Number of channels.
batch_size: The number of images in each minibatch. Defaults to 32.
shuffle: Whether to shuffle the read images.
num_threads: Number of prefetching threads.
Returns:
A float `Tensor`s with shape [batch_size, patch_height, patch_width, colors]
representing a batch of images.
"""
image = sample_patch(image, patch_height, patch_width, colors)
images = None
if shuffle:
images = tf.train.shuffle_batch(
[image],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
images = tf.train.batch(
[image],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)
images.set_shape([batch_size, patch_height, patch_width, colors])
return images
def provide_data(dataset_name='cifar10',
split_name='train',
dataset_dir,
batch_size=32,
shuffle=True,
num_threads=1,
patch_height=32,
patch_width=32,
colors=3):
"""Provides a batch of image data from predefined dataset.
Args:
dataset_name: A string of dataset name. Defaults to 'cifar10'.
split_name: Either 'train' or 'validation'. Defaults to 'train'.
dataset_dir: The directory where the data can be found. If `None`, use
default.
batch_size: The number of images in each minibatch. Defaults to 32.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_height: A Python integer. The read images height. Defaults to 32.
patch_width: A Python integer. The read images width. Defaults to 32.
colors: Number of channels. Defaults to 3.
Returns:
A float `Tensor`s with shape [batch_size, patch_height, patch_width, colors]
representing a batch of images.
"""
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=1,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=shuffle)
return batch_images(
image=normalize_image(provider.get(['image'])[0]),
patch_height=patch_height,
patch_width=patch_width,
colors=colors,
batch_size=batch_size,
shuffle=shuffle,
num_threads=num_threads)
def provide_data_from_image_files(file_pattern,
batch_size=32,
shuffle=True,
num_threads=1,
patch_height=32,
patch_width=32,
colors=3):
"""Provides a batch of image data from image files.
Args:
file_pattern: A file pattern (glob), or 1D `Tensor` of file patterns.
batch_size: The number of images in each minibatch. Defaults to 32.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_height: A Python integer. The read images height. Defaults to 32.
patch_width: A Python integer. The read images width. Defaults to 32.
colors: Number of channels. Defaults to 3.
Returns:
A float `Tensor` of shape [batch_size, patch_height, patch_width, 3]
representing a batch of images.
"""
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(file_pattern),
shuffle=shuffle,
capacity=5 * batch_size)
_, image_bytes = tf.WholeFileReader().read(filename_queue)
return batch_images(
image=normalize_image(tf.image.decode_image(image_bytes)),
patch_height=patch_height,
patch_width=patch_width,
colors=colors,
batch_size=batch_size,
shuffle=shuffle,
num_threads=num_threads)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import numpy as np
import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
def setUp(self):
super(DataProviderTest, self).setUp()
self.testdata_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/progressive_gan/testdata/')
def test_normalize_image(self):
image_np = np.asarray([0, 255, 210], dtype=np.uint8)
normalized_image = data_provider.normalize_image(tf.constant(image_np))
self.assertEqual(normalized_image.dtype, tf.float32)
self.assertEqual(normalized_image.shape.as_list(), [3])
with self.test_session(use_gpu=True) as sess:
normalized_image_np = sess.run(normalized_image)
self.assertNDArrayNear(normalized_image_np, [-1, 1, 0.6470588235], 1.0e-6)
def test_sample_patch_large_patch_returns_upscaled_image(self):
image_np = np.reshape(np.arange(2 * 2), [2, 2, 1])
image = tf.constant(image_np, dtype=tf.float32)
image_patch = data_provider.sample_patch(
image, patch_height=3, patch_width=3, colors=1)
with self.test_session(use_gpu=True) as sess:
image_patch_np = sess.run(image_patch)
expected_np = np.asarray([[[0.], [0.66666669], [1.]], [[1.33333337], [2.],
[2.33333349]],
[[2.], [2.66666675], [3.]]])
self.assertNDArrayNear(image_patch_np, expected_np, 1.0e-6)
def test_sample_patch_small_patch_returns_downscaled_image(self):
image_np = np.reshape(np.arange(3 * 3), [3, 3, 1])
image = tf.constant(image_np, dtype=tf.float32)
image_patch = data_provider.sample_patch(
image, patch_height=2, patch_width=2, colors=1)
with self.test_session(use_gpu=True) as sess:
image_patch_np = sess.run(image_patch)
expected_np = np.asarray([[[0.], [1.5]], [[4.5], [6.]]])
self.assertNDArrayNear(image_patch_np, expected_np, 1.0e-6)
def test_batch_images(self):
image_np = np.reshape(np.arange(3 * 3), [3, 3, 1])
image = tf.constant(image_np, dtype=tf.float32)
images = data_provider.batch_images(
image,
patch_height=2,
patch_width=2,
colors=1,
batch_size=2,
shuffle=False,
num_threads=1)
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
expected_np = np.asarray([[[[0.], [1.5]], [[4.5], [6.]]], [[[0.], [1.5]],
[[4.5], [6.]]]])
self.assertNDArrayNear(images_np, expected_np, 1.0e-6)
def test_provide_data(self):
images = data_provider.provide_data(
'mnist',
'train',
dataset_dir=self.testdata_dir,
batch_size=2,
shuffle=False,
patch_height=3,
patch_width=3,
colors=1)
self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
self.assertEqual(images_np.shape, (2, 3, 3, 1))
def test_provide_data_from_image_files_a_single_pattern(self):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
images = data_provider.provide_data_from_image_files(
file_pattern,
batch_size=2,
shuffle=False,
patch_height=3,
patch_width=3,
colors=1)
self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
self.assertEqual(images_np.shape, (2, 3, 3, 1))
def test_provide_data_from_image_files_a_list_of_patterns(self):
file_pattern = [os.path.join(self.testdata_dir, '*.jpg')]
images = data_provider.provide_data_from_image_files(
file_pattern,
batch_size=2,
shuffle=False,
patch_height=3,
patch_width=3,
colors=1)
self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
self.assertEqual(images_np.shape, (2, 3, 3, 1))
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for a progressive GAN model.
This module contains basic building blocks to build a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def pixel_norm(images, epsilon=1.0e-8):
"""Pixel normalization.
For each pixel a[i,j,k] of image in HWC format, normalize its value to
b[i,j,k] = a[i,j,k] / SQRT(SUM_k(a[i,j,k]^2) / C + eps).
Args:
images: A 4D `Tensor` of NHWC format.
epsilon: A small positive number to avoid division by zero.
Returns:
A 4D `Tensor` with pixel-wise normalized channels.
"""
return images * tf.rsqrt(
tf.reduce_mean(tf.square(images), axis=3, keepdims=True) + epsilon)
def _get_validated_scale(scale):
"""Returns the scale guaranteed to be a positive integer."""
scale = int(scale)
if scale <= 0:
raise ValueError('`scale` must be a positive integer.')
return scale
def downscale(images, scale):
"""Box downscaling of images.
Args:
images: A 4D `Tensor` in NHWC format.
scale: A positive integer scale.
Returns:
A 4D `Tensor` of `images` down scaled by a factor `scale`.
Raises:
ValueError: If `scale` is not a positive integer.
"""
scale = _get_validated_scale(scale)
if scale == 1:
return images
return tf.nn.avg_pool(
images,
ksize=[1, scale, scale, 1],
strides=[1, scale, scale, 1],
padding='VALID')
def upscale(images, scale):
"""Box upscaling (also called nearest neighbors) of images.
Args:
images: A 4D `Tensor` in NHWC format.
scale: A positive integer scale.
Returns:
A 4D `Tensor` of `images` up scaled by a factor `scale`.
Raises:
ValueError: If `scale` is not a positive integer.
"""
scale = _get_validated_scale(scale)
if scale == 1:
return images
return tf.batch_to_space(
tf.tile(images, [scale**2, 1, 1, 1]),
crops=[[0, 0], [0, 0]],
block_size=scale)
def minibatch_mean_stddev(x):
"""Computes the standard deviation average.
This is used by the discriminator as a form of batch discrimination.
Args:
x: A `Tensor` for which to compute the standard deviation average. The first
dimension must be batch size.
Returns:
A scalar `Tensor` which is the mean variance of variable x.
"""
mean, var = tf.nn.moments(x, axes=[0])
del mean
return tf.reduce_mean(tf.sqrt(var))
def scalar_concat(tensor, scalar):
"""Concatenates a scalar to the last dimension of a tensor.
Args:
tensor: A `Tensor`.
scalar: a scalar `Tensor` to concatenate to tensor `tensor`.
Returns:
A `Tensor`. If `tensor` has shape [...,N], the result R has shape
[...,N+1] and R[...,N] = scalar.
Raises:
ValueError: If `tensor` is a scalar `Tensor`.
"""
ndims = tensor.shape.ndims
if ndims < 1:
raise ValueError('`tensor` must have number of dimensions >= 1.')
shape = tf.shape(tensor)
return tf.concat(
[tensor, tf.ones([shape[i] for i in range(ndims - 1)] + [1]) * scalar],
axis=ndims - 1)
def he_initializer_scale(shape, slope=1.0):
"""The scale of He neural network initializer.
Args:
shape: A list of ints representing the dimensions of a tensor.
slope: A float representing the slope of the ReLu following the layer.
Returns:
A float of he initializer scale.
"""
fan_in = np.prod(shape[:-1])
return np.sqrt(2. / ((1. + slope**2) * fan_in))
def _custom_layer_impl(apply_kernel, kernel_shape, bias_shape, activation,
he_initializer_slope, use_weight_scaling):
"""Helper function to implement custom_xxx layer.
Args:
apply_kernel: A function that transforms kernel to output.
kernel_shape: An integer tuple or list of the kernel shape.
bias_shape: An integer tuple or list of the bias shape.
activation: An activation function to be applied. None means no
activation.
he_initializer_slope: A float slope for the He initializer.
use_weight_scaling: Whether to apply weight scaling.
Returns:
A `Tensor` computed as apply_kernel(kernel) + bias where kernel is a
`Tensor` variable with shape `kernel_shape`, bias is a `Tensor` variable
with shape `bias_shape`.
"""
kernel_scale = he_initializer_scale(kernel_shape, he_initializer_slope)
init_scale, post_scale = kernel_scale, 1.0
if use_weight_scaling:
init_scale, post_scale = post_scale, init_scale
kernel_initializer = tf.random_normal_initializer(stddev=init_scale)
bias = tf.get_variable(
'bias', shape=bias_shape, initializer=tf.zeros_initializer())
output = post_scale * apply_kernel(kernel_shape, kernel_initializer) + bias
if activation is not None:
output = activation(output)
return output
def custom_conv2d(x,
filters,
kernel_size,
strides=(1, 1),
padding='SAME',
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True,
scope='custom_conv2d',
reuse=None):
"""Custom conv2d layer.
In comparison with tf.layers.conv2d this implementation use the He initializer
to initialize convolutional kernel and the weight scaling trick (if
`use_weight_scaling` is True) to equalize learning rates. See
https://arxiv.org/abs/1710.10196 for more details.
Args:
x: A `Tensor` of NHWC format.
filters: An int of output channels.
kernel_size: An integer or a int tuple of [kernel_height, kernel_width].
strides: A list of strides.
padding: One of "VALID" or "SAME".
activation: An activation function to be applied. None means no
activation. Defaults to None.
he_initializer_slope: A float slope for the He initializer. Defaults to 1.0.
use_weight_scaling: Whether to apply weight scaling. Defaults to True.
scope: A string or variable scope.
reuse: Whether to reuse the weights. Defaults to None.
Returns:
A `Tensor` of NHWC format where the last dimension has size `filters`.
"""
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * 2
kernel_size = list(kernel_size)
def _apply_kernel(kernel_shape, kernel_initializer):
return tf.layers.conv2d(
x,
filters=filters,
kernel_size=kernel_shape[0:2],
strides=strides,
padding=padding,
use_bias=False,
kernel_initializer=kernel_initializer)
with tf.variable_scope(scope, reuse=reuse):
return _custom_layer_impl(
_apply_kernel,
kernel_shape=kernel_size + [x.shape.as_list()[3], filters],
bias_shape=(filters,),
activation=activation,
he_initializer_slope=he_initializer_slope,
use_weight_scaling=use_weight_scaling)
def custom_dense(x,
units,
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True,
scope='custom_dense',
reuse=None):
"""Custom dense layer.
In comparison with tf.layers.dense This implementation use the He
initializer to initialize weights and the weight scaling trick
(if `use_weight_scaling` is True) to equalize learning rates. See
https://arxiv.org/abs/1710.10196 for more details.
Args:
x: A `Tensor`.
units: An int of the last dimension size of output.
activation: An activation function to be applied. None means no
activation. Defaults to None.
he_initializer_slope: A float slope for the He initializer. Defaults to 1.0.
use_weight_scaling: Whether to apply weight scaling. Defaults to True.
scope: A string or variable scope.
reuse: Whether to reuse the weights. Defaults to None.
Returns:
A `Tensor` where the last dimension has size `units`.
"""
x = tf.contrib.layers.flatten(x)
def _apply_kernel(kernel_shape, kernel_initializer):
return tf.layers.dense(
x,
kernel_shape[1],
use_bias=False,
kernel_initializer=kernel_initializer)
with tf.variable_scope(scope, reuse=reuse):
return _custom_layer_impl(
_apply_kernel,
kernel_shape=(x.shape.as_list()[-1], units),
bias_shape=(units,),
activation=activation,
he_initializer_slope=he_initializer_slope,
use_weight_scaling=use_weight_scaling)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import layers
mock = tf.test.mock
def dummy_apply_kernel(kernel_shape, kernel_initializer):
kernel = tf.get_variable(
'kernel', shape=kernel_shape, initializer=kernel_initializer)
return tf.reduce_sum(kernel) + 1.5
class LayersTest(tf.test.TestCase):
def test_pixel_norm_4d_images_returns_channel_normalized_images(self):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
with self.test_session(use_gpu=True) as sess:
output_np = sess.run(layers.pixel_norm(x))
expected_np = [[[[0.46291006, 0.92582011, 1.38873017],
[0.78954202, 0.98692751, 1.18431306]],
[[0.87047803, 0.99483204, 1.11918604],
[0.90659684, 0.99725652, 1.08791625]]],
[[[0., 0., 0.], [-0.46291006, -0.92582011, -1.38873017]],
[[0.57735026, -1.15470052, 1.15470052],
[0.56195146, 1.40487862, 0.84292722]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
def test_get_validated_scale_invalid_scale_throws_exception(self):
with self.assertRaises(ValueError):
layers._get_validated_scale(0)
def test_get_validated_scale_float_scale_returns_integer(self):
self.assertEqual(layers._get_validated_scale(5.5), 5)
def test_downscale_invalid_scale_throws_exception(self):
with self.assertRaises(ValueError):
layers.downscale(tf.constant([]), -1)
def test_downscale_4d_images_returns_downscaled_images(self):
x_np = np.array(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=np.float32)
with self.test_session(use_gpu=True) as sess:
x1_np, x2_np = sess.run(
[layers.downscale(tf.constant(x_np), n) for n in [1, 2]])
expected2_np = [[[[5.5, 6.5, 7.5]]], [[[0.5, 0.25, 0.5]]]]
self.assertNDArrayNear(x1_np, x_np, 1.0e-5)
self.assertNDArrayNear(x2_np, expected2_np, 1.0e-5)
def test_upscale_invalid_scale_throws_exception(self):
with self.assertRaises(ValueError):
self.assertRaises(layers.upscale(tf.constant([]), -1))
def test_upscale_4d_images_returns_upscaled_images(self):
x_np = np.array([[[[1, 2, 3]]], [[[4, 5, 6]]]], dtype=np.float32)
with self.test_session(use_gpu=True) as sess:
x1_np, x2_np = sess.run(
[layers.upscale(tf.constant(x_np), n) for n in [1, 2]])
expected2_np = [[[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
[[[4, 5, 6], [4, 5, 6]], [[4, 5, 6], [4, 5, 6]]]]
self.assertNDArrayNear(x1_np, x_np, 1.0e-5)
self.assertNDArrayNear(x2_np, expected2_np, 1.0e-5)
def test_minibatch_mean_stddev_4d_images_returns_scalar(self):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
with self.test_session(use_gpu=True) as sess:
output_np = sess.run(layers.minibatch_mean_stddev(x))
self.assertAlmostEqual(output_np, 3.0416667, 5)
def test_scalar_concat_invalid_input_throws_exception(self):
with self.assertRaises(ValueError):
layers.scalar_concat(tf.constant(1.2), 2.0)
def test_scalar_concat_4d_images_and_scalar(self):
x = tf.constant(
[[[[1, 2], [4, 5]], [[7, 8], [10, 11]]], [[[0, 0], [-1, -2]],
[[1, -2], [2, 5]]]],
dtype=tf.float32)
with self.test_session(use_gpu=True) as sess:
output_np = sess.run(layers.scalar_concat(x, 7))
expected_np = [[[[1, 2, 7], [4, 5, 7]], [[7, 8, 7], [10, 11, 7]]],
[[[0, 0, 7], [-1, -2, 7]], [[1, -2, 7], [2, 5, 7]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
def test_he_initializer_scale_slope_linear(self):
self.assertAlmostEqual(
layers.he_initializer_scale([3, 4, 5, 6], 1.0), 0.1290994, 5)
def test_he_initializer_scale_slope_relu(self):
self.assertAlmostEqual(
layers.he_initializer_scale([3, 4, 5, 6], 0.0), 0.1825742, 5)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_layer_impl_with_weight_scaling(
self, mock_zeros_initializer, mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
output = layers._custom_layer_impl(
apply_kernel=dummy_apply_kernel,
kernel_shape=(25, 6),
bias_shape=(),
activation=lambda x: 2.0 * x,
he_initializer_slope=1.0,
use_weight_scaling=True)
mock_zeros_initializer.assert_called_once_with()
mock_random_normal_initializer.assert_called_once_with(stddev=1.0)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
self.assertAlmostEqual(output_np, 182.6, 3)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_layer_impl_no_weight_scaling(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
output = layers._custom_layer_impl(
apply_kernel=dummy_apply_kernel,
kernel_shape=(25, 6),
bias_shape=(),
activation=lambda x: 2.0 * x,
he_initializer_slope=1.0,
use_weight_scaling=False)
mock_zeros_initializer.assert_called_once_with()
mock_random_normal_initializer.assert_called_once_with(stddev=0.2)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
self.assertAlmostEqual(output_np, 905.0, 3)
@mock.patch.object(tf.layers, 'conv2d', autospec=True)
def test_custom_conv2d_passes_conv2d_options(self, mock_conv2d):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
layers.custom_conv2d(x, 1, 2)
mock_conv2d.assert_called_once_with(
x,
filters=1,
kernel_size=[2, 2],
strides=(1, 1),
padding='SAME',
use_bias=False,
kernel_initializer=mock.ANY)
@mock.patch.object(layers, '_custom_layer_impl', autospec=True)
def test_custom_conv2d_passes_custom_layer_options(self,
mock_custom_layer_impl):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
layers.custom_conv2d(x, 1, 2)
mock_custom_layer_impl.assert_called_once_with(
mock.ANY,
kernel_shape=[2, 2, 3, 1],
bias_shape=(1,),
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_conv2d_scalar_kernel_size(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
output = layers.custom_conv2d(x, 1, 2)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
expected_np = [[[[68.54998016], [42.56921768]], [[50.36344528],
[29.57883835]]],
[[[5.33012676], [4.46410179]], [[10.52627945],
[9.66025352]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_conv2d_list_kernel_size(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
output = layers.custom_conv2d(x, 1, [2, 3])
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
expected_np = [[
[[56.15432739], [56.15432739]],
[[41.30508804], [41.30508804]],
], [[[4.53553391], [4.53553391]], [[8.7781744], [8.7781744]]]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
@mock.patch.object(layers, '_custom_layer_impl', autospec=True)
def test_custom_dense_passes_custom_layer_options(self,
mock_custom_layer_impl):
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
layers.custom_dense(x, 3)
mock_custom_layer_impl.assert_called_once_with(
mock.ANY,
kernel_shape=(12, 3),
bias_shape=(3,),
activation=None,
he_initializer_slope=1.0,
use_weight_scaling=True)
@mock.patch.object(tf, 'random_normal_initializer', autospec=True)
@mock.patch.object(tf, 'zeros_initializer', autospec=True)
def test_custom_dense_output_is_correct(self, mock_zeros_initializer,
mock_random_normal_initializer):
mock_zeros_initializer.return_value = tf.constant_initializer(1.0)
mock_random_normal_initializer.return_value = tf.constant_initializer(3.0)
x = tf.constant(
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
[[[0, 0, 0], [-1, -2, -3]], [[1, -2, 2], [2, 5, 3]]]],
dtype=tf.float32)
output = layers.custom_dense(x, 3)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
output_np = sess.run(output)
expected_np = [[68.54998016, 68.54998016, 68.54998016],
[5.33012676, 5.33012676, 5.33012676]]
self.assertNDArrayNear(output_np, expected_np, 1.0e-5)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generator and discriminator for a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
import layers
class ResolutionSchedule(object):
"""Image resolution upscaling schedule."""
def __init__(self, start_resolutions=(4, 4), scale_base=2, num_resolutions=4):
"""Initializer.
Args:
start_resolutions: An tuple of integers of HxW format for start image
resolutions. Defaults to (4, 4).
scale_base: An integer of resolution base multiplier. Defaults to 2.
num_resolutions: An integer of how many progressive resolutions (including
`start_resolutions`). Defaults to 4.
"""
self._start_resolutions = start_resolutions
self._scale_base = scale_base
self._num_resolutions = num_resolutions
@property
def start_resolutions(self):
return tuple(self._start_resolutions)
@property
def scale_base(self):
return self._scale_base
@property
def num_resolutions(self):
return self._num_resolutions
@property
def final_resolutions(self):
"""Returns the final resolutions."""
return tuple([
r * self._scale_base**(self._num_resolutions - 1)
for r in self._start_resolutions
])
def scale_factor(self, block_id):
"""Returns the scale factor for network block `block_id`."""
if block_id < 1 or block_id > self._num_resolutions:
raise ValueError('`block_id` must be in [1, {}]'.format(
self._num_resolutions))
return self._scale_base**(self._num_resolutions - block_id)
def block_name(block_id):
"""Returns the scope name for the network block `block_id`."""
return 'progressive_gan_block_{}'.format(block_id)
def min_total_num_images(stable_stage_num_images, transition_stage_num_images,
num_blocks):
"""Returns the minimum total number of images.
Computes the minimum total number of images required to reach the desired
`resolution`.
Args:
stable_stage_num_images: Number of images in the stable stage.
transition_stage_num_images: Number of images in the transition stage.
num_blocks: Number of network blocks.
Returns:
An integer of the minimum total number of images.
"""
return (num_blocks * stable_stage_num_images +
(num_blocks - 1) * transition_stage_num_images)
def compute_progress(current_image_id, stable_stage_num_images,
transition_stage_num_images, num_blocks):
"""Computes the training progress.
The training alternates between stable phase and transition phase.
The `progress` indicates the training progress, i.e. the training is at
- a stable phase p if progress = p
- a transition stage between p and p + 1 if progress = p + fraction
where p = 0,1,2.,...
Note the max value of progress is `num_blocks` - 1.
In terms of LOD (of the original implementation):
progress = `num_blocks` - 1 - LOD
Args:
current_image_id: An scalar integer `Tensor` of the current image id, count
from 0.
stable_stage_num_images: An integer representing the number of images in
each stable stage.
transition_stage_num_images: An integer representing the number of images in
each transition stage.
num_blocks: Number of network blocks.
Returns:
A scalar float `Tensor` of the training progress.
"""
# Note when current_image_id >= min_total_num_images - 1 (which means we
# are already at the highest resolution), we want to keep progress constant.
# Therefore, cap current_image_id here.
capped_current_image_id = tf.minimum(
current_image_id,
min_total_num_images(stable_stage_num_images, transition_stage_num_images,
num_blocks) - 1)
stage_num_images = stable_stage_num_images + transition_stage_num_images
progress_integer = tf.floordiv(capped_current_image_id, stage_num_images)
progress_fraction = tf.maximum(
0.0,
tf.to_float(
tf.mod(capped_current_image_id, stage_num_images) -
stable_stage_num_images) / tf.to_float(transition_stage_num_images))
return tf.to_float(progress_integer) + progress_fraction
def _generator_alpha(block_id, progress):
"""Returns the block output parameter for the generator network.
The generator has N blocks with `block_id` = 1,2,...,N. Each block
block_id outputs a fake data output(block_id). The generator output is a
linear combination of all block outputs, i.e.
SUM_block_id(output(block_id) * alpha(block_id, progress)) where
alpha(block_id, progress) = _generator_alpha(block_id, progress). Note it
garantees that SUM_block_id(alpha(block_id, progress)) = 1 for any progress.
With a fixed block_id, the plot of alpha(block_id, progress) against progress
is a 'triangle' with its peak at (block_id - 1, 1).
Args:
block_id: An integer of generator block id.
progress: A scalar float `Tensor` of training progress.
Returns:
A scalar float `Tensor` of block output parameter.
"""
return tf.maximum(0.0,
tf.minimum(progress - (block_id - 2), block_id - progress))
def _discriminator_alpha(block_id, progress):
"""Returns the block input parameter for discriminator network.
The discriminator has N blocks with `block_id` = 1,2,...,N. Each block
block_id accepts an
- input(block_id) transformed from the real data and
- the output of block block_id + 1, i.e. output(block_id + 1)
The final input is a linear combination of them,
i.e. alpha * input(block_id) + (1 - alpha) * output(block_id + 1)
where alpha = _discriminator_alpha(block_id, progress).
With a fixed block_id, alpha(block_id, progress) stays to be 1
when progress <= block_id - 1, then linear decays to 0 when
block_id - 1 < progress <= block_id, and finally stays at 0
when progress > block_id.
Args:
block_id: An integer of generator block id.
progress: A scalar float `Tensor` of training progress.
Returns:
A scalar float `Tensor` of block input parameter.
"""
return tf.clip_by_value(block_id - progress, 0.0, 1.0)
def blend_images(x, progress, resolution_schedule, num_blocks):
"""Blends images of different resolutions according to `progress`.
When training `progress` is at a stable stage for resolution r, returns
image `x` downscaled to resolution r and then upscaled to `final_resolutions`,
call it x'(r).
Otherwise when training `progress` is at a transition stage from resolution
r to 2r, returns a linear combination of x'(r) and x'(2r).
Args:
x: An image `Tensor` of NHWC format with resolution `final_resolutions`.
progress: A scalar float `Tensor` of training progress.
resolution_schedule: An object of `ResolutionSchedule`.
num_blocks: An integer of number of blocks.
Returns:
An image `Tensor` which is a blend of images of different resolutions.
"""
x_blend = []
for block_id in range(1, num_blocks + 1):
alpha = _generator_alpha(block_id, progress)
scale = resolution_schedule.scale_factor(block_id)
x_blend.append(alpha * layers.upscale(layers.downscale(x, scale), scale))
return tf.add_n(x_blend)
def num_filters(block_id, fmap_base=4096, fmap_decay=1.0, fmap_max=256):
"""Computes number of filters of block `block_id`."""
return int(min(fmap_base / math.pow(2.0, block_id * fmap_decay), fmap_max))
def generator(z,
progress,
num_filters_fn,
resolution_schedule,
num_blocks=None,
kernel_size=3,
colors=3,
to_rgb_activation=None,
scope='progressive_gan_generator',
reuse=None):
"""Generator network for the progressive GAN model.
Args:
z: A `Tensor` of latent vector. The first dimension must be batch size.
progress: A scalar float `Tensor` of training progress.
num_filters_fn: A function that maps `block_id` to # of filters for the
block.
resolution_schedule: An object of `ResolutionSchedule`.
num_blocks: An integer of number of blocks. None means maximum number of
blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None.
kernel_size: An integer of convolution kernel size.
colors: Number of output color channels. Defaults to 3.
to_rgb_activation: Activation function applied when output rgb.
scope: A string or variable scope.
reuse: Whether to reuse `scope`. Defaults to None which means to inherit
the reuse option of the parent scope.
Returns:
A `Tensor` of model output and a dictionary of model end points.
"""
if num_blocks is None:
num_blocks = resolution_schedule.num_resolutions
start_h, start_w = resolution_schedule.start_resolutions
final_h, final_w = resolution_schedule.final_resolutions
def _conv2d(scope, x, kernel_size, filters, padding='SAME'):
return layers.custom_conv2d(
x=x,
filters=filters,
kernel_size=kernel_size,
padding=padding,
activation=lambda x: layers.pixel_norm(tf.nn.leaky_relu(x)),
he_initializer_slope=0.0,
scope=scope)
def _to_rgb(x):
return layers.custom_conv2d(
x=x,
filters=colors,
kernel_size=1,
padding='SAME',
activation=to_rgb_activation,
scope='to_rgb')
end_points = {}
with tf.variable_scope(scope, reuse=reuse):
with tf.name_scope('input'):
x = tf.contrib.layers.flatten(z)
end_points['latent_vector'] = x
with tf.variable_scope(block_name(1)):
x = tf.expand_dims(tf.expand_dims(x, 1), 1)
x = layers.pixel_norm(x)
# Pad the 1 x 1 image to 2 * (start_h - 1) x 2 * (start_w - 1)
# with zeros for the next conv.
x = tf.pad(x, [[0] * 2, [start_h - 1] * 2, [start_w - 1] * 2, [0] * 2])
# The output is start_h x start_w x num_filters_fn(1).
x = _conv2d('conv0', x, (start_h, start_w), num_filters_fn(1), 'VALID')
x = _conv2d('conv1', x, kernel_size, num_filters_fn(1))
lods = [x]
for block_id in range(2, num_blocks + 1):
with tf.variable_scope(block_name(block_id)):
x = layers.upscale(x, resolution_schedule.scale_base)
x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id))
x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id))
lods.append(x)
outputs = []
for block_id in range(1, num_blocks + 1):
with tf.variable_scope(block_name(block_id)):
lod = _to_rgb(lods[block_id - 1])
scale = resolution_schedule.scale_factor(block_id)
lod = layers.upscale(lod, scale)
end_points['upscaled_rgb_{}'.format(block_id)] = lod
# alpha_i is used to replace lod_select. Note sum(alpha_i) is
# garanteed to be 1.
alpha = _generator_alpha(block_id, progress)
end_points['alpha_{}'.format(block_id)] = alpha
outputs.append(lod * alpha)
predictions = tf.add_n(outputs)
batch_size = z.shape[0].value
predictions.set_shape([batch_size, final_h, final_w, colors])
end_points['predictions'] = predictions
return predictions, end_points
def discriminator(x,
progress,
num_filters_fn,
resolution_schedule,
num_blocks=None,
kernel_size=3,
scope='progressive_gan_discriminator',
reuse=None):
"""Discriminator network for the progressive GAN model.
Args:
x: A `Tensor`of NHWC format representing images of size `resolution`.
progress: A scalar float `Tensor` of training progress.
num_filters_fn: A function that maps `block_id` to # of filters for the
block.
resolution_schedule: An object of `ResolutionSchedule`.
num_blocks: An integer of number of blocks. None means maximum number of
blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None.
kernel_size: An integer of convolution kernel size.
scope: A string or variable scope.
reuse: Whether to reuse `scope`. Defaults to None which means to inherit
the reuse option of the parent scope.
Returns:
A `Tensor` of model output and a dictionary of model end points.
"""
if num_blocks is None:
num_blocks = resolution_schedule.num_resolutions
def _conv2d(scope, x, kernel_size, filters, padding='SAME'):
return layers.custom_conv2d(
x=x,
filters=filters,
kernel_size=kernel_size,
padding=padding,
activation=tf.nn.leaky_relu,
he_initializer_slope=0.0,
scope=scope)
def _from_rgb(x, block_id):
return _conv2d('from_rgb', x, 1, num_filters_fn(block_id))
end_points = {}
with tf.variable_scope(scope, reuse=reuse):
x0 = x
end_points['rgb'] = x0
lods = []
for block_id in range(num_blocks, 0, -1):
with tf.variable_scope(block_name(block_id)):
scale = resolution_schedule.scale_factor(block_id)
lod = layers.downscale(x0, scale)
end_points['downscaled_rgb_{}'.format(block_id)] = lod
lod = _from_rgb(lod, block_id)
# alpha_i is used to replace lod_select.
alpha = _discriminator_alpha(block_id, progress)
end_points['alpha_{}'.format(block_id)] = alpha
lods.append((lod, alpha))
lods_iter = iter(lods)
x, _ = lods_iter.next()
for block_id in range(num_blocks, 1, -1):
with tf.variable_scope(block_name(block_id)):
x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id))
x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id - 1))
x = layers.downscale(x, resolution_schedule.scale_base)
lod, alpha = lods_iter.next()
x = alpha * lod + (1.0 - alpha) * x
with tf.variable_scope(block_name(1)):
x = layers.scalar_concat(x, layers.minibatch_mean_stddev(x))
x = _conv2d('conv0', x, kernel_size, num_filters_fn(1))
x = _conv2d('conv1', x, resolution_schedule.start_resolutions,
num_filters_fn(0), 'VALID')
end_points['last_conv'] = x
logits = layers.custom_dense(x=x, units=1, scope='logits')
end_points['logits'] = logits
return logits, end_points
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import layers
import networks
def _get_grad_norm(ys, xs):
"""Compute 2-norm of dys / dxs."""
return tf.sqrt(
tf.add_n([tf.reduce_sum(tf.square(g)) for g in tf.gradients(ys, xs)]))
def _num_filters_stub(block_id):
return networks.num_filters(block_id, 8, 1, 8)
class NetworksTest(tf.test.TestCase):
def test_resolution_schedule_correct(self):
rs = networks.ResolutionSchedule(
start_resolutions=[5, 3], scale_base=2, num_resolutions=3)
self.assertEqual(rs.start_resolutions, (5, 3))
self.assertEqual(rs.scale_base, 2)
self.assertEqual(rs.num_resolutions, 3)
self.assertEqual(rs.final_resolutions, (20, 12))
self.assertEqual(rs.scale_factor(1), 4)
self.assertEqual(rs.scale_factor(2), 2)
self.assertEqual(rs.scale_factor(3), 1)
with self.assertRaises(ValueError):
rs.scale_factor(0)
with self.assertRaises(ValueError):
rs.scale_factor(4)
def test_block_name(self):
self.assertEqual(networks.block_name(10), 'progressive_gan_block_10')
def test_min_total_num_images(self):
self.assertEqual(networks.min_total_num_images(7, 8, 4), 52)
def test_compute_progress(self):
current_image_id_ph = tf.placeholder(tf.int32, [])
progress = networks.compute_progress(
current_image_id_ph,
stable_stage_num_images=7,
transition_stage_num_images=8,
num_blocks=2)
with self.test_session(use_gpu=True) as sess:
progress_output = [
sess.run(progress, feed_dict={current_image_id_ph: current_image_id})
for current_image_id in [0, 3, 6, 7, 8, 10, 15, 29, 100]
]
self.assertArrayNear(progress_output,
[0.0, 0.0, 0.0, 0.0, 0.125, 0.375, 1.0, 1.0, 1.0],
1.0e-6)
def test_generator_alpha(self):
with self.test_session(use_gpu=True) as sess:
alpha_fixed_block_id = [
sess.run(
networks._generator_alpha(2, tf.constant(progress, tf.float32)))
for progress in [0, 0.2, 1, 1.2, 2, 2.2, 3]
]
alpha_fixed_progress = [
sess.run(
networks._generator_alpha(block_id, tf.constant(1.2, tf.float32)))
for block_id in range(1, 5)
]
self.assertArrayNear(alpha_fixed_block_id, [0, 0.2, 1, 0.8, 0, 0, 0],
1.0e-6)
self.assertArrayNear(alpha_fixed_progress, [0, 0.8, 0.2, 0], 1.0e-6)
def test_discriminator_alpha(self):
with self.test_session(use_gpu=True) as sess:
alpha_fixed_block_id = [
sess.run(
networks._discriminator_alpha(2, tf.constant(
progress, tf.float32)))
for progress in [0, 0.2, 1, 1.2, 2, 2.2, 3]
]
alpha_fixed_progress = [
sess.run(
networks._discriminator_alpha(block_id,
tf.constant(1.2, tf.float32)))
for block_id in range(1, 5)
]
self.assertArrayNear(alpha_fixed_block_id, [1, 1, 1, 0.8, 0, 0, 0], 1.0e-6)
self.assertArrayNear(alpha_fixed_progress, [0, 0.8, 1, 1], 1.0e-6)
def test_blend_images_in_stable_stage(self):
x_np = np.random.normal(size=[2, 8, 8, 3])
x = tf.constant(x_np, tf.float32)
x_blend = networks.blend_images(
x,
progress=tf.constant(0.0),
resolution_schedule=networks.ResolutionSchedule(
scale_base=2, num_resolutions=2),
num_blocks=2)
with self.test_session(use_gpu=True) as sess:
x_blend_np = sess.run(x_blend)
x_blend_expected_np = sess.run(layers.upscale(layers.downscale(x, 2), 2))
self.assertNDArrayNear(x_blend_np, x_blend_expected_np, 1.0e-6)
def test_blend_images_in_transition_stage(self):
x_np = np.random.normal(size=[2, 8, 8, 3])
x = tf.constant(x_np, tf.float32)
x_blend = networks.blend_images(
x,
tf.constant(0.2),
resolution_schedule=networks.ResolutionSchedule(
scale_base=2, num_resolutions=2),
num_blocks=2)
with self.test_session(use_gpu=True) as sess:
x_blend_np = sess.run(x_blend)
x_blend_expected_np = 0.8 * sess.run(
layers.upscale(layers.downscale(x, 2), 2)) + 0.2 * x_np
self.assertNDArrayNear(x_blend_np, x_blend_expected_np, 1.0e-6)
def test_num_filters(self):
self.assertEqual(networks.num_filters(1, 4096, 1, 256), 256)
self.assertEqual(networks.num_filters(5, 4096, 1, 256), 128)
def test_generator_grad_norm_progress(self):
stable_stage_num_images = 2
transition_stage_num_images = 3
current_image_id_ph = tf.placeholder(tf.int32, [])
progress = networks.compute_progress(
current_image_id_ph,
stable_stage_num_images,
transition_stage_num_images,
num_blocks=3)
z = tf.random_normal([2, 10], dtype=tf.float32)
x, _ = networks.generator(
z, progress, _num_filters_stub,
networks.ResolutionSchedule(
start_resolutions=(4, 4), scale_base=2, num_resolutions=3))
fake_loss = tf.reduce_sum(tf.square(x))
grad_norms = [
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_1/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_2/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_3/.*'))
]
grad_norms_output = None
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
x1_np = sess.run(x, feed_dict={current_image_id_ph: 0.12})
x2_np = sess.run(x, feed_dict={current_image_id_ph: 1.8})
grad_norms_output = np.array([
sess.run(grad_norms, feed_dict={current_image_id_ph: i})
for i in range(15) # total num of images
])
self.assertEqual((2, 16, 16, 3), x1_np.shape)
self.assertEqual((2, 16, 16, 3), x2_np.shape)
# The gradient of block_1 is always on.
self.assertEqual(
np.argmax(grad_norms_output[:, 0] > 0), 0,
'gradient norms {} for block 1 is not always on'.format(
grad_norms_output[:, 0]))
# The gradient of block_2 is on after 1 stable stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 1] > 0), 3,
'gradient norms {} for block 2 is not on at step 3'.format(
grad_norms_output[:, 1]))
# The gradient of block_3 is on after 2 stable stage + 1 transition stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 2] > 0), 8,
'gradient norms {} for block 3 is not on at step 8'.format(
grad_norms_output[:, 2]))
def test_discriminator_grad_norm_progress(self):
stable_stage_num_images = 2
transition_stage_num_images = 3
current_image_id_ph = tf.placeholder(tf.int32, [])
progress = networks.compute_progress(
current_image_id_ph,
stable_stage_num_images,
transition_stage_num_images,
num_blocks=3)
x = tf.random_normal([2, 16, 16, 3])
logits, _ = networks.discriminator(
x, progress, _num_filters_stub,
networks.ResolutionSchedule(
start_resolutions=(4, 4), scale_base=2, num_resolutions=3))
fake_loss = tf.reduce_sum(tf.square(logits))
grad_norms = [
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_1/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_2/.*')),
_get_grad_norm(
fake_loss, tf.trainable_variables('.*/progressive_gan_block_3/.*'))
]
grad_norms_output = None
with self.test_session(use_gpu=True) as sess:
sess.run(tf.global_variables_initializer())
grad_norms_output = np.array([
sess.run(grad_norms, feed_dict={current_image_id_ph: i})
for i in range(15) # total num of images
])
# The gradient of block_1 is always on.
self.assertEqual(
np.argmax(grad_norms_output[:, 0] > 0), 0,
'gradient norms {} for block 1 is not always on'.format(
grad_norms_output[:, 0]))
# The gradient of block_2 is on after 1 stable stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 1] > 0), 3,
'gradient norms {} for block 2 is not on at step 3'.format(
grad_norms_output[:, 1]))
# The gradient of block_3 is on after 2 stable stage + 1 transition stage.
self.assertEqual(
np.argmax(grad_norms_output[:, 2] > 0), 8,
'gradient norms {} for block 3 is not on at step 8'.format(
grad_norms_output[:, 2]))
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import numpy as np
import tensorflow as tf
import networks
tfgan = tf.contrib.gan
def make_train_sub_dir(stage_id, **kwargs):
"""Returns the log directory for training stage `stage_id`."""
return os.path.join(kwargs['train_root_dir'], 'stage_{:05d}'.format(stage_id))
def make_resolution_schedule(**kwargs):
"""Returns an object of `ResolutionSchedule`."""
return networks.ResolutionSchedule(
start_resolutions=(kwargs['start_height'], kwargs['start_width']),
scale_base=kwargs['scale_base'],
num_resolutions=kwargs['num_resolutions'])
def get_stage_ids(**kwargs):
"""Returns a list of stage ids.
Args:
**kwargs: A dictionary of
'train_root_dir': A string of root directory of training logs.
'num_resolutions': An integer of number of progressive resolutions.
"""
train_sub_dirs = [
sub_dir for sub_dir in tf.gfile.ListDirectory(kwargs['train_root_dir'])
if sub_dir.startswith('stage_')
]
# If fresh start, start with start_stage_id = 0
# If has been trained for n = len(train_sub_dirs) stages, start with the last
# stage, i.e. start_stage_id = n - 1.
start_stage_id = max(0, len(train_sub_dirs) - 1)
return range(start_stage_id, get_total_num_stages(**kwargs))
def get_total_num_stages(**kwargs):
"""Returns total number of training stages."""
return 2 * kwargs['num_resolutions'] - 1
def get_stage_info(stage_id, **kwargs):
"""Returns information for a training stage.
Args:
stage_id: An integer of training stage index.
**kwargs: A dictionary of
'num_resolutions': An integer of number of progressive resolutions.
'stable_stage_num_images': An integer of number of training images in
the stable stage.
'transition_stage_num_images': An integer of number of training images
in the transition stage.
'total_num_images': An integer of total number of training images.
Returns:
A tuple of integers. The first entry is the number of blocks. The second
entry is the accumulated total number of training images when stage
`stage_id` is finished.
Raises:
ValueError: If `stage_id` is not in [0, total number of stages).
"""
total_num_stages = get_total_num_stages(**kwargs)
if not (stage_id >= 0 and stage_id < total_num_stages):
raise ValueError(
'`stage_id` must be in [0, {0}), but instead was {1}'.format(
total_num_stages, stage_id))
# Even stage_id: stable training stage.
# Odd stage_id: transition training stage.
num_blocks = (stage_id + 1) // 2 + 1
num_images = ((stage_id // 2 + 1) * kwargs['stable_stage_num_images'] + (
(stage_id + 1) // 2) * kwargs['transition_stage_num_images'])
total_num_images = kwargs['total_num_images']
if stage_id >= total_num_stages - 1:
num_images = total_num_images
num_images = min(num_images, total_num_images)
return num_blocks, num_images
def make_latent_vectors(num, **kwargs):
"""Returns a batch of `num` random latent vectors."""
return tf.random_normal([num, kwargs['latent_vector_size']], dtype=tf.float32)
def make_interpolated_latent_vectors(num_rows, num_columns, **kwargs):
"""Returns a batch of linearly interpolated latent vectors.
Given two randomly generated latent vector za and zb, it can generate
a row of `num_columns` interpolated latent vectors, i.e.
[..., za + (zb - za) * i / (num_columns - 1), ...] where
i = 0, 1, ..., `num_columns` - 1.
This function produces `num_rows` such rows and returns a (flattened)
batch of latent vectors with batch size `num_rows * num_columns`.
Args:
num_rows: An integer. Number of rows of interpolated latent vectors.
num_columns: An integer. Number of interpolated latent vectors in each row.
**kwargs: A dictionary of
'latent_vector_size': An integer of latent vector size.
Returns:
A `Tensor` of shape `[num_rows * num_columns, latent_vector_size]`.
"""
ans = []
for _ in range(num_rows):
z = tf.random_normal([2, kwargs['latent_vector_size']])
r = tf.reshape(
tf.to_float(tf.range(num_columns)) / (num_columns - 1), [-1, 1])
dz = z[1] - z[0]
ans.append(z[0] + tf.stack([dz] * num_columns) * r)
return tf.concat(ans, axis=0)
def define_loss(gan_model, **kwargs):
"""Defines progressive GAN losses.
The generator and discriminator both use wasserstein loss. In addition,
a small penalty term is added to the discriminator loss to prevent it getting
too large.
Args:
gan_model: A `GANModel` namedtuple.
**kwargs: A dictionary of
'gradient_penalty_weight': A float of gradient norm target for
wasserstein loss.
'gradient_penalty_target': A float of gradient penalty weight for
wasserstein loss.
'real_score_penalty_weight': A float of Additional penalty to keep
the scores from drifting too far from zero.
Returns:
A `GANLoss` namedtuple.
"""
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
gradient_penalty_weight=kwargs['gradient_penalty_weight'],
gradient_penalty_target=kwargs['gradient_penalty_target'],
gradient_penalty_epsilon=0.0)
real_score_penalty = tf.reduce_mean(
tf.square(gan_model.discriminator_real_outputs))
tf.summary.scalar('real_score_penalty', real_score_penalty)
return gan_loss._replace(
discriminator_loss=(
gan_loss.discriminator_loss +
kwargs['real_score_penalty_weight'] * real_score_penalty))
def define_train_ops(gan_model, gan_loss, **kwargs):
"""Defines progressive GAN train ops.
Args:
gan_model: A `GANModel` namedtuple.
gan_loss: A `GANLoss` namedtuple.
**kwargs: A dictionary of
'adam_beta1': A float of Adam optimizer beta1.
'adam_beta2': A float of Adam optimizer beta2.
'generator_learning_rate': A float of generator learning rate.
'discriminator_learning_rate': A float of discriminator learning rate.
Returns:
A tuple of `GANTrainOps` namedtuple and a list variables tracking the state
of optimizers.
"""
with tf.variable_scope('progressive_gan_train_ops') as var_scope:
beta1, beta2 = kwargs['adam_beta1'], kwargs['adam_beta2']
gen_opt = tf.train.AdamOptimizer(kwargs['generator_learning_rate'], beta1,
beta2)
dis_opt = tf.train.AdamOptimizer(kwargs['discriminator_learning_rate'],
beta1, beta2)
gan_train_ops = tfgan.gan_train_ops(gan_model, gan_loss, gen_opt, dis_opt)
return gan_train_ops, tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope.name)
def add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops):
"""Adds generator smoothing ops."""
with tf.control_dependencies([gan_train_ops.generator_train_op]):
new_generator_train_op = generator_ema.apply(gan_model.generator_variables)
gan_train_ops = gan_train_ops._replace(
generator_train_op=new_generator_train_op)
generator_vars_to_restore = generator_ema.variables_to_restore(
gan_model.generator_variables)
return gan_train_ops, generator_vars_to_restore
def build_model(stage_id, real_images, **kwargs):
"""Builds progressive GAN model.
Args:
stage_id: An integer of training stage index.
real_images: A 4D `Tensor` of NHWC format.
**kwargs: A dictionary of
'batch_size': Number of training images in each minibatch.
'start_height': An integer of start image height.
'start_width': An integer of start image width.
'scale_base': An integer of resolution multiplier.
'num_resolutions': An integer of number of progressive resolutions.
'stable_stage_num_images': An integer of number of training images in
the stable stage.
'transition_stage_num_images': An integer of number of training images
in the transition stage.
'total_num_images': An integer of total number of training images.
'kernel_size': Convolution kernel size.
'colors': Number of image channels.
'to_rgb_use_tanh_activation': Whether to apply tanh activation when
output rgb.
'fmap_base': Base number of filters.
'fmap_decay': Decay of number of filters.
'fmap_max': Max number of filters.
'latent_vector_size': An integer of latent vector size.
'gradient_penalty_weight': A float of gradient norm target for
wasserstein loss.
'gradient_penalty_target': A float of gradient penalty weight for
wasserstein loss.
'real_score_penalty_weight': A float of Additional penalty to keep
the scores from drifting too far from zero.
'adam_beta1': A float of Adam optimizer beta1.
'adam_beta2': A float of Adam optimizer beta2.
'generator_learning_rate': A float of generator learning rate.
'discriminator_learning_rate': A float of discriminator learning rate.
Returns:
An inernal object that wraps all information about the model.
"""
batch_size = kwargs['batch_size']
kernel_size = kwargs['kernel_size']
colors = kwargs['colors']
resolution_schedule = make_resolution_schedule(**kwargs)
num_blocks, num_images = get_stage_info(stage_id, **kwargs)
global_step = tf.train.get_or_create_global_step()
current_image_id = global_step * batch_size
tf.summary.scalar('current_image_id', current_image_id)
progress = networks.compute_progress(
current_image_id, kwargs['stable_stage_num_images'],
kwargs['transition_stage_num_images'], num_blocks)
tf.summary.scalar('progress', progress)
real_images = networks.blend_images(
real_images, progress, resolution_schedule, num_blocks=num_blocks)
def _num_filters_fn(block_id):
"""Computes number of filters of block `block_id`."""
return networks.num_filters(block_id, kwargs['fmap_base'],
kwargs['fmap_decay'], kwargs['fmap_max'])
def _generator_fn(z):
"""Builds generator network."""
return networks.generator(
z,
progress,
_num_filters_fn,
resolution_schedule,
num_blocks=num_blocks,
kernel_size=kernel_size,
colors=colors,
to_rgb_activation=(tf.tanh
if kwargs['to_rgb_use_tanh_activation'] else None))
def _discriminator_fn(x):
"""Builds discriminator network."""
return networks.discriminator(
x,
progress,
_num_filters_fn,
resolution_schedule,
num_blocks=num_blocks,
kernel_size=kernel_size)
########## Define model.
z = make_latent_vectors(batch_size, **kwargs)
gan_model = tfgan.gan_model(
generator_fn=lambda z: _generator_fn(z)[0],
discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0],
real_data=real_images,
generator_inputs=z)
########## Define loss.
gan_loss = define_loss(gan_model, **kwargs)
########## Define train ops.
gan_train_ops, optimizer_var_list = define_train_ops(gan_model, gan_loss,
**kwargs)
########## Generator smoothing.
generator_ema = tf.train.ExponentialMovingAverage(decay=0.999)
gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops(
generator_ema, gan_model, gan_train_ops)
class Model(object):
pass
model = Model()
model.resolution_schedule = resolution_schedule
model.stage_id = stage_id
model.num_images = num_images
model.num_blocks = num_blocks
model.global_step = global_step
model.current_image_id = current_image_id
model.progress = progress
model.num_filters_fn = _num_filters_fn
model.generator_fn = _generator_fn
model.discriminator_fn = _discriminator_fn
model.gan_model = gan_model
model.gan_loss = gan_loss
model.gan_train_ops = gan_train_ops
model.optimizer_var_list = optimizer_var_list
model.generator_ema = generator_ema
model.generator_vars_to_restore = generator_vars_to_restore
return model
def make_var_scope_custom_getter_for_ema(ema):
"""Makes variable scope custom getter."""
def _custom_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var else var
return _custom_getter
def add_model_summaries(model, **kwargs):
"""Adds model summaries.
This function adds several useful summaries during training:
- fake_images: A grid of fake images based on random latent vectors.
- interp_images: A grid of fake images based on interpolated latent vectors.
- real_images_blend: A grid of real images.
- summaries for `gan_model` losses, variable distributions etc.
Args:
model: An model object having all information of progressive GAN model,
e.g. the return of build_model().
**kwargs: A dictionary of
'batch_size': Number of training images in each minibatch.
'fake_grid_size': The fake image grid size for summaries.
'interp_grid_size': The latent space interpolated image grid size for
summaries.
'colors': Number of image channels.
'latent_vector_size': An integer of latent vector size.
"""
fake_grid_size = kwargs['fake_grid_size']
interp_grid_size = kwargs['interp_grid_size']
colors = kwargs['colors']
image_shape = list(model.resolution_schedule.final_resolutions)
fake_batch_size = fake_grid_size**2
fake_images_shape = [fake_batch_size] + image_shape + [colors]
interp_batch_size = interp_grid_size**2
interp_images_shape = [interp_batch_size] + image_shape + [colors]
# When making prediction, use the ema smoothed generator vars.
with tf.variable_scope(
model.gan_model.generator_scope,
reuse=True,
custom_getter=make_var_scope_custom_getter_for_ema(model.generator_ema)):
z_fake = make_latent_vectors(fake_batch_size, **kwargs)
fake_images = model.gan_model.generator_fn(z_fake)
fake_images.set_shape(fake_images_shape)
z_interp = make_interpolated_latent_vectors(interp_grid_size,
interp_grid_size, **kwargs)
interp_images = model.gan_model.generator_fn(z_interp)
interp_images.set_shape(interp_images_shape)
tf.summary.image(
'fake_images',
tfgan.eval.eval_utils.image_grid(
fake_images,
grid_shape=[fake_grid_size] * 2,
image_shape=image_shape,
num_channels=colors),
max_outputs=1)
tf.summary.image(
'interp_images',
tfgan.eval.eval_utils.image_grid(
interp_images,
grid_shape=[interp_grid_size] * 2,
image_shape=image_shape,
num_channels=colors),
max_outputs=1)
real_grid_size = int(np.sqrt(kwargs['batch_size']))
tf.summary.image(
'real_images_blend',
tfgan.eval.eval_utils.image_grid(
model.gan_model.real_data[:real_grid_size**2],
grid_shape=(real_grid_size, real_grid_size),
image_shape=image_shape,
num_channels=colors),
max_outputs=1)
tfgan.eval.add_gan_model_summaries(model.gan_model)
def make_scaffold(stage_id, optimizer_var_list, **kwargs):
"""Makes a custom scaffold.
The scaffold
- restores variables from the last training stage.
- initializes new variables in the new block.
Args:
stage_id: An integer of stage id.
optimizer_var_list: A list of optimizer variables.
**kwargs: A dictionary of
'train_root_dir': A string of root directory of training logs.
'num_resolutions': An integer of number of progressive resolutions.
'stable_stage_num_images': An integer of number of training images in
the stable stage.
'transition_stage_num_images': An integer of number of training images
in the transition stage.
'total_num_images': An integer of total number of training images.
Returns:
A `Scaffold` object.
"""
# Holds variables that from the previous stage and need to be restored.
restore_var_list = []
prev_ckpt = None
curr_ckpt = tf.train.latest_checkpoint(make_train_sub_dir(stage_id, **kwargs))
if stage_id > 0 and curr_ckpt is None:
prev_ckpt = tf.train.latest_checkpoint(
make_train_sub_dir(stage_id - 1, **kwargs))
num_blocks, _ = get_stage_info(stage_id, **kwargs)
prev_num_blocks, _ = get_stage_info(stage_id - 1, **kwargs)
# Holds variables created in the new block of the current stage. If the
# current stage is a stable stage (except the initial stage), this list
# will be empty.
new_block_var_list = []
for block_id in range(prev_num_blocks + 1, num_blocks + 1):
new_block_var_list.extend(
tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope='.*/{}/'.format(networks.block_name(block_id))))
# Every variables that are 1) not for optimizers and 2) from the new block
# need to be restored.
restore_var_list = [
var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if var not in set(optimizer_var_list + new_block_var_list)
]
# Add saver op to graph. This saver is used to restore variables from the
# previous stage.
saver_for_restore = tf.train.Saver(
var_list=restore_var_list, allow_empty=True)
# Add the op to graph that initializes all global variables.
init_op = tf.global_variables_initializer()
def _init_fn(unused_scaffold, sess):
# First initialize every variables.
sess.run(init_op)
logging.info('\n'.join([var.name for var in restore_var_list]))
# Then overwrite variables saved in previous stage.
if prev_ckpt is not None:
saver_for_restore.restore(sess, prev_ckpt)
# Use a dummy init_op here as all initialization is done in init_fn.
return tf.train.Scaffold(init_op=tf.constant([]), init_fn=_init_fn)
def make_status_message(model):
"""Makes a string `Tensor` of training status."""
return tf.string_join(
[
'Starting train step: ',
tf.as_string(model.global_step), ', current_image_id: ',
tf.as_string(model.current_image_id), ', progress: ',
tf.as_string(model.progress), ', num_blocks: {}'.format(
model.num_blocks)
],
name='status_message')
def train(model, **kwargs):
"""Trains progressive GAN for stage `stage_id`.
Args:
model: An model object having all information of progressive GAN model,
e.g. the return of build_model().
**kwargs: A dictionary of
'train_root_dir': A string of root directory of training logs.
'master': Name of the TensorFlow master to use.
'task': The Task ID. This value is used when training with multiple
workers to identify each worker.
'save_summaries_num_images': Save summaries in this number of images.
Returns:
None.
"""
batch_size = kwargs['batch_size']
logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
model.num_blocks, model.num_images)
scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs)
tfgan.gan_train(
model.gan_train_ops,
logdir=make_train_sub_dir(model.stage_id, **kwargs),
get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)),
hooks=[
tf.train.StopAtStepHook(last_step=model.num_images // batch_size),
tf.train.LoggingTensorHook(
[make_status_message(model)], every_n_iter=10)
],
master=kwargs['master'],
is_chief=(kwargs['task'] == 0),
scaffold=scaffold,
save_checkpoint_secs=600,
save_summaries_steps=(kwargs['save_summaries_num_images'] // batch_size))
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train a progressive GAN model.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from absl import flags
from absl import logging
import tensorflow as tf
import data_provider
import train
tfgan = tf.contrib.gan
flags.DEFINE_string('dataset_name', 'cifar10', 'Dataset name.')
flags.DEFINE_string('dataset_file_pattern', '', 'Dataset file pattern.')
flags.DEFINE_integer('start_height', 4, 'Start image height.')
flags.DEFINE_integer('start_width', 4, 'Start image width.')
flags.DEFINE_integer('scale_base', 2, 'Resolution multiplier.')
flags.DEFINE_integer('num_resolutions', 4, 'Number of progressive resolutions.')
flags.DEFINE_integer('kernel_size', 3, 'Convolution kernel size.')
flags.DEFINE_integer('colors', 3, 'Number of image channels.')
flags.DEFINE_bool('to_rgb_use_tanh_activation', False,
'Whether to apply tanh activation when output rgb.')
flags.DEFINE_integer('batch_size', 8, 'Number of images in each batch.')
flags.DEFINE_integer('stable_stage_num_images', 1000,
'Number of images in the stable stage.')
flags.DEFINE_integer('transition_stage_num_images', 1000,
'Number of images in the transition stage.')
flags.DEFINE_integer('total_num_images', 10000, 'Total number of images.')
flags.DEFINE_integer('save_summaries_num_images', 100,
'Save summaries in this number of images.')
flags.DEFINE_integer('latent_vector_size', 128, 'Latent vector size.')
flags.DEFINE_integer('fmap_base', 4096, 'Base number of filters.')
flags.DEFINE_float('fmap_decay', 1.0, 'Decay of number of filters.')
flags.DEFINE_integer('fmap_max', 128, 'Max number of filters.')
flags.DEFINE_float('gradient_penalty_target', 1.0,
'Gradient norm target for wasserstein loss.')
flags.DEFINE_float('gradient_penalty_weight', 10.0,
'Gradient penalty weight for wasserstein loss.')
flags.DEFINE_float('real_score_penalty_weight', 0.001,
'Additional penalty to keep the scores from drifting too '
'far from zero.')
flags.DEFINE_float('generator_learning_rate', 0.001, 'Learning rate.')
flags.DEFINE_float('discriminator_learning_rate', 0.001, 'Learning rate.')
flags.DEFINE_float('adam_beta1', 0.0, 'Adam beta 1.')
flags.DEFINE_float('adam_beta2', 0.99, 'Adam beta 2.')
flags.DEFINE_integer('fake_grid_size', 8, 'The fake image grid size for eval.')
flags.DEFINE_integer('interp_grid_size', 8,
'The interp image grid size for eval.')
flags.DEFINE_string('train_root_dir', '/tmp/progressive_gan/',
'Directory where to write event logs.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
FLAGS = flags.FLAGS
def _make_config_from_flags():
"""Makes a config dictionary from commandline flags."""
return dict([(flag.name, flag.value)
for flag in FLAGS.get_key_flags_for_module(sys.argv[0])])
def _provide_real_images(**kwargs):
"""Provides real images."""
dataset_name = kwargs.get('dataset_name')
dataset_file_pattern = kwargs.get('dataset_file_pattern')
batch_size = kwargs['batch_size']
colors = kwargs['colors']
final_height, final_width = train.make_resolution_schedule(
**kwargs).final_resolutions
if dataset_name is not None:
return data_provider.provide_data(
dataset_name=dataset_name,
split_name='train',
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
elif dataset_file_pattern is not None:
return data_provider.provide_data_from_image_files(
file_pattern=dataset_file_pattern,
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
def main(_):
if not tf.gfile.Exists(FLAGS.train_root_dir):
tf.gfile.MakeDirs(FLAGS.train_root_dir)
config = _make_config_from_flags()
logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))
for stage_id in train.get_stage_ids(**config):
tf.reset_default_graph()
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
real_images = None
with tf.device('/cpu:0'), tf.name_scope('inputs'):
real_images = _provide_real_images(**config)
model = train.build_model(stage_id, real_images, **config)
train.add_model_summaries(model, **config)
train.train(model, **config)
if __name__ == '__main__':
tf.app.run()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from absl.testing import absltest
import tensorflow as tf
import train
FLAGS = flags.FLAGS
def provide_random_data(batch_size=2, patch_size=8, colors=1, **unused_kwargs):
return tf.random_normal([batch_size, patch_size, patch_size, colors])
class TrainTest(absltest.TestCase):
def setUp(self):
self._config = {
'start_height': 4,
'start_width': 4,
'scale_base': 2,
'num_resolutions': 2,
'colors': 1,
'to_rgb_use_tanh_activation': True,
'kernel_size': 3,
'batch_size': 2,
'stable_stage_num_images': 4,
'transition_stage_num_images': 4,
'total_num_images': 12,
'save_summaries_num_images': 4,
'latent_vector_size': 8,
'fmap_base': 8,
'fmap_decay': 1.0,
'fmap_max': 8,
'gradient_penalty_target': 1.0,
'gradient_penalty_weight': 10.0,
'real_score_penalty_weight': 0.001,
'generator_learning_rate': 0.001,
'discriminator_learning_rate': 0.001,
'adam_beta1': 0.0,
'adam_beta2': 0.99,
'fake_grid_size': 2,
'interp_grid_size': 2,
'train_root_dir': os.path.join(FLAGS.test_tmpdir, 'progressive_gan'),
'master': '',
'task': 0
}
def test_train_success(self):
train_root_dir = self._config['train_root_dir']
if not tf.gfile.Exists(train_root_dir):
tf.gfile.MakeDirs(train_root_dir)
for stage_id in train.get_stage_ids(**self._config):
tf.reset_default_graph()
real_images = provide_random_data()
model = train.build_model(stage_id, real_images, **self._config)
train.add_model_summaries(model, **self._config)
train.train(model, **self._config)
if __name__ == '__main__':
absltest.main()
...@@ -25,6 +25,7 @@ from __future__ import print_function ...@@ -25,6 +25,7 @@ from __future__ import print_function
import argparse import argparse
import os import os
import sys
import tarfile import tarfile
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
...@@ -63,7 +64,10 @@ def _get_file_names(): ...@@ -63,7 +64,10 @@ def _get_file_names():
def read_pickle_from_file(filename): def read_pickle_from_file(filename):
with tf.gfile.Open(filename, 'rb') as f: with tf.gfile.Open(filename, 'rb') as f:
data_dict = pickle.load(f) if sys.version_info >= (3, 0):
data_dict = pickle.load(f, encoding='bytes')
else:
data_dict = pickle.load(f)
return data_dict return data_dict
...@@ -73,8 +77,8 @@ def convert_to_tfrecord(input_files, output_file): ...@@ -73,8 +77,8 @@ def convert_to_tfrecord(input_files, output_file):
with tf.python_io.TFRecordWriter(output_file) as record_writer: with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files: for input_file in input_files:
data_dict = read_pickle_from_file(input_file) data_dict = read_pickle_from_file(input_file)
data = data_dict['data'] data = data_dict[b'data']
labels = data_dict['labels'] labels = data_dict[b'labels']
num_entries_in_batch = len(labels) num_entries_in_batch = len(labels)
for i in range(num_entries_in_batch): for i in range(num_entries_in_batch):
example = tf.train.Example(features=tf.train.Features( example = tf.train.Example(features=tf.train.Features(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment