Commit 4dcd5116 authored by Jing Li's avatar Jing Li
Browse files

Removed deprecated research/gan. Please visit https://github.com/tensorflow/gan.

parent f673f7a8
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import layers
class LayersTest(tf.test.TestCase):
def test_residual_block(self):
n = 2
h = 32
w = h
c = 256
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers._residual_block(
input_net=input_tensor,
num_outputs=c,
kernel_size=3,
stride=1,
padding_size=1)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, c), output.shape)
def test_generator_down_sample(self):
n = 2
h = 128
w = h
c = 3 + 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.generator_down_sample(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h // 4, w // 4, 256), output.shape)
def test_generator_bottleneck(self):
n = 2
h = 32
w = h
c = 256
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.generator_bottleneck(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, c), output.shape)
def test_generator_up_sample(self):
n = 2
h = 32
w = h
c = 256
c_out = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.generator_up_sample(input_tensor, c_out)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h * 4, w * 4, c_out), output.shape)
def test_discriminator_input_hidden(self):
n = 2
h = 128
w = 128
c = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.discriminator_input_hidden(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, 2, 2, 2048), output.shape)
def test_discriminator_output_source(self):
n = 2
h = 2
w = 2
c = 2048
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.discriminator_output_source(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, 1), output.shape)
def test_discriminator_output_class(self):
n = 2
h = 2
w = 2
c = 2048
num_domain = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.discriminator_output_class(input_tensor, num_domain)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, num_domain), output.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Neural network for a StarGAN model.
This module contains the Generator and Discriminator Neural Network to build a
StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorch implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import layers
import ops
def generator(inputs, targets):
"""Generator module.
Piece everything together for the Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22
Args:
inputs: Tensor of shape (batch_size, h, w, c) representing the
images/information that we want to transform.
targets: Tensor of shape (batch_size, num_domains) representing the target
domain the generator should transform the image/information to.
Returns:
Tensor of shape (batch_size, h, w, c) as the inputs.
"""
with tf.variable_scope('generator'):
input_with_condition = ops.condition_input_with_pixel_padding(
inputs, targets)
down_sample = layers.generator_down_sample(input_with_condition)
bottleneck = layers.generator_bottleneck(down_sample)
up_sample = layers.generator_up_sample(bottleneck, inputs.shape[-1])
return up_sample
def discriminator(input_net, class_num):
"""Discriminator Module.
Piece everything together and reshape the output source tensor
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L63
Notes:
The PyTorch Version run the reduce_mean operation later in their solver:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/solver.py#L245
Args:
input_net: Tensor of shape (batch_size, h, w, c) as batch of images.
class_num: (int) number of domain to be predicted
Returns:
output_src: Tensor of shape (batch_size) where each value is a logit
representing whether the image is real of fake.
output_cls: Tensor of shape (batch_size, class_um) where each value is a
logit representing whether the image is in the associated domain.
"""
with tf.variable_scope('discriminator'):
hidden = layers.discriminator_input_hidden(input_net)
output_src = layers.discriminator_output_source(hidden)
output_src = tf.contrib.layers.flatten(output_src)
output_src = tf.reduce_mean(output_src, axis=1)
output_cls = layers.discriminator_output_class(hidden, class_num)
return output_src, output_cls
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import network
class NetworkTest(tf.test.TestCase):
def test_generator(self):
n = 2
h = 128
w = h
c = 4
class_num = 3
input_tensor = tf.random_uniform((n, h, w, c))
target_tensor = tf.random_uniform((n, class_num))
output_tensor = network.generator(input_tensor, target_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, c), output.shape)
def test_discriminator(self):
n = 2
h = 128
w = h
c = 3
class_num = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_src_tensor, output_cls_tensor = network.discriminator(
input_tensor, class_num)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output_src, output_cls = sess.run([output_src_tensor, output_cls_tensor])
self.assertEqual(1, len(output_src.shape))
self.assertEqual(n, output_src.shape[0])
self.assertTupleEqual((n, class_num), output_cls.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for a StarGAN model.
This module contains basic ops to build a StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorvh implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def _padding_arg(h, w, input_format):
"""Calculate the padding shape for tf.pad().
Args:
h: (int) padding on the height dim.
w: (int) padding on the width dim.
input_format: (string) the input format as in 'NHWC' or 'HWC'.
Raises:
ValueError: If input_format is not 'NHWC' or 'HWC'.
Returns:
A two dimension array representing the padding argument.
"""
if input_format == 'NHWC':
return [[0, 0], [h, h], [w, w], [0, 0]]
elif input_format == 'HWC':
return [[h, h], [w, w], [0, 0]]
else:
raise ValueError('Input Format %s is not supported.' % input_format)
def pad(input_net, padding_size):
"""Padding the tensor with padding_size on both the height and width dim.
Args:
input_net: Tensor in 3D ('HWC') or 4D ('NHWC').
padding_size: (int) the size of the padding.
Notes:
Original StarGAN use zero padding instead of mirror padding.
Raises:
ValueError: If input_net Tensor is not 3D or 4D.
Returns:
Tensor with same rank as input_net but with padding on the height and width
dim.
"""
if len(input_net.shape) == 4:
return tf.pad(input_net, _padding_arg(padding_size, padding_size, 'NHWC'))
elif len(input_net.shape) == 3:
return tf.pad(input_net, _padding_arg(padding_size, padding_size, 'HWC'))
else:
raise ValueError('The input tensor need to be either 3D or 4D.')
def condition_input_with_pixel_padding(input_tensor, condition_tensor):
"""Pad image tensor with condition tensor as additional color channel.
Args:
input_tensor: Tensor of shape (batch_size, h, w, c) representing images.
condition_tensor: Tensor of shape (batch_size, num_domains) representing the
associated domain for the image in input_tensor.
Returns:
Tensor of shape (batch_size, h, w, c + num_domains) representing the
conditioned data.
Raises:
ValueError: If `input_tensor` isn't rank 4.
ValueError: If `condition_tensor` isn't rank 2.
ValueError: If dimension 1 of the input_tensor and condition_tensor is not
the same.
"""
input_tensor.shape.assert_has_rank(4)
condition_tensor.shape.assert_has_rank(2)
input_tensor.shape[:1].assert_is_compatible_with(condition_tensor.shape[:1])
condition_tensor = tf.expand_dims(condition_tensor, axis=1)
condition_tensor = tf.expand_dims(condition_tensor, axis=1)
condition_tensor = tf.tile(
condition_tensor, [1, input_tensor.shape[1], input_tensor.shape[2], 1])
return tf.concat([input_tensor, condition_tensor], -1)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import ops
class OpsTest(tf.test.TestCase):
def test_padding_arg(self):
pad_h = 2
pad_w = 3
self.assertListEqual([[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]],
ops._padding_arg(pad_h, pad_w, 'NHWC'))
def test_padding_arg_specify_format(self):
pad_h = 2
pad_w = 3
self.assertListEqual([[pad_h, pad_h], [pad_w, pad_w], [0, 0]],
ops._padding_arg(pad_h, pad_w, 'HWC'))
def test_padding_arg_invalid_format(self):
pad_h = 2
pad_w = 3
with self.assertRaises(ValueError):
ops._padding_arg(pad_h, pad_w, 'INVALID')
def test_padding(self):
n = 2
h = 128
w = 64
c = 3
pad = 3
test_input_tensor = tf.random_uniform((n, h, w, c))
test_output_tensor = ops.pad(test_input_tensor, padding_size=pad)
with self.test_session() as sess:
output = sess.run(test_output_tensor)
self.assertTupleEqual((n, h + pad * 2, w + pad * 2, c), output.shape)
def test_padding_with_3D_tensor(self):
h = 128
w = 64
c = 3
pad = 3
test_input_tensor = tf.random_uniform((h, w, c))
test_output_tensor = ops.pad(test_input_tensor, padding_size=pad)
with self.test_session() as sess:
output = sess.run(test_output_tensor)
self.assertTupleEqual((h + pad * 2, w + pad * 2, c), output.shape)
def test_padding_with_tensor_of_invalid_shape(self):
n = 2
invalid_rank = 1
h = 128
w = 64
c = 3
pad = 3
test_input_tensor = tf.random_uniform((n, invalid_rank, h, w, c))
with self.assertRaises(ValueError):
ops.pad(test_input_tensor, padding_size=pad)
def test_condition_input_with_pixel_padding(self):
n = 2
h = 128
w = h
c = 3
num_label = 5
input_tensor = tf.random_uniform((n, h, w, c))
label_tensor = tf.random_uniform((n, num_label))
output_tensor = ops.condition_input_with_pixel_padding(
input_tensor, label_tensor)
with self.test_session() as sess:
labels, outputs = sess.run([label_tensor, output_tensor])
self.assertTupleEqual((n, h, w, c + num_label), outputs.shape)
for label, output in zip(labels, outputs):
for i in range(output.shape[0]):
for j in range(output.shape[1]):
self.assertListEqual(label.tolist(), output[i, j, c:].tolist())
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a StarGAN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
import data_provider
import network
# FLAGS for data.
flags.DEFINE_multi_string(
'image_file_patterns', None,
'List of file pattern for different domain of images. '
'(e.g.[\'black_hair\', \'blond_hair\', \'brown_hair\']')
flags.DEFINE_integer('batch_size', 6, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 128, 'The patch size of images.')
flags.DEFINE_string('train_log_dir', '/tmp/stargan/',
'Directory where to write event logs.')
# FLAGS for training hyper-parameters.
flags.DEFINE_float('generator_lr', 1e-4, 'The generator learning rate.')
flags.DEFINE_float('discriminator_lr', 1e-4, 'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 1000000,
'The maximum number of gradient steps.')
flags.DEFINE_float('adam_beta1', 0.5, 'Adam Beta 1 for the Adam optimizer.')
flags.DEFINE_float('adam_beta2', 0.999, 'Adam Beta 2 for the Adam optimizer.')
flags.DEFINE_float('gen_disc_step_ratio', 0.2,
'Generator:Discriminator training step ratio.')
# FLAGS for distributed training.
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
FLAGS = flags.FLAGS
tfgan = tf.contrib.gan
def _define_model(images, labels):
"""Create the StarGAN Model.
Args:
images: `Tensor` or list of `Tensor` of shape (N, H, W, C).
labels: `Tensor` or list of `Tensor` of shape (N, num_domains).
Returns:
`StarGANModel` namedtuple.
"""
return tfgan.stargan_model(
generator_fn=network.generator,
discriminator_fn=network.discriminator,
input_data=images,
input_data_domain_label=labels)
def _get_lr(base_lr):
"""Returns a learning rate `Tensor`.
Args:
base_lr: A scalar float `Tensor` or a Python number. The base learning
rate.
Returns:
A scalar float `Tensor` of learning rate which equals `base_lr` when the
global training step is less than FLAGS.max_number_of_steps / 2, afterwards
it linearly decays to zero.
"""
global_step = tf.train.get_or_create_global_step()
lr_constant_steps = FLAGS.max_number_of_steps // 2
def _lr_decay():
return tf.train.polynomial_decay(
learning_rate=base_lr,
global_step=(global_step - lr_constant_steps),
decay_steps=(FLAGS.max_number_of_steps - lr_constant_steps),
end_learning_rate=0.0)
return tf.cond(global_step < lr_constant_steps, lambda: base_lr, _lr_decay)
def _get_optimizer(gen_lr, dis_lr):
"""Returns generator optimizer and discriminator optimizer.
Args:
gen_lr: A scalar float `Tensor` or a Python number. The Generator learning
rate.
dis_lr: A scalar float `Tensor` or a Python number. The Discriminator
learning rate.
Returns:
A tuple of generator optimizer and discriminator optimizer.
"""
gen_opt = tf.train.AdamOptimizer(
gen_lr, beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, use_locking=True)
dis_opt = tf.train.AdamOptimizer(
dis_lr, beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, use_locking=True)
return gen_opt, dis_opt
def _define_train_ops(model, loss):
"""Defines train ops that trains `stargan_model` with `stargan_loss`.
Args:
model: A `StarGANModel` namedtuple.
loss: A `StarGANLoss` namedtuple containing all losses for
`stargan_model`.
Returns:
A `GANTrainOps` namedtuple.
"""
gen_lr = _get_lr(FLAGS.generator_lr)
dis_lr = _get_lr(FLAGS.discriminator_lr)
gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr)
train_ops = tfgan.gan_train_ops(
model,
loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
return train_ops
def _define_train_step():
"""Get the training step for generator and discriminator for each GAN step.
Returns:
GANTrainSteps namedtuple representing the training step configuration.
"""
if FLAGS.gen_disc_step_ratio <= 1:
discriminator_step = int(1 / FLAGS.gen_disc_step_ratio)
return tfgan.GANTrainSteps(1, discriminator_step)
else:
generator_step = int(FLAGS.gen_disc_step_ratio)
return tfgan.GANTrainSteps(generator_step, 1)
def main(_):
# Create the log_dir if not exist.
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
# Shard the model to different parameter servers.
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Create the input dataset.
with tf.name_scope('inputs'):
images, labels = data_provider.provide_data(
FLAGS.image_file_patterns, FLAGS.batch_size, FLAGS.patch_size)
# Define the model.
with tf.name_scope('model'):
model = _define_model(images, labels)
# Add image summary.
tfgan.eval.add_stargan_image_summaries(
model,
num_images=len(FLAGS.image_file_patterns) * FLAGS.batch_size,
display_diffs=True)
# Define the model loss.
loss = tfgan.stargan_loss(model)
# Define the train ops.
with tf.name_scope('train_ops'):
train_ops = _define_train_ops(model, loss)
# Define the train steps.
train_steps = _define_train_step()
# Define a status message.
status_message = tf.string_join(
[
'Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())
],
name='status_message')
# Train the model.
tfgan.gan_train(
train_ops,
FLAGS.train_log_dir,
get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
hooks=[
tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)
],
master=FLAGS.master,
is_chief=FLAGS.task == 0)
if __name__ == '__main__':
tf.flags.mark_flag_as_required('image_file_patterns')
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for stargan.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import numpy as np
import tensorflow as tf
import train
FLAGS = flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
def _test_generator(input_images, _):
"""Simple generator function."""
return input_images * tf.get_variable('dummy_g', initializer=2.0)
def _test_discriminator(inputs, num_domains):
"""Differentiable dummy discriminator for StarGAN."""
hidden = tf.contrib.layers.flatten(inputs)
output_src = tf.reduce_mean(hidden, axis=1)
output_cls = tf.contrib.layers.fully_connected(
inputs=hidden,
num_outputs=num_domains,
activation_fn=None,
normalizer_fn=None,
biases_initializer=None)
return output_src, output_cls
train.network.generator = _test_generator
train.network.discriminator = _test_discriminator
class TrainTest(tf.test.TestCase):
def test_define_model(self):
FLAGS.batch_size = 2
images_shape = [FLAGS.batch_size, 4, 4, 3]
images_np = np.zeros(shape=images_shape)
images = tf.constant(images_np, dtype=tf.float32)
labels = tf.one_hot([0] * FLAGS.batch_size, 2)
model = train._define_model(images, labels)
self.assertIsInstance(model, tfgan.StarGANModel)
self.assertShapeEqual(images_np, model.generated_data)
self.assertShapeEqual(images_np, model.reconstructed_data)
self.assertTrue(isinstance(model.discriminator_variables, list))
self.assertTrue(isinstance(model.generator_variables, list))
self.assertIsInstance(model.discriminator_scope, tf.VariableScope)
self.assertTrue(model.generator_scope, tf.VariableScope)
self.assertTrue(callable(model.discriminator_fn))
self.assertTrue(callable(model.generator_fn))
@mock.patch.object(tf.train, 'get_or_create_global_step', autospec=True)
def test_get_lr(self, mock_get_or_create_global_step):
FLAGS.max_number_of_steps = 10
base_lr = 0.01
with self.test_session(use_gpu=True) as sess:
mock_get_or_create_global_step.return_value = tf.constant(2)
lr_step2 = sess.run(train._get_lr(base_lr))
mock_get_or_create_global_step.return_value = tf.constant(9)
lr_step9 = sess.run(train._get_lr(base_lr))
self.assertAlmostEqual(base_lr, lr_step2)
self.assertAlmostEqual(base_lr * 0.2, lr_step9)
@mock.patch.object(tf.summary, 'scalar', autospec=True)
def test_define_train_ops(self, mock_summary_scalar):
FLAGS.batch_size = 2
FLAGS.generator_lr = 0.1
FLAGS.discriminator_lr = 0.01
images_shape = [FLAGS.batch_size, 4, 4, 3]
images = tf.zeros(images_shape, dtype=tf.float32)
labels = tf.one_hot([0] * FLAGS.batch_size, 2)
model = train._define_model(images, labels)
loss = tfgan.stargan_loss(model)
train_ops = train._define_train_ops(model, loss)
self.assertIsInstance(train_ops, tfgan.GANTrainOps)
mock_summary_scalar.assert_has_calls([
mock.call('generator_lr', mock.ANY),
mock.call('discriminator_lr', mock.ANY)
])
def test_get_train_step(self):
FLAGS.gen_disc_step_ratio = 0.5
train_steps = train._define_train_step()
self.assertEqual(1, train_steps.generator_train_steps)
self.assertEqual(2, train_steps.discriminator_train_steps)
FLAGS.gen_disc_step_ratio = 3
train_steps = train._define_train_step()
self.assertEqual(3, train_steps.generator_train_steps)
self.assertEqual(1, train_steps.discriminator_train_steps)
@mock.patch.object(
train.data_provider, 'provide_data', autospec=True)
def test_main(self, mock_provide_data):
FLAGS.image_file_patterns = ['/tmp/A/*.jpg', '/tmp/B/*.jpg', '/tmp/C/*.jpg']
FLAGS.max_number_of_steps = 10
FLAGS.batch_size = 2
num_domains = 3
images_shape = [FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3]
img_list = [tf.zeros(images_shape)] * num_domains
lbl_list = [tf.one_hot([0] * FLAGS.batch_size, num_domains)] * num_domains
mock_provide_data.return_value = (img_list, lbl_list)
train.main(None)
if __name__ == '__main__':
tf.test.main()
"""StarGAN Estimator data provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import data_provider
from google3.pyglib import resources
provide_data = data_provider.provide_data
def provide_celeba_test_set():
"""Provide one example of every class, and labels.
Returns:
An `np.array` of shape (num_domains, H, W, C) representing the images.
Values are in [-1, 1].
An `np.array` of shape (num_domains, num_domains) representing the labels.
Raises:
ValueError: If test data is inconsistent or malformed.
"""
base_dir = 'google3/third_party/tensorflow_models/gan/stargan_estimator/data'
images_fn = os.path.join(base_dir, 'celeba_test_split_images.npy')
with resources.GetResourceAsFile(images_fn) as f:
images_np = np.load(f)
labels_fn = os.path.join(base_dir, 'celeba_test_split_labels.npy')
with resources.GetResourceAsFile(labels_fn) as f:
labels_np = np.load(f)
if images_np.shape[0] != labels_np.shape[0]:
raise ValueError('Test data is malformed.')
return images_np, labels_np
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a StarGAN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
from absl import flags
import numpy as np
import scipy.misc
import tensorflow as tf
import network
import data_provider
# FLAGS for data.
flags.DEFINE_multi_string(
'image_file_patterns', None,
'List of file pattern for different domain of images. '
'(e.g.[\'black_hair\', \'blond_hair\', \'brown_hair\']')
flags.DEFINE_integer('batch_size', 6, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 128, 'The patch size of images.')
# Write-to-disk flags.
flags.DEFINE_string('output_dir', '/tmp/stargan/out/',
'Directory where to write summary image.')
# FLAGS for training hyper-parameters.
flags.DEFINE_float('generator_lr', 1e-4, 'The generator learning rate.')
flags.DEFINE_float('discriminator_lr', 1e-4, 'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 1000000,
'The maximum number of gradient steps.')
flags.DEFINE_integer('steps_per_eval', 1000,
'The number of steps after which we write eval to disk.')
flags.DEFINE_float('adam_beta1', 0.5, 'Adam Beta 1 for the Adam optimizer.')
flags.DEFINE_float('adam_beta2', 0.999, 'Adam Beta 2 for the Adam optimizer.')
flags.DEFINE_float('gen_disc_step_ratio', 0.2,
'Generator:Discriminator training step ratio.')
# FLAGS for distributed training.
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
FLAGS = flags.FLAGS
tfgan = tf.contrib.gan
def _get_optimizer(gen_lr, dis_lr):
"""Returns generator optimizer and discriminator optimizer.
Args:
gen_lr: A scalar float `Tensor` or a Python number. The Generator learning
rate.
dis_lr: A scalar float `Tensor` or a Python number. The Discriminator
learning rate.
Returns:
A tuple of generator optimizer and discriminator optimizer.
"""
gen_opt = tf.train.AdamOptimizer(
gen_lr, beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, use_locking=True)
dis_opt = tf.train.AdamOptimizer(
dis_lr, beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, use_locking=True)
return gen_opt, dis_opt
def _define_train_step():
"""Get the training step for generator and discriminator for each GAN step.
Returns:
GANTrainSteps namedtuple representing the training step configuration.
"""
if FLAGS.gen_disc_step_ratio <= 1:
discriminator_step = int(1 / FLAGS.gen_disc_step_ratio)
return tfgan.GANTrainSteps(1, discriminator_step)
else:
generator_step = int(FLAGS.gen_disc_step_ratio)
return tfgan.GANTrainSteps(generator_step, 1)
def _get_summary_image(estimator, test_images_np):
"""Returns a numpy image of the generate on the test images."""
num_domains = len(test_images_np)
img_rows = []
for img_np in test_images_np:
def test_input_fn():
dataset_imgs = [img_np] * num_domains # pylint:disable=cell-var-from-loop
dataset_lbls = [tf.one_hot([d], num_domains) for d in xrange(num_domains)]
# Make into a dataset.
dataset_imgs = np.stack(dataset_imgs)
dataset_imgs = np.expand_dims(dataset_imgs, 1)
dataset_lbls = tf.stack(dataset_lbls)
unused_tensor = tf.zeros(num_domains)
return tf.data.Dataset.from_tensor_slices(
((dataset_imgs, dataset_lbls), unused_tensor))
prediction_iterable = estimator.predict(test_input_fn)
predictions = [prediction_iterable.next() for _ in xrange(num_domains)]
transform_row = np.concatenate([img_np] + predictions, 1)
img_rows.append(transform_row)
all_rows = np.concatenate(img_rows, 0)
# Normalize` [-1, 1] to [0, 1].
normalized_summary = (all_rows + 1.0) / 2.0
return normalized_summary
def _write_to_disk(summary_image, filename):
"""Write to disk."""
buf = io.BytesIO()
scipy.misc.imsave(buf, summary_image, format='png')
buf.seek(0)
with tf.gfile.GFile(filename, 'w') as f:
f.write(buf.getvalue())
def main(_, override_generator_fn=None, override_discriminator_fn=None):
# Create directories if not exist.
if not tf.gfile.Exists(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
# Make sure steps integers are consistent.
if FLAGS.max_number_of_steps % FLAGS.steps_per_eval != 0:
raise ValueError('`max_number_of_steps` must be divisible by '
'`steps_per_eval`.')
# Create optimizers.
gen_opt, dis_opt = _get_optimizer(FLAGS.generator_lr, FLAGS.discriminator_lr)
# Create estimator.
# (joelshor): Add optional distribution strategy here.
stargan_estimator = tfgan.estimator.StarGANEstimator(
generator_fn=override_generator_fn or network.generator,
discriminator_fn=override_discriminator_fn or network.discriminator,
loss_fn=tfgan.stargan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
get_hooks_fn=tfgan.get_sequential_train_hooks(_define_train_step()),
add_summaries=tfgan.estimator.SummaryType.IMAGES)
# Get input function for training and test images.
train_input_fn = lambda: data_provider.provide_data( # pylint:disable=g-long-lambda
FLAGS.image_file_patterns, FLAGS.batch_size, FLAGS.patch_size)
test_images_np, _ = data_provider.provide_celeba_test_set()
filename_str = os.path.join(FLAGS.output_dir, 'summary_image_%i.png')
# Periodically train and write prediction output to disk.
cur_step = 0
while cur_step < FLAGS.max_number_of_steps:
stargan_estimator.train(train_input_fn, steps=FLAGS.steps_per_eval)
cur_step += FLAGS.steps_per_eval
summary_img = _get_summary_image(stargan_estimator, test_images_np)
_write_to_disk(summary_img, filename_str % cur_step)
if __name__ == '__main__':
tf.flags.mark_flag_as_required('image_file_patterns')
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for stargan_estimator.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import tensorflow as tf
import train
FLAGS = flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
TESTDATA_DIR = 'google3/third_party/tensorflow_models/gan/stargan_estimator/testdata/celeba'
def _test_generator(input_images, _):
"""Simple generator function."""
return input_images * tf.get_variable('dummy_g', initializer=2.0)
def _test_discriminator(inputs, num_domains):
"""Differentiable dummy discriminator for StarGAN."""
hidden = tf.contrib.layers.flatten(inputs)
output_src = tf.reduce_mean(hidden, axis=1)
output_cls = tf.contrib.layers.fully_connected(
inputs=hidden,
num_outputs=num_domains,
activation_fn=None,
normalizer_fn=None,
biases_initializer=None)
return output_src, output_cls
class TrainTest(tf.test.TestCase):
def test_main(self):
FLAGS.image_file_patterns = [
os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'black/*.jpg'),
os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'blond/*.jpg'),
os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'brown/*.jpg'),
]
FLAGS.max_number_of_steps = 1
FLAGS.steps_per_eval = 1
FLAGS.batch_size = 1
train.main(None, _test_generator, _test_discriminator)
if __name__ == '__main__':
tf.test.main()
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment