Unverified Commit 57b99319 authored by Joel Shor's avatar Joel Shor Committed by GitHub
Browse files

Merge pull request #3820 from joel-shor/master

Add cyclegan to open source tensorflow/models
parents ba2b8e00 f3a542b4
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing image data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def normalize_image(image):
"""Rescale from range [0, 255] to [-1, 1]."""
return (tf.to_float(image) - 127.5) / 127.5
def undo_normalize_image(normalized_image):
"""Convert to a numpy array that can be read by PIL."""
# Convert from NHWC to HWC.
normalized_image = np.squeeze(normalized_image, axis=0)
return np.uint8(normalized_image * 127.5 + 127.5)
def _sample_patch(image, patch_size):
"""Crop image to square shape and resize it to `patch_size`.
Args:
image: A 3D `Tensor` of HWC format.
patch_size: A Python scalar. The output image size.
Returns:
A 3D `Tensor` of HWC format which has the shape of
[patch_size, patch_size, 3].
"""
image_shape = tf.shape(image)
height, width = image_shape[0], image_shape[1]
target_size = tf.minimum(height, width)
image = tf.image.resize_image_with_crop_or_pad(image, target_size,
target_size)
# tf.image.resize_area only accepts 4D tensor, so expand dims first.
image = tf.expand_dims(image, axis=0)
image = tf.image.resize_images(image, [patch_size, patch_size])
image = tf.squeeze(image, axis=0)
# Force image num_channels = 3
image = tf.tile(image, [1, 1, tf.maximum(1, 4 - tf.shape(image)[2])])
image = tf.slice(image, [0, 0, 0], [patch_size, patch_size, 3])
return image
def full_image_to_patch(image, patch_size):
image = normalize_image(image)
# Sample a patch of fixed size.
image_patch = _sample_patch(image, patch_size)
image_patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
return image_patch
def _provide_custom_dataset(image_file_pattern,
batch_size,
shuffle=True,
num_threads=1,
patch_size=128):
"""Provides batches of custom image data.
Args:
image_file_pattern: A string of glob pattern of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the path to extract from the image. Defaults to 128.
Returns:
A float `Tensor` of shape [batch_size, patch_size, patch_size, 3]
representing a batch of images.
"""
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(image_file_pattern),
shuffle=shuffle,
capacity=5 * batch_size)
image_reader = tf.WholeFileReader()
_, image_bytes = image_reader.read(filename_queue)
image = tf.image.decode_image(image_bytes)
image_patch = full_image_to_patch(image, patch_size)
if shuffle:
return tf.train.shuffle_batch(
[image_patch],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
return tf.train.batch(
[image_patch],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)
def provide_custom_datasets(image_file_patterns,
batch_size,
shuffle=True,
num_threads=1,
patch_size=128):
"""Provides multiple batches of custom image data.
Args:
image_file_patterns: A list of glob patterns of image files.
batch_size: The number of images in each batch.
shuffle: Whether to shuffle the read images. Defaults to True.
num_threads: Number of prefetching threads. Defaults to 1.
patch_size: Size of the patch to extract from the image. Defaults to 128.
Returns:
A list of float `Tensor`s with the same size of `image_file_patterns`.
Each of the `Tensor` in the list has a shape of
[batch_size, patch_size, patch_size, 3] representing a batch of images.
Raises:
ValueError: If image_file_patterns is not a list or tuple.
"""
if not isinstance(image_file_patterns, (list, tuple)):
raise ValueError(
'`image_file_patterns` should be either list or tuple, but was {}.'.
format(type(image_file_patterns)))
custom_datasets = []
for pattern in image_file_patterns:
custom_datasets.append(
_provide_custom_dataset(
pattern,
batch_size=batch_size,
shuffle=shuffle,
num_threads=num_threads,
patch_size=patch_size))
return custom_datasets
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import data_provider
mock = tf.test.mock
class DataProviderTest(tf.test.TestCase):
def test_normalize_image(self):
image = tf.random_uniform(shape=(8, 8, 3), maxval=256, dtype=tf.int32)
rescaled_image = data_provider.normalize_image(image)
self.assertEqual(tf.float32, rescaled_image.dtype)
self.assertListEqual(image.shape.as_list(), rescaled_image.shape.as_list())
with self.test_session(use_gpu=True) as sess:
rescaled_image_out = sess.run(rescaled_image)
self.assertTrue(np.all(np.abs(rescaled_image_out) <= 1.0))
def test_sample_patch(self):
image = tf.zeros(shape=(8, 8, 3))
patch1 = data_provider._sample_patch(image, 7)
patch2 = data_provider._sample_patch(image, 10)
image = tf.zeros(shape=(8, 8, 1))
patch3 = data_provider._sample_patch(image, 10)
with self.test_session(use_gpu=True) as sess:
self.assertTupleEqual((7, 7, 3), sess.run(patch1).shape)
self.assertTupleEqual((10, 10, 3), sess.run(patch2).shape)
self.assertTupleEqual((10, 10, 3), sess.run(patch3).shape)
def _get_testdata_dir(self):
return os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata')
def test_custom_dataset_provider(self):
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg')
batch_size = 3
patch_size = 8
images = data_provider._provide_custom_dataset(
file_pattern, batch_size=batch_size, patch_size=patch_size)
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())
self.assertEqual(tf.float32, images.dtype)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_out = sess.run(images)
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_custom_datasets_provider(self):
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg')
batch_size = 3
patch_size = 8
images_list = data_provider.provide_custom_datasets(
[file_pattern, file_pattern],
batch_size=batch_size,
patch_size=patch_size)
for images in images_list:
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())
self.assertEqual(tf.float32, images.dtype)
with self.test_session(use_gpu=True) as sess:
sess.run(tf.local_variables_initializer())
with tf.contrib.slim.queues.QueueRunners(sess):
images_out_list = sess.run(images_list)
for images_out in images_out_list:
self.assertTupleEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Demo that makes inference requests against a running inference server."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import PIL
import tensorflow as tf
import data_provider
import networks
flags = tf.flags
tfgan = tf.contrib.gan
flags.DEFINE_string('checkpoint_path', '',
'CycleGAN checkpoint path created by train.py. '
'(e.g. "/mylogdir/model.ckpt-18442")')
flags.DEFINE_string(
'image_set_x_glob', '',
'Optional: Glob path to images of class X to feed through the CycleGAN.')
flags.DEFINE_string(
'image_set_y_glob', '',
'Optional: Glob path to images of class Y to feed through the CycleGAN.')
flags.DEFINE_string(
'generated_x_dir', '/tmp/generated_x/',
'If image_set_y_glob is defined, where to output the generated X '
'images.')
flags.DEFINE_string(
'generated_y_dir', '/tmp/generated_y/',
'If image_set_x_glob is defined, where to output the generated Y '
'images.')
flags.DEFINE_integer('patch_dim', 128,
'The patch size of images that was used in train.py.')
FLAGS = flags.FLAGS
def _make_dir_if_not_exists(dir_path):
"""Make a directory if it does not exist."""
if not tf.gfile.Exists(dir_path):
tf.gfile.MakeDirs(dir_path)
def _file_output_path(dir_path, input_file_path):
"""Create output path for an individual file."""
return os.path.join(dir_path, os.path.basename(input_file_path))
def make_inference_graph(model_name, patch_dim):
"""Build the inference graph for either the X2Y or Y2X GAN.
Args:
model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
patch_dim: An integer size of patches to feed to the generator.
Returns:
Tuple of (input_placeholder, generated_tensor).
"""
input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3])
# Expand HWC to NHWC
images_x = tf.expand_dims(
data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0)
with tf.variable_scope(model_name):
with tf.variable_scope('Generator'):
generated = networks.generator(images_x)
return input_hwc_pl, generated
def export(sess, input_pl, output_tensor, input_file_pattern, output_dir):
"""Exports inference outputs to an output directory.
Args:
sess: tf.Session with variables already loaded.
input_pl: tf.Placeholder for input (HWC format).
output_tensor: Tensor for generated outut images.
input_file_pattern: Glob file pattern for input images.
output_dir: Output directory.
"""
if output_dir:
_make_dir_if_not_exists(output_dir)
if input_file_pattern:
for file_path in tf.gfile.Glob(input_file_pattern):
# Grab a single image and run it through inference
input_np = np.asarray(PIL.Image.open(file_path))
output_np = sess.run(output_tensor, feed_dict={input_pl: input_np})
image_np = data_provider.undo_normalize_image(output_np)
output_path = _file_output_path(output_dir, file_path)
PIL.Image.fromarray(image_np).save(output_path)
def _validate_flags():
flags.register_validator('checkpoint_path', bool,
'Must provide `checkpoint_path`.')
flags.register_validator(
'generated_x_dir',
lambda x: False if (FLAGS.image_set_y_glob and not x) else True,
'Must provide `generated_x_dir`.')
flags.register_validator(
'generated_y_dir',
lambda x: False if (FLAGS.image_set_x_glob and not x) else True,
'Must provide `generated_y_dir`.')
def main(_):
_validate_flags()
images_x_hwc_pl, generated_y = make_inference_graph('ModelX2Y',
FLAGS.patch_dim)
images_y_hwc_pl, generated_x = make_inference_graph('ModelY2X',
FLAGS.patch_dim)
# Restore all the variables that were saved in the checkpoint.
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, FLAGS.checkpoint_path)
export(sess, images_x_hwc_pl, generated_y, FLAGS.image_set_x_glob,
FLAGS.generated_y_dir)
export(sess, images_y_hwc_pl, generated_x, FLAGS.image_set_y_glob,
FLAGS.generated_x_dir)
if __name__ == '__main__':
tf.app.run()
"""Tests for CycleGAN inference demo."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import PIL
import tensorflow as tf
import inference_demo
import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
def _basenames_from_glob(file_glob):
return [os.path.basename(file_path) for file_path in tf.gfile.Glob(file_glob)]
class InferenceDemoTest(tf.test.TestCase):
def setUp(self):
self._export_dir = os.path.join(FLAGS.test_tmpdir, 'export')
self._ckpt_path = os.path.join(self._export_dir, 'model.ckpt')
self._image_glob = os.path.join(
FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata', '*.jpg')
self._genx_dir = os.path.join(FLAGS.test_tmpdir, 'genx')
self._geny_dir = os.path.join(FLAGS.test_tmpdir, 'geny')
@mock.patch.object(tfgan, 'gan_train', autospec=True)
def testTrainingAndInferenceGraphsAreCompatible(self, unused_mock_gan_train):
# Training and inference graphs can get out of sync if changes are made
# to one but not the other. This test will keep them in sync.
# Save the training graph
train_sess = tf.Session()
FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg'
FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg'
FLAGS.batch_size = 3
FLAGS.patch_size = 128
FLAGS.generator_lr = 0.02
FLAGS.discriminator_lr = 0.3
FLAGS.train_log_dir = self._export_dir
FLAGS.master = 'master'
FLAGS.task = 0
FLAGS.cycle_consistency_loss_weight = 2.0
FLAGS.max_number_of_steps = 1
train.main(None)
init_op = tf.global_variables_initializer()
train_sess.run(init_op)
train_saver = tf.train.Saver()
train_saver.save(train_sess, save_path=self._ckpt_path)
# Create inference graph
tf.reset_default_graph()
FLAGS.patch_dim = FLAGS.patch_size
tf.logging.info('dir_path: {}'.format(os.listdir(self._export_dir)))
FLAGS.checkpoint_path = self._ckpt_path
FLAGS.image_set_x_glob = self._image_glob
FLAGS.image_set_y_glob = self._image_glob
FLAGS.generated_x_dir = self._genx_dir
FLAGS.generated_y_dir = self._geny_dir
inference_demo.main(None)
tf.logging.info('gen x: {}'.format(os.listdir(self._genx_dir)))
# Check that the image names match
self.assertSetEqual(
set(_basenames_from_glob(FLAGS.image_set_x_glob)),
set(os.listdir(FLAGS.generated_y_dir)))
self.assertSetEqual(
set(_basenames_from_glob(FLAGS.image_set_y_glob)),
set(os.listdir(FLAGS.generated_x_dir)))
# Check that each image in the directory looks as expected
for directory in [FLAGS.generated_x_dir, FLAGS.generated_x_dir]:
for base_name in os.listdir(directory):
image_path = os.path.join(directory, base_name)
self.assertRealisticImage(image_path)
def assertRealisticImage(self, image_path):
tf.logging.info('Testing {} for realism.'.format(image_path))
# If the normalization is off or forgotten, then the generated image is
# all one pixel value. This tests that different pixel values are achieved.
input_np = np.asarray(PIL.Image.open(image_path))
self.assertEqual(len(input_np.shape), 3)
self.assertGreaterEqual(input_np.shape[0], 50)
self.assertGreaterEqual(input_np.shape[1], 50)
self.assertGreater(np.mean(input_np), 20)
self.assertGreater(np.var(input_np), 100)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a CycleGAN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import data_provider
import networks
flags = tf.flags
tfgan = tf.contrib.gan
flags.DEFINE_string('image_set_x_file_pattern', None,
'File pattern of images in image set X')
flags.DEFINE_string('image_set_y_file_pattern', None,
'File pattern of images in image set Y')
flags.DEFINE_integer('batch_size', 1, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 64, 'The patch size of images.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/cyclegan/',
'Directory where to write event logs.')
flags.DEFINE_float('generator_lr', 0.0002,
'The compression model learning rate.')
flags.DEFINE_float('discriminator_lr', 0.0001,
'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 500000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_float('cycle_consistency_loss_weight', 10.0,
'The weight of cycle consistency loss')
FLAGS = flags.FLAGS
def _define_model(images_x, images_y):
"""Defines a CycleGAN model that maps between images_x and images_y.
Args:
images_x: A 4D float `Tensor` of NHWC format. Images in set X.
images_y: A 4D float `Tensor` of NHWC format. Images in set Y.
Returns:
A `CycleGANModel` namedtuple.
"""
cyclegan_model = tfgan.cyclegan_model(
generator_fn=networks.generator,
discriminator_fn=networks.discriminator,
data_x=images_x,
data_y=images_y)
# Add summaries for generated images.
tfgan.eval.add_image_comparison_summaries(
cyclegan_model, num_comparisons=3, display_diffs=False)
tfgan.eval.add_gan_model_image_summaries(
cyclegan_model, grid_size=int(np.sqrt(FLAGS.batch_size)))
return cyclegan_model
def _get_lr(base_lr):
"""Returns a learning rate `Tensor`.
Args:
base_lr: A scalar float `Tensor` or a Python number. The base learning
rate.
Returns:
A scalar float `Tensor` of learning rate which equals `base_lr` when the
global training step is less than FLAGS.max_number_of_steps / 2, afterwards
it linearly decays to zero.
"""
global_step = tf.train.get_or_create_global_step()
lr_constant_steps = FLAGS.max_number_of_steps // 2
def _lr_decay():
return tf.train.polynomial_decay(
learning_rate=base_lr,
global_step=(global_step - lr_constant_steps),
decay_steps=(FLAGS.max_number_of_steps - lr_constant_steps),
end_learning_rate=0.0)
return tf.cond(global_step < lr_constant_steps, lambda: base_lr, _lr_decay)
def _get_optimizer(gen_lr, dis_lr):
"""Returns generator optimizer and discriminator optimizer.
Args:
gen_lr: A scalar float `Tensor` or a Python number. The Generator learning
rate.
dis_lr: A scalar float `Tensor` or a Python number. The Discriminator
learning rate.
Returns:
A tuple of generator optimizer and discriminator optimizer.
"""
# beta1 follows
# https://github.com/junyanz/CycleGAN/blob/master/options.lua
gen_opt = tf.train.AdamOptimizer(gen_lr, beta1=0.5, use_locking=True)
dis_opt = tf.train.AdamOptimizer(dis_lr, beta1=0.5, use_locking=True)
return gen_opt, dis_opt
def _define_train_ops(cyclegan_model, cyclegan_loss):
"""Defines train ops that trains `cyclegan_model` with `cyclegan_loss`.
Args:
cyclegan_model: A `CycleGANModel` namedtuple.
cyclegan_loss: A `CycleGANLoss` namedtuple containing all losses for
`cyclegan_model`.
Returns:
A `GANTrainOps` namedtuple.
"""
gen_lr = _get_lr(FLAGS.generator_lr)
dis_lr = _get_lr(FLAGS.discriminator_lr)
gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr)
train_ops = tfgan.gan_train_ops(
cyclegan_model,
cyclegan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
return train_ops
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
with tf.name_scope('inputs'):
images_x, images_y = data_provider.provide_custom_datasets(
[FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern],
batch_size=FLAGS.batch_size,
patch_size=FLAGS.patch_size)
# Define CycleGAN model.
cyclegan_model = _define_model(images_x, images_y)
# Define CycleGAN loss.
cyclegan_loss = tfgan.cyclegan_loss(
cyclegan_model,
cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight,
tensor_pool_fn=tfgan.features.tensor_pool)
# Define CycleGAN train ops.
train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)
# Training
train_steps = tfgan.GANTrainSteps(1, 1)
status_message = tf.string_join(
[
'Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())
],
name='status_message')
if not FLAGS.max_number_of_steps:
return
tfgan.gan_train(
train_ops,
FLAGS.train_log_dir,
get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
hooks=[
tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)
],
master=FLAGS.master,
is_chief=FLAGS.task == 0)
if __name__ == '__main__':
tf.flags.mark_flag_as_required('image_set_x_file_pattern')
tf.flags.mark_flag_as_required('image_set_y_file_pattern')
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cyclegan.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
def _test_generator(input_images):
"""Simple generator function."""
return input_images * tf.get_variable('dummy_g', initializer=2.0)
def _test_discriminator(image_batch, unused_conditioning=None):
"""Simple discriminator function."""
return tf.contrib.layers.flatten(
image_batch * tf.get_variable('dummy_d', initializer=2.0))
train.networks.generator = _test_generator
train.networks.discriminator = _test_discriminator
class TrainTest(tf.test.TestCase):
@mock.patch.object(tfgan, 'eval', autospec=True)
def test_define_model(self, mock_eval):
FLAGS.batch_size = 2
images_shape = [FLAGS.batch_size, 4, 4, 3]
images_x_np = np.zeros(shape=images_shape)
images_y_np = np.zeros(shape=images_shape)
images_x = tf.constant(images_x_np, dtype=tf.float32)
images_y = tf.constant(images_y_np, dtype=tf.float32)
cyclegan_model = train._define_model(images_x, images_y)
self.assertIsInstance(cyclegan_model, tfgan.CycleGANModel)
self.assertShapeEqual(images_x_np, cyclegan_model.reconstructed_x)
self.assertShapeEqual(images_y_np, cyclegan_model.reconstructed_y)
mock_eval.add_image_comparison_summaries.assert_called_once()
mock_eval.add_gan_model_image_summaries.assert_called_once()
@mock.patch.object(train.networks, 'generator', autospec=True)
@mock.patch.object(train.networks, 'discriminator', autospec=True)
@mock.patch.object(
tf.train, 'get_or_create_global_step', autospec=True)
def test_get_lr(self, mock_get_or_create_global_step,
unused_mock_discriminator, unused_mock_generator):
FLAGS.max_number_of_steps = 10
base_lr = 0.01
with self.test_session(use_gpu=True) as sess:
mock_get_or_create_global_step.return_value = tf.constant(2)
lr_step2 = sess.run(train._get_lr(base_lr))
mock_get_or_create_global_step.return_value = tf.constant(9)
lr_step9 = sess.run(train._get_lr(base_lr))
self.assertAlmostEqual(base_lr, lr_step2)
self.assertAlmostEqual(base_lr * 0.2, lr_step9)
@mock.patch.object(tf.train, 'AdamOptimizer', autospec=True)
def test_get_optimizer(self, mock_adam_optimizer):
gen_lr, dis_lr = 0.1, 0.01
train._get_optimizer(gen_lr=gen_lr, dis_lr=dis_lr)
mock_adam_optimizer.assert_has_calls([
mock.call(gen_lr, beta1=mock.ANY, use_locking=True),
mock.call(dis_lr, beta1=mock.ANY, use_locking=True)
])
@mock.patch.object(tf.summary, 'scalar', autospec=True)
def test_define_train_ops(self, mock_summary_scalar):
FLAGS.batch_size = 2
FLAGS.generator_lr = 0.1
FLAGS.discriminator_lr = 0.01
images_shape = [FLAGS.batch_size, 4, 4, 3]
images_x = tf.zeros(images_shape, dtype=tf.float32)
images_y = tf.zeros(images_shape, dtype=tf.float32)
cyclegan_model = train._define_model(images_x, images_y)
cyclegan_loss = tfgan.cyclegan_loss(
cyclegan_model, cycle_consistency_loss_weight=10.0)
train_ops = train._define_train_ops(cyclegan_model, cyclegan_loss)
self.assertIsInstance(train_ops, tfgan.GANTrainOps)
mock_summary_scalar.assert_has_calls([
mock.call('generator_lr', mock.ANY),
mock.call('discriminator_lr', mock.ANY)
])
@mock.patch.object(tf, 'gfile', autospec=True)
@mock.patch.object(train, 'data_provider', autospec=True)
@mock.patch.object(train, '_define_model', autospec=True)
@mock.patch.object(tfgan, 'cyclegan_loss', autospec=True)
@mock.patch.object(train, '_define_train_ops', autospec=True)
@mock.patch.object(tfgan, 'gan_train', autospec=True)
def test_main(self, mock_gan_train, mock_define_train_ops, mock_cyclegan_loss,
mock_define_model, mock_data_provider, mock_gfile):
FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg'
FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg'
FLAGS.batch_size = 3
FLAGS.patch_size = 8
FLAGS.generator_lr = 0.02
FLAGS.discriminator_lr = 0.3
FLAGS.train_log_dir = '/tmp/foo'
FLAGS.master = 'master'
FLAGS.task = 0
FLAGS.cycle_consistency_loss_weight = 2.0
FLAGS.max_number_of_steps = 1
mock_data_provider.provide_custom_datasets.return_value = (tf.zeros(
[1, 2], dtype=tf.float32), tf.zeros([1, 2], dtype=tf.float32))
train.main(None)
mock_data_provider.provide_custom_datasets.assert_called_once_with(
['/tmp/x/*.jpg', '/tmp/y/*.jpg'], batch_size=3, patch_size=8)
mock_define_model.assert_called_once_with(mock.ANY, mock.ANY)
mock_cyclegan_loss.assert_called_once_with(
mock_define_model.return_value,
cycle_consistency_loss_weight=2.0,
tensor_pool_fn=mock.ANY)
mock_define_train_ops.assert_called_once_with(
mock_define_model.return_value, mock_cyclegan_loss.return_value)
mock_gan_train.assert_called_once_with(
mock_define_train_ops.return_value,
'/tmp/foo',
get_hooks_fn=mock.ANY,
hooks=mock.ANY,
master='master',
is_chief=True)
if __name__ == '__main__':
tf.test.main()
...@@ -19,7 +19,7 @@ from __future__ import division ...@@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from google3.third_party.tensorflow_models.gan.pix2pix import networks import networks
class Pix2PixTest(tf.test.TestCase): class Pix2PixTest(tf.test.TestCase):
......
...@@ -23,7 +23,7 @@ from __future__ import print_function ...@@ -23,7 +23,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
import data_provider import data_provider
from google3.third_party.tensorflow_models.gan.pix2pix import networks import networks
flags = tf.flags flags = tf.flags
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from google3.third_party.tensorflow_models.gan.pix2pix import train import train
FLAGS = tf.flags.FLAGS FLAGS = tf.flags.FLAGS
mock = tf.test.mock mock = tf.test.mock
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment