"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "7eacbc8713c301e2a761131124b7718319392ce4"
Commit 54a5a577 authored by Joel Shor's avatar Joel Shor Committed by Joel Shor
Browse files

Project import generated by Copybara.

PiperOrigin-RevId: 215004158
parent 4f7074f6
......@@ -35,7 +35,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
('Unconditional', False, False),
('Conditional', True, False),
('SyncReplicas', False, True))
def test_build_graph_helper(self, conditional, use_sync_replicas):
def test_build_graph(self, conditional, use_sync_replicas):
FLAGS.max_number_of_steps = 0
FLAGS.conditional = conditional
FLAGS.use_sync_replicas = use_sync_replicas
......
......@@ -80,36 +80,34 @@ def _provide_custom_dataset(image_file_pattern,
image_file_pattern: A string of glob pattern of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
num_threads: Number of mapping threads. Defaults to 1.
patch_size: Size of the path to extract from the image. Defaults to 128.
Returns:
A float `Tensor` of shape [batch_size, patch_size, patch_size, 3]
representing a batch of images.
"""
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(image_file_pattern),
shuffle=shuffle,
capacity=5 * batch_size)
image_reader = tf.WholeFileReader()
A tf.data.Dataset with Tensors of shape
[batch_size, patch_size, patch_size, 3] representing a batch of images.
_, image_bytes = image_reader.read(filename_queue)
image = tf.image.decode_image(image_bytes)
image_patch = full_image_to_patch(image, patch_size)
Raises:
ValueError: If no files match `image_file_pattern`.
"""
if not tf.gfile.Glob(image_file_pattern):
raise ValueError('No file patterns found.')
filenames_ds = tf.data.Dataset.list_files(image_file_pattern)
bytes_ds = filenames_ds.map(tf.io.read_file, num_parallel_calls=num_threads)
images_ds = bytes_ds.map(
tf.image.decode_image, num_parallel_calls=num_threads)
patches_ds = images_ds.map(
lambda img: full_image_to_patch(img, patch_size),
num_parallel_calls=num_threads)
patches_ds = patches_ds.repeat()
if shuffle:
return tf.train.shuffle_batch(
[image_patch],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
return tf.train.batch(
[image_patch],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)
patches_ds = patches_ds.shuffle(5 * batch_size)
patches_ds = patches_ds.prefetch(5 * batch_size)
patches_ds = patches_ds.batch(batch_size)
return patches_ds
def provide_custom_datasets(image_file_patterns,
......@@ -127,8 +125,8 @@ def provide_custom_datasets(image_file_patterns,
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
A list of float `Tensor`s with the same size of `image_file_patterns`.
Each of the `Tensor` in the list has a shape of
A list of tf.data.Datasets the same number as `image_file_patterns`. Each
of the datasets have `Tensor`'s in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images.
Raises:
......@@ -147,4 +145,41 @@ def provide_custom_datasets(image_file_patterns,
shuffle=shuffle,
num_threads=num_threads,
patch_size=patch_size))
return custom_datasets
def provide_custom_data(image_file_patterns,
batch_size,
shuffle=True,
num_threads=1,
patch_size=128):
"""Provides multiple batches of custom image data.
Args:
image_file_patterns: A list of glob patterns of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
A list of float `Tensor`s with the same size of `image_file_patterns`. Each
of the `Tensor` in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images. As a
side effect, the tf.Dataset initializer is added to the
tf.GraphKeys.TABLE_INITIALIZERS collection.
Raises:
ValueError: If image_file_patterns is not a list or tuple.
"""
datasets = provide_custom_datasets(
image_file_patterns, batch_size, shuffle, num_threads, patch_size)
tensors = []
for ds in datasets:
iterator = ds.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
tensors.append(iterator.get_next())
return tensors
......@@ -62,41 +62,67 @@ class DataProviderTest(tf.test.TestCase):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images = data_provider._provide_custom_dataset(
images_ds = data_provider._provide_custom_dataset(
file_pattern, batch_size=batch_size, patch_size=patch_size)
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())
self.assertEqual(tf.float32, images.dtype)
self.assertListEqual([None, patch_size, patch_size, 3],
images_ds.output_shapes.as_list())
self.assertEqual(tf.float32, images_ds.output_types)
iterator = images_ds.make_initializable_iterator()
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_out = sess.run(images)
sess.run(iterator.initializer)
images_out = sess.run(iterator.get_next())
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_custom_datasets_provider(self):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images_ds_list = data_provider.provide_custom_datasets(
[file_pattern, file_pattern],
batch_size=batch_size,
patch_size=patch_size)
for images_ds in images_ds_list:
self.assertListEqual([None, patch_size, patch_size, 3],
images_ds.output_shapes.as_list())
self.assertEqual(tf.float32, images_ds.output_types)
iterators = [x.make_initializable_iterator() for x in images_ds_list]
initialiers = [x.initializer for x in iterators]
img_tensors = [x.get_next() for x in iterators]
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
sess.run(initialiers)
images_out_list = sess.run(img_tensors)
for images_out in images_out_list:
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_custom_datasets_provider(self):
def test_custom_data_provider(self):
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images_list = data_provider.provide_custom_datasets(
images_list = data_provider.provide_custom_data(
[file_pattern, file_pattern],
batch_size=batch_size,
patch_size=patch_size)
for images in images_list:
self.assertListEqual([batch_size, patch_size, patch_size, 3],
self.assertListEqual([None, patch_size, patch_size, 3],
images.shape.as_list())
self.assertEqual(tf.float32, images.dtype)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_out_list = sess.run(images_list)
for images_out in images_out_list:
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
sess.run(tf.tables_initializer())
images_out_list = sess.run(images_list)
for images_out in images_out_list:
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
if __name__ == '__main__':
......
......@@ -35,7 +35,10 @@ class InferenceDemoTest(tf.test.TestCase):
self._geny_dir = os.path.join(FLAGS.test_tmpdir, 'geny')
@mock.patch.object(tfgan, 'gan_train', autospec=True)
def testTrainingAndInferenceGraphsAreCompatible(self, unused_mock_gan_train):
@mock.patch.object(
train.data_provider, 'provide_custom_data', autospec=True)
def testTrainingAndInferenceGraphsAreCompatible(
self, mock_provide_custom_data, unused_mock_gan_train):
# Training and inference graphs can get out of sync if changes are made
# to one but not the other. This test will keep them in sync.
......@@ -52,6 +55,8 @@ class InferenceDemoTest(tf.test.TestCase):
FLAGS.task = 0
FLAGS.cycle_consistency_loss_weight = 2.0
FLAGS.max_number_of_steps = 1
mock_provide_custom_data.return_value = (
tf.zeros([3, 4, 4, 3,]), tf.zeros([3, 4, 4, 3]))
train.main(None)
init_op = tf.global_variables_initializer()
train_sess.run(init_op)
......
......@@ -169,10 +169,13 @@ def main(_):
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
with tf.name_scope('inputs'):
images_x, images_y = data_provider.provide_custom_datasets(
images_x, images_y = data_provider.provide_custom_data(
[FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern],
batch_size=FLAGS.batch_size,
patch_size=FLAGS.patch_size)
# Set batch size for summaries.
images_x.set_shape([FLAGS.batch_size, None, None, None])
images_y.set_shape([FLAGS.batch_size, None, None, None])
# Define CycleGAN model.
cyclegan_model = _define_model(images_x, images_y)
......
......@@ -128,11 +128,12 @@ class TrainTest(tf.test.TestCase):
FLAGS.cycle_consistency_loss_weight = 2.0
FLAGS.max_number_of_steps = 1
mock_data_provider.provide_custom_datasets.return_value = (tf.zeros(
[1, 2], dtype=tf.float32), tf.zeros([1, 2], dtype=tf.float32))
mock_data_provider.provide_custom_data.return_value = (
tf.zeros([3, 2, 2, 3], dtype=tf.float32),
tf.zeros([3, 2, 2, 3], dtype=tf.float32))
train.main(None)
mock_data_provider.provide_custom_datasets.assert_called_once_with(
mock_data_provider.provide_custom_data.assert_called_once_with(
['/tmp/x/*.jpg', '/tmp/y/*.jpg'], batch_size=3, patch_size=8)
mock_define_model.assert_called_once_with(mock.ANY, mock.ANY)
mock_cyclegan_loss.assert_called_once_with(
......
......@@ -75,6 +75,34 @@ def get_total_num_stages(**kwargs):
return 2 * kwargs['num_resolutions'] - 1
def get_batch_size(stage_id, **kwargs):
"""Returns batch size for each stage.
It is expected that `len(batch_size_schedule) == num_resolutions`. Each stage
corresponds to a resolution and hence a batch size. However if
`len(batch_size_schedule) < num_resolutions`, pad `batch_size_schedule` in the
beginning with the first batch size.
Args:
stage_id: An integer of training stage index.
**kwargs: A dictionary of
'batch_size_schedule': A list of integer, each element is the batch size
for the current training image resolution.
'num_resolutions': An integer of number of progressive resolutions.
Returns:
An integer batch size for the `stage_id`.
"""
batch_size_schedule = kwargs['batch_size_schedule']
num_resolutions = kwargs['num_resolutions']
if len(batch_size_schedule) < num_resolutions:
batch_size_schedule = (
[batch_size_schedule[0]] * (num_resolutions - len(batch_size_schedule))
+ batch_size_schedule)
return int(batch_size_schedule[(stage_id + 1) // 2])
def get_stage_info(stage_id, **kwargs):
"""Returns information for a training stage.
......@@ -228,14 +256,14 @@ def add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops):
return gan_train_ops, generator_vars_to_restore
def build_model(stage_id, real_images, **kwargs):
def build_model(stage_id, batch_size, real_images, **kwargs):
"""Builds progressive GAN model.
Args:
stage_id: An integer of training stage index.
batch_size: Number of training images in each minibatch.
real_images: A 4D `Tensor` of NHWC format.
**kwargs: A dictionary of
'batch_size': Number of training images in each minibatch.
'start_height': An integer of start image height.
'start_width': An integer of start image width.
'scale_base': An integer of resolution multiplier.
......@@ -267,15 +295,14 @@ def build_model(stage_id, real_images, **kwargs):
Returns:
An inernal object that wraps all information about the model.
"""
batch_size = kwargs['batch_size']
kernel_size = kwargs['kernel_size']
colors = kwargs['colors']
resolution_schedule = make_resolution_schedule(**kwargs)
num_blocks, num_images = get_stage_info(stage_id, **kwargs)
global_step = tf.train.get_or_create_global_step()
current_image_id = global_step * batch_size
current_image_id = tf.train.get_or_create_global_step()
current_image_id_inc_op = current_image_id.assign_add(batch_size)
tf.summary.scalar('current_image_id', current_image_id)
progress = networks.compute_progress(
......@@ -329,6 +356,8 @@ def build_model(stage_id, real_images, **kwargs):
########## Define train ops.
gan_train_ops, optimizer_var_list = define_train_ops(gan_model, gan_loss,
**kwargs)
gan_train_ops = gan_train_ops._replace(
global_step_inc_op=current_image_id_inc_op)
########## Generator smoothing.
generator_ema = tf.train.ExponentialMovingAverage(decay=0.999)
......@@ -339,11 +368,11 @@ def build_model(stage_id, real_images, **kwargs):
pass
model = Model()
model.resolution_schedule = resolution_schedule
model.stage_id = stage_id
model.batch_size = batch_size
model.resolution_schedule = resolution_schedule
model.num_images = num_images
model.num_blocks = num_blocks
model.global_step = global_step
model.current_image_id = current_image_id
model.progress = progress
model.num_filters_fn = _num_filters_fn
......@@ -380,7 +409,6 @@ def add_model_summaries(model, **kwargs):
model: An model object having all information of progressive GAN model,
e.g. the return of build_model().
**kwargs: A dictionary of
'batch_size': Number of training images in each minibatch.
'fake_grid_size': The fake image grid size for summaries.
'interp_grid_size': The latent space interpolated image grid size for
summaries.
......@@ -431,7 +459,7 @@ def add_model_summaries(model, **kwargs):
num_channels=colors),
max_outputs=1)
real_grid_size = int(np.sqrt(kwargs['batch_size']))
real_grid_size = int(np.sqrt(model.batch_size))
tf.summary.image(
'real_images_blend',
tfgan.eval.eval_utils.image_grid(
......@@ -517,11 +545,10 @@ def make_status_message(model):
"""Makes a string `Tensor` of training status."""
return tf.string_join(
[
'Starting train step: ',
tf.as_string(model.global_step), ', current_image_id: ',
'Starting train step: current_image_id: ',
tf.as_string(model.current_image_id), ', progress: ',
tf.as_string(model.progress), ', num_blocks: {}'.format(
model.num_blocks)
model.num_blocks), ', batch_size: {}'.format(model.batch_size)
],
name='status_message')
......@@ -541,8 +568,6 @@ def train(model, **kwargs):
Returns:
None.
"""
batch_size = kwargs['batch_size']
logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
model.num_blocks, model.num_images)
......@@ -553,7 +578,7 @@ def train(model, **kwargs):
logdir=make_train_sub_dir(model.stage_id, **kwargs),
get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)),
hooks=[
tf.train.StopAtStepHook(last_step=model.num_images // batch_size),
tf.train.StopAtStepHook(last_step=model.num_images),
tf.train.LoggingTensorHook(
[make_status_message(model)], every_n_iter=10)
],
......@@ -561,4 +586,4 @@ def train(model, **kwargs):
is_chief=(kwargs['task'] == 0),
scaffold=scaffold,
save_checkpoint_secs=600,
save_summaries_steps=(kwargs['save_summaries_num_images'] // batch_size))
save_summaries_steps=(kwargs['save_summaries_num_images']))
......@@ -48,6 +48,12 @@ flags.DEFINE_integer('scale_base', 2, 'Resolution multiplier.')
flags.DEFINE_integer('num_resolutions', 4, 'Number of progressive resolutions.')
flags.DEFINE_list(
'batch_size_schedule', [8, 8, 4],
'A list of batch sizes for each resolution, if '
'len(batch_size_schedule) < num_resolutions, pad the schedule in the '
'beginning with the first batch size.')
flags.DEFINE_integer('kernel_size', 3, 'Convolution kernel size.')
flags.DEFINE_integer('colors', 3, 'Number of image channels.')
......@@ -55,8 +61,6 @@ flags.DEFINE_integer('colors', 3, 'Number of image channels.')
flags.DEFINE_bool('to_rgb_use_tanh_activation', False,
'Whether to apply tanh activation when output rgb.')
flags.DEFINE_integer('batch_size', 8, 'Number of images in each batch.')
flags.DEFINE_integer('stable_stage_num_images', 1000,
'Number of images in the stable stage.')
......@@ -123,11 +127,10 @@ def _make_config_from_flags():
for flag in FLAGS.get_key_flags_for_module(sys.argv[0])])
def _provide_real_images(**kwargs):
def _provide_real_images(batch_size, **kwargs):
"""Provides real images."""
dataset_name = kwargs.get('dataset_name')
dataset_file_pattern = kwargs.get('dataset_file_pattern')
batch_size = kwargs['batch_size']
colors = kwargs['colors']
final_height, final_width = train.make_resolution_schedule(
**kwargs).final_resolutions
......@@ -156,12 +159,13 @@ def main(_):
logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))
for stage_id in train.get_stage_ids(**config):
batch_size = train.get_batch_size(stage_id, **config)
tf.reset_default_graph()
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
real_images = None
with tf.device('/cpu:0'), tf.name_scope('inputs'):
real_images = _provide_real_images(**config)
model = train.build_model(stage_id, real_images, **config)
real_images = _provide_real_images(batch_size, **config)
model = train.build_model(stage_id, batch_size, real_images, **config)
train.add_model_summaries(model, **config)
train.train(model, **config)
......
......@@ -29,7 +29,7 @@ import train
FLAGS = flags.FLAGS
def provide_random_data(batch_size=2, patch_size=8, colors=1, **unused_kwargs):
def provide_random_data(batch_size=2, patch_size=4, colors=1, **unused_kwargs):
return tf.random_normal([batch_size, patch_size, patch_size, colors])
......@@ -37,19 +37,19 @@ class TrainTest(absltest.TestCase):
def setUp(self):
self._config = {
'start_height': 4,
'start_width': 4,
'start_height': 2,
'start_width': 2,
'scale_base': 2,
'num_resolutions': 2,
'batch_size_schedule': [2],
'colors': 1,
'to_rgb_use_tanh_activation': True,
'kernel_size': 3,
'batch_size': 2,
'stable_stage_num_images': 4,
'transition_stage_num_images': 4,
'total_num_images': 12,
'save_summaries_num_images': 4,
'latent_vector_size': 8,
'stable_stage_num_images': 1,
'transition_stage_num_images': 1,
'total_num_images': 3,
'save_summaries_num_images': 2,
'latent_vector_size': 2,
'fmap_base': 8,
'fmap_decay': 1.0,
'fmap_max': 8,
......@@ -73,12 +73,21 @@ class TrainTest(absltest.TestCase):
tf.gfile.MakeDirs(train_root_dir)
for stage_id in train.get_stage_ids(**self._config):
batch_size = train.get_batch_size(stage_id, **self._config)
tf.reset_default_graph()
real_images = provide_random_data()
model = train.build_model(stage_id, real_images, **self._config)
real_images = provide_random_data(batch_size=batch_size)
model = train.build_model(stage_id, batch_size, real_images,
**self._config)
train.add_model_summaries(model, **self._config)
train.train(model, **self._config)
def test_get_batch_size(self):
config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]}
# batch_size_schedule is expanded to [8, 8, 8, 4, 2]
# At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]):
self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)
if __name__ == '__main__':
absltest.main()
"""StarGAN data provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
def provide_data(image_file_patterns, batch_size, patch_size):
"""Data provider wrapper on for the data_provider in gan/cyclegan.
Args:
image_file_patterns: A list of file pattern globs.
batch_size: Python int. Batch size.
patch_size: Python int. The patch size to extract.
Returns:
List of `Tensor` of shape (N, H, W, C) representing the images.
List of `Tensor` of shape (N, num_domains) representing the labels.
"""
images = data_provider.provide_custom_data(
image_file_patterns,
batch_size=batch_size,
patch_size=patch_size)
num_domains = len(images)
labels = [tf.one_hot([idx] * batch_size, num_domains) for idx in
range(num_domains)]
return images, labels
"""Tests for google3.third_party.tensorflow_models.gan.stargan.data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from google3.testing.pybase import googletest
import data_provider
mock = tf.test.mock
class DataProviderTest(googletest.TestCase):
@mock.patch.object(
data_provider.data_provider, 'provide_custom_data', autospec=True)
def test_data_provider(self, mock_provide_custom_data):
batch_size = 2
patch_size = 8
num_domains = 3
images_shape = [batch_size, patch_size, patch_size, 3]
mock_provide_custom_data.return_value = [
tf.zeros(images_shape) for _ in range(num_domains)
]
images, labels = data_provider.provide_data(
image_file_patterns=None, batch_size=batch_size, patch_size=patch_size)
self.assertEqual(num_domains, len(images))
self.assertEqual(num_domains, len(labels))
for label in labels:
self.assertListEqual([batch_size, num_domains], label.shape.as_list())
for image in images:
self.assertListEqual(images_shape, image.shape.as_list())
if __name__ == '__main__':
googletest.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for a StarGAN model.
This module contains basic layers to build a StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorvh implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import ops
def generator_down_sample(input_net, final_num_outputs=256):
"""Down-sampling module in Generator.
Down sampling pathway of the Generator Architecture:
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L32
Notes:
We require dimension 1 and dimension 2 of the input_net to be fully defined
for the correct down sampling.
Args:
input_net: Tensor of shape (batch_size, h, w, c + num_class).
final_num_outputs: (int) Number of hidden unit for the final layer.
Returns:
Tensor of shape (batch_size, h / 4, w / 4, 256).
Raises:
ValueError: If final_num_outputs are not divisible by 4,
or input_net does not have a rank of 4,
or dimension 1 and dimension 2 of input_net are not defined at graph
construction time,
or dimension 1 and dimension 2 of input_net are not divisible by 4.
"""
if final_num_outputs % 4 != 0:
raise ValueError('Final number outputs need to be divisible by 4.')
# Check the rank of input_net.
input_net.shape.assert_has_rank(4)
# Check dimension 1 and dimension 2 are defined and divisible by 4.
if input_net.shape[1]:
if input_net.shape[1] % 4 != 0:
raise ValueError(
'Dimension 1 of the input should be divisible by 4, but is {} '
'instead.'.
format(input_net.shape[1]))
else:
raise ValueError('Dimension 1 of the input should be explicitly defined.')
# Check dimension 1 and dimension 2 are defined and divisible by 4.
if input_net.shape[2]:
if input_net.shape[2] % 4 != 0:
raise ValueError(
'Dimension 2 of the input should be divisible by 4, but is {} '
'instead.'.
format(input_net.shape[2]))
else:
raise ValueError('Dimension 2 of the input should be explicitly defined.')
with tf.variable_scope('generator_down_sample'):
with tf.contrib.framework.arg_scope(
[tf.contrib.layers.conv2d],
padding='VALID',
biases_initializer=None,
normalizer_fn=tf.contrib.layers.instance_norm,
activation_fn=tf.nn.relu):
down_sample = ops.pad(input_net, 3)
down_sample = tf.contrib.layers.conv2d(
inputs=down_sample,
num_outputs=final_num_outputs / 4,
kernel_size=7,
stride=1,
scope='conv_0')
down_sample = ops.pad(down_sample, 1)
down_sample = tf.contrib.layers.conv2d(
inputs=down_sample,
num_outputs=final_num_outputs / 2,
kernel_size=4,
stride=2,
scope='conv_1')
down_sample = ops.pad(down_sample, 1)
output_net = tf.contrib.layers.conv2d(
inputs=down_sample,
num_outputs=final_num_outputs,
kernel_size=4,
stride=2,
scope='conv_2')
return output_net
def _residual_block(input_net,
num_outputs,
kernel_size,
stride=1,
padding_size=0,
activation_fn=tf.nn.relu,
normalizer_fn=None,
name='residual_block'):
"""Residual Block.
Input Tensor X - > Conv1 -> IN -> ReLU -> Conv2 -> IN + X
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L7
Args:
input_net: Tensor as input.
num_outputs: (int) number of output channels for Convolution.
kernel_size: (int) size of the square kernel for Convolution.
stride: (int) stride for Convolution. Default to 1.
padding_size: (int) padding size for Convolution. Default to 0.
activation_fn: Activation function.
normalizer_fn: Normalization function.
name: Name scope
Returns:
Residual Tensor with the same shape as the input tensor.
"""
with tf.variable_scope(name):
with tf.contrib.framework.arg_scope(
[tf.contrib.layers.conv2d],
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
padding='VALID',
normalizer_fn=normalizer_fn,
activation_fn=None):
res_block = ops.pad(input_net, padding_size)
res_block = tf.contrib.layers.conv2d(inputs=res_block, scope='conv_0')
res_block = activation_fn(res_block, name='activation_0')
res_block = ops.pad(res_block, padding_size)
res_block = tf.contrib.layers.conv2d(inputs=res_block, scope='conv_1')
output_net = res_block + input_net
return output_net
def generator_bottleneck(input_net, residual_block_num=6, num_outputs=256):
"""Bottleneck module in Generator.
Residual bottleneck pathway in Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L40
Args:
input_net: Tensor of shape (batch_size, h / 4, w / 4, 256).
residual_block_num: (int) Number of residual_block_num. Default to 6 per the
original implementation.
num_outputs: (int) Number of hidden unit in the residual bottleneck. Default
to 256 per the original implementation.
Returns:
Tensor of shape (batch_size, h / 4, w / 4, 256).
Raises:
ValueError: If the rank of the input tensor is not 4,
or the last channel of the input_tensor is not explicitly defined,
or the last channel of the input_tensor is not the same as num_outputs.
"""
# Check the rank of input_net.
input_net.shape.assert_has_rank(4)
# Check dimension 4 of the input_net.
if input_net.shape[-1]:
if input_net.shape[-1] != num_outputs:
raise ValueError(
'The last dimension of the input_net should be the same as '
'num_outputs: but {} vs. {} instead.'.format(input_net.shape[-1],
num_outputs))
else:
raise ValueError(
'The last dimension of the input_net should be explicitly defined.')
with tf.variable_scope('generator_bottleneck'):
bottleneck = input_net
for i in range(residual_block_num):
bottleneck = _residual_block(
input_net=bottleneck,
num_outputs=num_outputs,
kernel_size=3,
stride=1,
padding_size=1,
activation_fn=tf.nn.relu,
normalizer_fn=tf.contrib.layers.instance_norm,
name='residual_block_{}'.format(i))
return bottleneck
def generator_up_sample(input_net, num_outputs):
"""Up-sampling module in Generator.
Up sampling path for image generation in the Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L44
Args:
input_net: Tensor of shape (batch_size, h / 4, w / 4, 256).
num_outputs: (int) Number of channel for the output tensor.
Returns:
Tensor of shape (batch_size, h, w, num_outputs).
"""
with tf.variable_scope('generator_up_sample'):
with tf.contrib.framework.arg_scope(
[tf.contrib.layers.conv2d_transpose],
kernel_size=4,
stride=2,
padding='VALID',
normalizer_fn=tf.contrib.layers.instance_norm,
activation_fn=tf.nn.relu):
up_sample = tf.contrib.layers.conv2d_transpose(
inputs=input_net, num_outputs=128, scope='deconv_0')
up_sample = up_sample[:, 1:-1, 1:-1, :]
up_sample = tf.contrib.layers.conv2d_transpose(
inputs=up_sample, num_outputs=64, scope='deconv_1')
up_sample = up_sample[:, 1:-1, 1:-1, :]
output_net = ops.pad(up_sample, 3)
output_net = tf.contrib.layers.conv2d(
inputs=output_net,
num_outputs=num_outputs,
kernel_size=7,
stride=1,
padding='VALID',
activation_fn=tf.nn.tanh,
normalizer_fn=None,
biases_initializer=None,
scope='conv_0')
return output_net
def discriminator_input_hidden(input_net, hidden_layer=6, init_num_outputs=64):
"""Input Layer + Hidden Layer in the Discriminator.
Feature extraction pathway in the Discriminator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L68
Args:
input_net: Tensor of shape (batch_size, h, w, 3) as batch of images.
hidden_layer: (int) Number of hidden layers. Default to 6 per the original
implementation.
init_num_outputs: (int) Number of hidden unit in the first hidden layer. The
number of hidden unit double after each layer. Default to 64 per the
original implementation.
Returns:
Tensor of shape (batch_size, h / 64, w / 64, 2048) as features.
"""
num_outputs = init_num_outputs
with tf.variable_scope('discriminator_input_hidden'):
hidden = input_net
for i in range(hidden_layer):
hidden = ops.pad(hidden, 1)
hidden = tf.contrib.layers.conv2d(
inputs=hidden,
num_outputs=num_outputs,
kernel_size=4,
stride=2,
padding='VALID',
activation_fn=None,
normalizer_fn=None,
scope='conv_{}'.format(i))
hidden = tf.nn.leaky_relu(hidden, alpha=0.01)
num_outputs = 2 * num_outputs
return hidden
def discriminator_output_source(input_net):
"""Output Layer for Source in the Discriminator.
Determine if the image is real/fake based on the feature extracted. We follow
the original paper design where the output is not a simple (batch_size) shape
Tensor but rather a (batch_size, 2, 2, 2048) shape Tensor. We will get the
correct shape later when we piece things together.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L79
Args:
input_net: Tensor of shape (batch_size, h / 64, w / 64, 2048) as features.
Returns:
Tensor of shape (batch_size, h / 64, w / 64, 1) as the score.
"""
with tf.variable_scope('discriminator_output_source'):
output_src = ops.pad(input_net, 1)
output_src = tf.contrib.layers.conv2d(
inputs=output_src,
num_outputs=1,
kernel_size=3,
stride=1,
padding='VALID',
activation_fn=None,
normalizer_fn=None,
biases_initializer=None,
scope='conv')
return output_src
def discriminator_output_class(input_net, class_num):
"""Output Layer for Domain Classification in the Discriminator.
The original paper use convolution layer where the kernel size is the height
and width of the Tensor. We use an equivalent operation here where we first
flatten the Tensor to shape (batch_size, K) and a fully connected layer.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L80https
Args:
input_net: Tensor of shape (batch_size, h / 64, w / 64, 2028).
class_num: Number of output classes to be predicted.
Returns:
Tensor of shape (batch_size, class_num).
"""
with tf.variable_scope('discriminator_output_class'):
output_cls = tf.contrib.layers.flatten(input_net, scope='flatten')
output_cls = tf.contrib.layers.fully_connected(
inputs=output_cls,
num_outputs=class_num,
activation_fn=None,
normalizer_fn=None,
biases_initializer=None,
scope='fully_connected')
return output_cls
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import layers
class LayersTest(tf.test.TestCase):
def test_residual_block(self):
n = 2
h = 32
w = h
c = 256
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers._residual_block(
input_net=input_tensor,
num_outputs=c,
kernel_size=3,
stride=1,
padding_size=1)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, c), output.shape)
def test_generator_down_sample(self):
n = 2
h = 128
w = h
c = 3 + 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.generator_down_sample(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h // 4, w // 4, 256), output.shape)
def test_generator_bottleneck(self):
n = 2
h = 32
w = h
c = 256
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.generator_bottleneck(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, c), output.shape)
def test_generator_up_sample(self):
n = 2
h = 32
w = h
c = 256
c_out = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.generator_up_sample(input_tensor, c_out)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h * 4, w * 4, c_out), output.shape)
def test_discriminator_input_hidden(self):
n = 2
h = 128
w = 128
c = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.discriminator_input_hidden(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, 2, 2, 2048), output.shape)
def test_discriminator_output_source(self):
n = 2
h = 2
w = 2
c = 2048
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.discriminator_output_source(input_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, 1), output.shape)
def test_discriminator_output_class(self):
n = 2
h = 2
w = 2
c = 2048
num_domain = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_tensor = layers.discriminator_output_class(input_tensor, num_domain)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, num_domain), output.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Neural network for a StarGAN model.
This module contains the Generator and Discriminator Neural Network to build a
StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorch implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import layers
import ops
def generator(inputs, targets):
"""Generator module.
Piece everything together for the Generator.
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22
Args:
inputs: Tensor of shape (batch_size, h, w, c) representing the
images/information that we want to transform.
targets: Tensor of shape (batch_size, num_domains) representing the target
domain the generator should transform the image/information to.
Returns:
Tensor of shape (batch_size, h, w, c) as the inputs.
"""
with tf.variable_scope('generator'):
input_with_condition = ops.condition_input_with_pixel_padding(
inputs, targets)
down_sample = layers.generator_down_sample(input_with_condition)
bottleneck = layers.generator_bottleneck(down_sample)
up_sample = layers.generator_up_sample(bottleneck, inputs.shape[-1])
return up_sample
def discriminator(input_net, class_num):
"""Discriminator Module.
Piece everything together and reshape the output source tensor
PyTorch Version:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L63
Notes:
The PyTorch Version run the reduce_mean operation later in their solver:
https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/solver.py#L245
Args:
input_net: Tensor of shape (batch_size, h, w, c) as batch of images.
class_num: (int) number of domain to be predicted
Returns:
output_src: Tensor of shape (batch_size) where each value is a logit
representing whether the image is real of fake.
output_cls: Tensor of shape (batch_size, class_um) where each value is a
logit representing whether the image is in the associated domain.
"""
with tf.variable_scope('discriminator'):
hidden = layers.discriminator_input_hidden(input_net)
output_src = layers.discriminator_output_source(hidden)
output_src = tf.contrib.layers.flatten(output_src)
output_src = tf.reduce_mean(output_src, axis=1)
output_cls = layers.discriminator_output_class(hidden, class_num)
return output_src, output_cls
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import network
class NetworkTest(tf.test.TestCase):
def test_generator(self):
n = 2
h = 128
w = h
c = 4
class_num = 3
input_tensor = tf.random_uniform((n, h, w, c))
target_tensor = tf.random_uniform((n, class_num))
output_tensor = network.generator(input_tensor, target_tensor)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(output_tensor)
self.assertTupleEqual((n, h, w, c), output.shape)
def test_discriminator(self):
n = 2
h = 128
w = h
c = 3
class_num = 3
input_tensor = tf.random_uniform((n, h, w, c))
output_src_tensor, output_cls_tensor = network.discriminator(
input_tensor, class_num)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output_src, output_cls = sess.run([output_src_tensor, output_cls_tensor])
self.assertEqual(1, len(output_src.shape))
self.assertEqual(n, output_src.shape[0])
self.assertTupleEqual((n, class_num), output_cls.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for a StarGAN model.
This module contains basic ops to build a StarGAN model.
See https://arxiv.org/abs/1711.09020 for details about the model.
See https://github.com/yunjey/StarGAN for the original pytorvh implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def _padding_arg(h, w, input_format):
"""Calculate the padding shape for tf.pad().
Args:
h: (int) padding on the height dim.
w: (int) padding on the width dim.
input_format: (string) the input format as in 'NHWC' or 'HWC'.
Raises:
ValueError: If input_format is not 'NHWC' or 'HWC'.
Returns:
A two dimension array representing the padding argument.
"""
if input_format == 'NHWC':
return [[0, 0], [h, h], [w, w], [0, 0]]
elif input_format == 'HWC':
return [[h, h], [w, w], [0, 0]]
else:
raise ValueError('Input Format %s is not supported.' % input_format)
def pad(input_net, padding_size):
"""Padding the tensor with padding_size on both the height and width dim.
Args:
input_net: Tensor in 3D ('HWC') or 4D ('NHWC').
padding_size: (int) the size of the padding.
Notes:
Original StarGAN use zero padding instead of mirror padding.
Raises:
ValueError: If input_net Tensor is not 3D or 4D.
Returns:
Tensor with same rank as input_net but with padding on the height and width
dim.
"""
if len(input_net.shape) == 4:
return tf.pad(input_net, _padding_arg(padding_size, padding_size, 'NHWC'))
elif len(input_net.shape) == 3:
return tf.pad(input_net, _padding_arg(padding_size, padding_size, 'HWC'))
else:
raise ValueError('The input tensor need to be either 3D or 4D.')
def condition_input_with_pixel_padding(input_tensor, condition_tensor):
"""Pad image tensor with condition tensor as additional color channel.
Args:
input_tensor: Tensor of shape (batch_size, h, w, c) representing images.
condition_tensor: Tensor of shape (batch_size, num_domains) representing the
associated domain for the image in input_tensor.
Returns:
Tensor of shape (batch_size, h, w, c + num_domains) representing the
conditioned data.
Raises:
ValueError: If `input_tensor` isn't rank 4.
ValueError: If `condition_tensor` isn't rank 2.
ValueError: If dimension 1 of the input_tensor and condition_tensor is not
the same.
"""
input_tensor.shape.assert_has_rank(4)
condition_tensor.shape.assert_has_rank(2)
input_tensor.shape[:1].assert_is_compatible_with(condition_tensor.shape[:1])
condition_tensor = tf.expand_dims(condition_tensor, axis=1)
condition_tensor = tf.expand_dims(condition_tensor, axis=1)
condition_tensor = tf.tile(
condition_tensor, [1, input_tensor.shape[1], input_tensor.shape[2], 1])
return tf.concat([input_tensor, condition_tensor], -1)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
import ops
class OpsTest(tf.test.TestCase):
def test_padding_arg(self):
pad_h = 2
pad_w = 3
self.assertListEqual([[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]],
ops._padding_arg(pad_h, pad_w, 'NHWC'))
def test_padding_arg_specify_format(self):
pad_h = 2
pad_w = 3
self.assertListEqual([[pad_h, pad_h], [pad_w, pad_w], [0, 0]],
ops._padding_arg(pad_h, pad_w, 'HWC'))
def test_padding_arg_invalid_format(self):
pad_h = 2
pad_w = 3
with self.assertRaises(ValueError):
ops._padding_arg(pad_h, pad_w, 'INVALID')
def test_padding(self):
n = 2
h = 128
w = 64
c = 3
pad = 3
test_input_tensor = tf.random_uniform((n, h, w, c))
test_output_tensor = ops.pad(test_input_tensor, padding_size=pad)
with self.test_session() as sess:
output = sess.run(test_output_tensor)
self.assertTupleEqual((n, h + pad * 2, w + pad * 2, c), output.shape)
def test_padding_with_3D_tensor(self):
h = 128
w = 64
c = 3
pad = 3
test_input_tensor = tf.random_uniform((h, w, c))
test_output_tensor = ops.pad(test_input_tensor, padding_size=pad)
with self.test_session() as sess:
output = sess.run(test_output_tensor)
self.assertTupleEqual((h + pad * 2, w + pad * 2, c), output.shape)
def test_padding_with_tensor_of_invalid_shape(self):
n = 2
invalid_rank = 1
h = 128
w = 64
c = 3
pad = 3
test_input_tensor = tf.random_uniform((n, invalid_rank, h, w, c))
with self.assertRaises(ValueError):
ops.pad(test_input_tensor, padding_size=pad)
def test_condition_input_with_pixel_padding(self):
n = 2
h = 128
w = h
c = 3
num_label = 5
input_tensor = tf.random_uniform((n, h, w, c))
label_tensor = tf.random_uniform((n, num_label))
output_tensor = ops.condition_input_with_pixel_padding(
input_tensor, label_tensor)
with self.test_session() as sess:
labels, outputs = sess.run([label_tensor, output_tensor])
self.assertTupleEqual((n, h, w, c + num_label), outputs.shape)
for label, output in zip(labels, outputs):
for i in range(output.shape[0]):
for j in range(output.shape[1]):
self.assertListEqual(label.tolist(), output[i, j, c:].tolist())
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a StarGAN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
import data_provider
import network
# FLAGS for data.
flags.DEFINE_multi_string(
'image_file_patterns', None,
'List of file pattern for different domain of images. '
'(e.g.[\'black_hair\', \'blond_hair\', \'brown_hair\']')
flags.DEFINE_integer('batch_size', 6, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 128, 'The patch size of images.')
flags.DEFINE_string('train_log_dir', '/tmp/stargan/',
'Directory where to write event logs.')
# FLAGS for training hyper-parameters.
flags.DEFINE_float('generator_lr', 1e-4, 'The generator learning rate.')
flags.DEFINE_float('discriminator_lr', 1e-4, 'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 1000000,
'The maximum number of gradient steps.')
flags.DEFINE_float('adam_beta1', 0.5, 'Adam Beta 1 for the Adam optimizer.')
flags.DEFINE_float('adam_beta2', 0.999, 'Adam Beta 2 for the Adam optimizer.')
flags.DEFINE_float('gen_disc_step_ratio', 0.2,
'Generator:Discriminator training step ratio.')
# FLAGS for distributed training.
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
FLAGS = flags.FLAGS
tfgan = tf.contrib.gan
def _define_model(images, labels):
"""Create the StarGAN Model.
Args:
images: `Tensor` or list of `Tensor` of shape (N, H, W, C).
labels: `Tensor` or list of `Tensor` of shape (N, num_domains).
Returns:
`StarGANModel` namedtuple.
"""
return tfgan.stargan_model(
generator_fn=network.generator,
discriminator_fn=network.discriminator,
input_data=images,
input_data_domain_label=labels)
def _get_lr(base_lr):
"""Returns a learning rate `Tensor`.
Args:
base_lr: A scalar float `Tensor` or a Python number. The base learning
rate.
Returns:
A scalar float `Tensor` of learning rate which equals `base_lr` when the
global training step is less than FLAGS.max_number_of_steps / 2, afterwards
it linearly decays to zero.
"""
global_step = tf.train.get_or_create_global_step()
lr_constant_steps = FLAGS.max_number_of_steps // 2
def _lr_decay():
return tf.train.polynomial_decay(
learning_rate=base_lr,
global_step=(global_step - lr_constant_steps),
decay_steps=(FLAGS.max_number_of_steps - lr_constant_steps),
end_learning_rate=0.0)
return tf.cond(global_step < lr_constant_steps, lambda: base_lr, _lr_decay)
def _get_optimizer(gen_lr, dis_lr):
"""Returns generator optimizer and discriminator optimizer.
Args:
gen_lr: A scalar float `Tensor` or a Python number. The Generator learning
rate.
dis_lr: A scalar float `Tensor` or a Python number. The Discriminator
learning rate.
Returns:
A tuple of generator optimizer and discriminator optimizer.
"""
gen_opt = tf.train.AdamOptimizer(
gen_lr, beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, use_locking=True)
dis_opt = tf.train.AdamOptimizer(
dis_lr, beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2, use_locking=True)
return gen_opt, dis_opt
def _define_train_ops(model, loss):
"""Defines train ops that trains `stargan_model` with `stargan_loss`.
Args:
model: A `StarGANModel` namedtuple.
loss: A `StarGANLoss` namedtuple containing all losses for
`stargan_model`.
Returns:
A `GANTrainOps` namedtuple.
"""
gen_lr = _get_lr(FLAGS.generator_lr)
dis_lr = _get_lr(FLAGS.discriminator_lr)
gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr)
train_ops = tfgan.gan_train_ops(
model,
loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
return train_ops
def _define_train_step():
"""Get the training step for generator and discriminator for each GAN step.
Returns:
GANTrainSteps namedtuple representing the training step configuration.
"""
if FLAGS.gen_disc_step_ratio <= 1:
discriminator_step = int(1 / FLAGS.gen_disc_step_ratio)
return tfgan.GANTrainSteps(1, discriminator_step)
else:
generator_step = int(FLAGS.gen_disc_step_ratio)
return tfgan.GANTrainSteps(generator_step, 1)
def main(_):
# Create the log_dir if not exist.
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
# Shard the model to different parameter servers.
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Create the input dataset.
with tf.name_scope('inputs'):
images, labels = data_provider.provide_data(
FLAGS.image_file_patterns, FLAGS.batch_size, FLAGS.patch_size)
# Define the model.
with tf.name_scope('model'):
model = _define_model(images, labels)
# Add image summary.
tfgan.eval.add_stargan_image_summaries(
model,
num_images=len(FLAGS.image_file_patterns) * FLAGS.batch_size,
display_diffs=True)
# Define the model loss.
loss = tfgan.stargan_loss(model)
# Define the train ops.
with tf.name_scope('train_ops'):
train_ops = _define_train_ops(model, loss)
# Define the train steps.
train_steps = _define_train_step()
# Define a status message.
status_message = tf.string_join(
[
'Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())
],
name='status_message')
# Train the model.
tfgan.gan_train(
train_ops,
FLAGS.train_log_dir,
get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
hooks=[
tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)
],
master=FLAGS.master,
is_chief=FLAGS.task == 0)
if __name__ == '__main__':
tf.flags.mark_flag_as_required('image_file_patterns')
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for stargan.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import numpy as np
import tensorflow as tf
import train
FLAGS = flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
def _test_generator(input_images, _):
"""Simple generator function."""
return input_images * tf.get_variable('dummy_g', initializer=2.0)
def _test_discriminator(inputs, num_domains):
"""Differentiable dummy discriminator for StarGAN."""
hidden = tf.contrib.layers.flatten(inputs)
output_src = tf.reduce_mean(hidden, axis=1)
output_cls = tf.contrib.layers.fully_connected(
inputs=hidden,
num_outputs=num_domains,
activation_fn=None,
normalizer_fn=None,
biases_initializer=None)
return output_src, output_cls
train.network.generator = _test_generator
train.network.discriminator = _test_discriminator
class TrainTest(tf.test.TestCase):
def test_define_model(self):
FLAGS.batch_size = 2
images_shape = [FLAGS.batch_size, 4, 4, 3]
images_np = np.zeros(shape=images_shape)
images = tf.constant(images_np, dtype=tf.float32)
labels = tf.one_hot([0] * FLAGS.batch_size, 2)
model = train._define_model(images, labels)
self.assertIsInstance(model, tfgan.StarGANModel)
self.assertShapeEqual(images_np, model.generated_data)
self.assertShapeEqual(images_np, model.reconstructed_data)
self.assertTrue(isinstance(model.discriminator_variables, list))
self.assertTrue(isinstance(model.generator_variables, list))
self.assertIsInstance(model.discriminator_scope, tf.VariableScope)
self.assertTrue(model.generator_scope, tf.VariableScope)
self.assertTrue(callable(model.discriminator_fn))
self.assertTrue(callable(model.generator_fn))
@mock.patch.object(tf.train, 'get_or_create_global_step', autospec=True)
def test_get_lr(self, mock_get_or_create_global_step):
FLAGS.max_number_of_steps = 10
base_lr = 0.01
with self.test_session(use_gpu=True) as sess:
mock_get_or_create_global_step.return_value = tf.constant(2)
lr_step2 = sess.run(train._get_lr(base_lr))
mock_get_or_create_global_step.return_value = tf.constant(9)
lr_step9 = sess.run(train._get_lr(base_lr))
self.assertAlmostEqual(base_lr, lr_step2)
self.assertAlmostEqual(base_lr * 0.2, lr_step9)
@mock.patch.object(tf.summary, 'scalar', autospec=True)
def test_define_train_ops(self, mock_summary_scalar):
FLAGS.batch_size = 2
FLAGS.generator_lr = 0.1
FLAGS.discriminator_lr = 0.01
images_shape = [FLAGS.batch_size, 4, 4, 3]
images = tf.zeros(images_shape, dtype=tf.float32)
labels = tf.one_hot([0] * FLAGS.batch_size, 2)
model = train._define_model(images, labels)
loss = tfgan.stargan_loss(model)
train_ops = train._define_train_ops(model, loss)
self.assertIsInstance(train_ops, tfgan.GANTrainOps)
mock_summary_scalar.assert_has_calls([
mock.call('generator_lr', mock.ANY),
mock.call('discriminator_lr', mock.ANY)
])
def test_get_train_step(self):
FLAGS.gen_disc_step_ratio = 0.5
train_steps = train._define_train_step()
self.assertEqual(1, train_steps.generator_train_steps)
self.assertEqual(2, train_steps.discriminator_train_steps)
FLAGS.gen_disc_step_ratio = 3
train_steps = train._define_train_step()
self.assertEqual(3, train_steps.generator_train_steps)
self.assertEqual(1, train_steps.discriminator_train_steps)
@mock.patch.object(
train.data_provider, 'provide_data', autospec=True)
def test_main(self, mock_provide_data):
FLAGS.image_file_patterns = ['/tmp/A/*.jpg', '/tmp/B/*.jpg', '/tmp/C/*.jpg']
FLAGS.max_number_of_steps = 10
FLAGS.batch_size = 2
num_domains = 3
images_shape = [FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3]
img_list = [tf.zeros(images_shape)] * num_domains
lbl_list = [tf.one_hot([0] * FLAGS.batch_size, num_domains)] * num_domains
mock_provide_data.return_value = (img_list, lbl_list)
train.main(None)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment