Commit b6907e8d authored by Joel Shor's avatar Joel Shor Committed by joel-shor
Browse files

Project import generated by Copybara.

PiperOrigin-RevId: 177165761
parent 220772b5
......@@ -37,26 +37,25 @@ class UtilTest(tf.test.TestCase):
num_classes=3,
num_images_per_class=1)
def test_get_inception_scores(self):
# Mock `inception_score` which is expensive.
with mock.patch.object(
util.tfgan.eval, 'inception_score') as mock_inception_score:
mock_inception_score.return_value = 1.0
util.get_inception_scores(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
def test_get_frechet_inception_distance(self):
# Mock `frechet_inception_distance` which is expensive.
with mock.patch.object(
util.tfgan.eval, 'frechet_inception_distance') as mock_fid:
mock_fid.return_value = 1.0
util.get_frechet_inception_distance(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
# Mock `inception_score` which is expensive.
@mock.patch.object(util.tfgan.eval, 'inception_score', autospec=True)
def test_get_inception_scores(self, mock_inception_score):
mock_inception_score.return_value = 1.0
util.get_inception_scores(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
# Mock `frechet_inception_distance` which is expensive.
@mock.patch.object(util.tfgan.eval, 'frechet_inception_distance',
autospec=True)
def test_get_frechet_inception_distance(self, mock_fid):
mock_fid.return_value = 1.0
util.get_frechet_inception_distance(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
if __name__ == '__main__':
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the compression image data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(split_name, batch_size, dataset_dir,
dataset_name='imagenet', num_readers=1, num_threads=1,
patch_size=128):
"""Provides batches of image data for compression.
Args:
split_name: Either 'train' or 'validation'.
batch_size: The number of images in each batch.
dataset_dir: The directory where the data can be found. If `None`, use
default.
dataset_name: Name of the dataset.
num_readers: Number of dataset readers.
num_threads: Number of prefetching threads.
patch_size: Size of the path to extract from the image.
Returns:
images: A `Tensor` of size [batch_size, patch_size, patch_size, channels]
"""
randomize = split_name == 'train'
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=randomize)
[image] = provider.get(['image'])
# Sample a patch of fixed size.
patch = tf.image.resize_image_with_crop_or_pad(image, patch_size, patch_size)
patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
# Preprocess the images. Make the range lie in a strictly smaller range than
# [-1, 1], so that network outputs aren't forced to the extreme ranges.
patch = (tf.to_float(patch) - 128.0) / 142.0
if randomize:
image_batch = tf.train.shuffle_batch(
[patch],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
image_batch = tf.train.batch(
[patch],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)
return image_batch
def float_image_to_uint8(image):
"""Convert float image in ~[-0.9, 0.9) to [0, 255] uint8.
Args:
image: An image tensor. Values should be in [-0.9, 0.9).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 142.0) + 128.0
return tf.cast(image, tf.uint8)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
def _test_data_provider_helper(self, split_name):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/')
batch_size = 3
patch_size = 8
images = data_provider.provide_data(
split_name, batch_size, dataset_dir, patch_size=8)
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out = sess.run(images)
self.assertEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_data_provider_train(self):
self._test_data_provider_helper('train')
def test_data_provider_validation(self):
self._test_data_provider_helper('validation')
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained compression model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
import summaries
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/compression/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/compression/',
'Directory where the results are saved to.')
flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'forever.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
# Compression-specific flags.
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.')
flags.DEFINE_integer('bits_per_patch', 1230,
'The number of bits to produce per patch.')
flags.DEFINE_integer('model_depth', 64,
'Number of filters for compression model')
def main(_, run_eval_loop=True):
with tf.name_scope('inputs'):
images = data_provider.provide_data(
'validation', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
patch_size=FLAGS.patch_size)
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('generator'):
reconstructions, _, prebinary = networks.compression_model(
images,
num_bits=FLAGS.bits_per_patch,
depth=FLAGS.model_depth,
is_training=False)
summaries.add_reconstruction_summaries(images, reconstructions, prebinary)
# Visualize losses.
pixel_loss_per_example = tf.reduce_mean(
tf.abs(images - reconstructions), axis=[1, 2, 3])
pixel_loss = tf.reduce_mean(pixel_loss_per_example)
tf.summary.histogram('pixel_l1_loss_hist', pixel_loss_per_example)
tf.summary.scalar('pixel_l1_loss', pixel_loss)
# Create ops to write images to disk.
uint8_images = data_provider.float_image_to_uint8(images)
uint8_reconstructions = data_provider.float_image_to_uint8(reconstructions)
uint8_reshaped = summaries.stack_images(uint8_images, uint8_reconstructions)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'compression.png'),
tf.image.encode_png(uint8_reshaped[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import eval # pylint:disable=redefined-builtin
class EvalTest(tf.test.TestCase):
def test_build_graph(self):
eval.main(None, run_eval_loop=False)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Imagenet dataset.
# 2. Trains image compression model on patches from Imagenet.
# 3. Evaluates the models and writes sample images to disk.
#
# Usage:
# cd models/research/gan/image_compression
# ./launch_jobs.sh ${weight_factor} ${git_repo}
set -e
# Weight of the adversarial loss.
weight_factor=$1
if [[ "$weight_factor" == "" ]]; then
echo "'weight_factor' must not be empty."
exit
fi
# Location of the git repository.
git_repo=$2
if [[ "$git_repo" == "" ]]; then
echo "'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/compression-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR=/tmp/compression-model/eval
# Where the dataset is saved to.
DATASET_DIR=/tmp/imagenet-data
export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/slim:$git_repo/research/slim/nets
# A helper function for printing pretty output.
Banner () {
local text=$1
local green='\033[0;32m'
local nc='\033[0m' # No color.
echo -e "${green}${text}${nc}"
}
# Download the dataset.
"${git_repo}/research/slim/datasets/download_and_convert_imagenet.sh" ${DATASET_DIR}
# Run the compression model.
NUM_STEPS=10000
MODEL_TRAIN_DIR="${TRAIN_DIR}/wt${weight_factor}"
Banner "Starting training an image compression model for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/image_compression/train.py" \
--train_log_dir=${MODEL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--weight_factor=${weight_factor} \
--alsologtostderr
Banner "Finished training image compression model ${NUM_STEPS} steps."
# Run evaluation.
MODEL_EVAL_DIR="${TRAIN_DIR}/eval/wt${weight_factor}"
Banner "Starting evaluation of image compression model..."
python "${git_repo}/research/gan/image_compression/eval.py" \
--checkpoint_dir=${MODEL_TRAIN_DIR} \
--eval_dir=${MODEL_EVAL_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_evaluation=1
Banner "Finished evaluation. See ${MODEL_EVAL_DIR} for output images."
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN compression example using TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.nets import dcgan
from slim.nets import pix2pix
def _last_conv_layer(end_points):
""""Returns the last convolutional layer from an endpoints dictionary."""
conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
conv_list.sort()
return end_points[conv_list[-1]]
def _encoder(img_batch, is_training=True, bits=64, depth=64):
"""Maps images to internal representation.
Args:
img_batch: Stuff
is_training: Stuff
bits: Number of bits per patch.
depth: Stuff
Returns:
Real-valued 2D Tensor of size [batch_size, bits].
"""
_, end_points = dcgan.discriminator(
img_batch, depth=depth, is_training=is_training, scope='Encoder')
# (joelshor): Make the DCGAN convolutional layer that converts to logits
# not trainable, since it doesn't affect the encoder output.
# Get the pre-logit layer, which is the last conv.
net = _last_conv_layer(end_points)
# Transform the features to the proper number of bits.
with tf.variable_scope('EncoderTransformer'):
encoded = tf.contrib.layers.conv2d(net, bits, kernel_size=1, stride=1,
padding='VALID', normalizer_fn=None,
activation_fn=None)
encoded = tf.squeeze(encoded, [1, 2])
encoded.shape.assert_has_rank(2)
# Map encoded to the range [-1, 1].
return tf.nn.softsign(encoded)
def _binarizer(prebinary_codes, is_training):
"""Binarize compression logits.
During training, add noise, as in https://arxiv.org/pdf/1611.01704.pdf. During
eval, map [-1, 1] -> {-1, 1}.
Args:
prebinary_codes: Floating-point tensors corresponding to pre-binary codes.
Shape is [batch, code_length].
is_training: A python bool. If True, add noise. If false, binarize.
Returns:
Binarized codes. Shape is [batch, code_length].
Raises:
ValueError: If the shape of `prebinary_codes` isn't static.
"""
if is_training:
# In order to train codes that can be binarized during eval, we add noise as
# in https://arxiv.org/pdf/1611.01704.pdf. Another option is to use a
# stochastic node, as in https://arxiv.org/abs/1608.05148.
noise = tf.random_uniform(
prebinary_codes.shape,
minval=-1.0,
maxval=1.0)
return prebinary_codes + noise
else:
return tf.sign(prebinary_codes)
def _decoder(codes, final_size, is_training, depth=64):
"""Compression decoder."""
decoded_img, _ = dcgan.generator(
codes,
depth=depth,
final_size=final_size,
num_outputs=3,
is_training=is_training,
scope='Decoder')
# Map output to [-1, 1].
# Use softsign instead of tanh, as per empirical results of
# http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
return tf.nn.softsign(decoded_img)
def _validate_image_inputs(image_batch):
image_batch.shape.assert_has_rank(4)
image_batch.shape[1:].assert_is_fully_defined()
def compression_model(image_batch, num_bits=64, depth=64, is_training=True):
"""Image compression model.
Args:
image_batch: A batch of images to compress and reconstruct. Images should
be normalized already. Shape is [batch, height, width, channels].
num_bits: Desired number of bits per image in the compressed representation.
depth: The base number of filters for the encoder and decoder networks.
is_training: A python bool. If False, run in evaluation mode.
Returns:
uncompressed images, binary codes, prebinary codes
"""
image_batch = tf.convert_to_tensor(image_batch)
_validate_image_inputs(image_batch)
final_size = image_batch.shape.as_list()[1]
prebinary_codes = _encoder(image_batch, is_training, num_bits, depth)
binary_codes = _binarizer(prebinary_codes, is_training)
uncompressed_imgs = _decoder(binary_codes, final_size, is_training, depth)
return uncompressed_imgs, binary_codes, prebinary_codes
def discriminator(image_batch, unused_conditioning=None, depth=64):
"""A thin wrapper around the pix2pix discriminator to conform to TFGAN API."""
logits, _ = pix2pix.pix2pix_discriminator(
image_batch, num_filters=[depth, 2 * depth, 4 * depth, 8 * depth])
return tf.layers.flatten(logits)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import networks
class NetworksTest(tf.test.TestCase):
def test_last_conv_layer(self):
x = tf.constant(1.0)
y = tf.constant(0.0)
end_points = {
'silly': y,
'conv2': y,
'conv4': x,
'logits': y,
'conv-1': y,
}
self.assertEqual(x, networks._last_conv_layer(end_points))
def test_generator_run(self):
img_batch = tf.zeros([3, 16, 16, 3])
model_output = networks.compression_model(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(model_output)
def test_generator_graph(self):
for i, batch_size in zip(xrange(3, 7), xrange(3, 11, 2)):
tf.reset_default_graph()
patch_size = 2 ** i
bits = 2 ** i
img = tf.ones([batch_size, patch_size, patch_size, 3])
uncompressed, binary_codes, prebinary = networks.compression_model(
img, bits)
self.assertAllEqual([batch_size, patch_size, patch_size, 3],
uncompressed.shape.as_list())
self.assertEqual([batch_size, bits], binary_codes.shape.as_list())
self.assertEqual([batch_size, bits], prebinary.shape.as_list())
def test_generator_invalid_input(self):
wrong_dim_input = tf.zeros([5, 32, 32])
with self.assertRaisesRegexp(ValueError, 'Shape .* must have rank 4'):
networks.compression_model(wrong_dim_input)
not_fully_defined = tf.placeholder(tf.float32, [3, None, 32, 3])
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
networks.compression_model(not_fully_defined)
def test_discriminator_run(self):
img_batch = tf.zeros([3, 70, 70, 3])
disc_output = networks.discriminator(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(disc_output)
def test_discriminator_graph(self):
# Check graph construction for a number of image size/depths and batch
# sizes.
for batch_size, patch_size in zip([3, 6], [70, 128]):
tf.reset_default_graph()
img = tf.ones([batch_size, patch_size, patch_size, 3])
disc_output = networks.discriminator(img)
self.assertEqual(2, disc_output.shape.ndims)
self.assertEqual(batch_size, disc_output.shape[0])
def test_discriminator_invalid_input(self):
wrong_dim_input = tf.zeros([5, 32, 32])
with self.assertRaisesRegexp(ValueError, 'Shape must be rank 4'):
networks.discriminator(wrong_dim_input)
not_fully_defined = tf.placeholder(tf.float32, [3, None, 32, 3])
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
networks.compression_model(not_fully_defined)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Summaries utility file to share between train and eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
tfgan = tf.contrib.gan
def add_reconstruction_summaries(images, reconstructions, prebinary,
num_imgs_to_visualize=8):
"""Adds image summaries."""
reshaped_img = stack_images(images, reconstructions, num_imgs_to_visualize)
tf.summary.image('real_vs_reconstruction', reshaped_img, max_outputs=1)
if prebinary is not None:
tf.summary.histogram('prebinary_codes', prebinary)
def stack_images(images, reconstructions, num_imgs_to_visualize=8):
"""Stack and reshape images to see compression effects."""
to_reshape = (tf.unstack(images)[:num_imgs_to_visualize] +
tf.unstack(reconstructions)[:num_imgs_to_visualize])
reshaped_img = tfgan.eval.image_reshaper(
to_reshape, num_cols=num_imgs_to_visualize)
return reshaped_img
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint:disable=line-too-long
"""Trains an image compression network with an adversarial loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
import networks
import summaries
tfgan = tf.contrib.gan
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.')
flags.DEFINE_integer('bits_per_patch', 1230,
'The number of bits to produce per patch.')
flags.DEFINE_integer('model_depth', 64,
'Number of filters for compression model')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/compression/',
'Directory where to write event logs.')
flags.DEFINE_float('generator_lr', 1e-5,
'The compression model learning rate.')
flags.DEFINE_float('discriminator_lr', 1e-6,
'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 2000000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_float(
'weight_factor', 10000.0,
'How much to weight the adversarial loss relative to pixel loss.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Put input pipeline on CPU to reserve GPU for training.
with tf.name_scope('inputs'), tf.device('/cpu:0'):
images = data_provider.provide_data(
'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
patch_size=FLAGS.patch_size)
# Manually define a GANModel tuple. This is useful when we have custom
# code to track variables. Note that we could replace all of this with a
# call to `tfgan.gan_model`, but we don't in order to demonstrate some of
# TFGAN's flexibility.
with tf.variable_scope('generator') as gen_scope:
reconstructions, _, prebinary = networks.compression_model(
images,
num_bits=FLAGS.bits_per_patch,
depth=FLAGS.model_depth)
gan_model = _get_gan_model(
generator_inputs=images,
generated_data=reconstructions,
real_data=images,
generator_scope=gen_scope)
summaries.add_reconstruction_summaries(images, reconstructions, prebinary)
tfgan.eval.add_gan_model_summaries(gan_model)
# Define the GANLoss tuple using standard library functions.
with tf.name_scope('loss'):
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.least_squares_generator_loss,
discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss,
add_summaries=FLAGS.weight_factor > 0)
# Define the standard pixel loss.
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data,
ord=1)
# Modify the loss tuple to include the pixel loss. Add summaries as well.
gan_loss = tfgan.losses.combine_adversarial_loss(
gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
with tf.name_scope('train_ops'):
gen_lr, dis_lr = _lr(FLAGS.generator_lr, FLAGS.discriminator_lr)
gen_opt, dis_opt = _optimizer(gen_lr, dis_lr)
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
# Determine the number of generator vs discriminator steps.
train_steps = tfgan.GANTrainSteps(
generator_train_steps=1,
discriminator_train_steps=int(FLAGS.weight_factor > 0))
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
status_message = tf.string_join(
['Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())],
name='status_message')
if FLAGS.max_number_of_steps == 0: return
tfgan.gan_train(
train_ops,
FLAGS.train_log_dir,
tfgan.get_sequential_train_hooks(train_steps),
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)],
master=FLAGS.master,
is_chief=FLAGS.task == 0)
def _optimizer(gen_lr, dis_lr):
# First is generator optimizer, second is discriminator.
adam_kwargs = {
'epsilon': 1e-8,
'beta1': 0.5,
}
return (tf.train.AdamOptimizer(gen_lr, **adam_kwargs),
tf.train.AdamOptimizer(dis_lr, **adam_kwargs))
def _lr(gen_lr_base, dis_lr_base):
"""Return the generator and discriminator learning rates."""
gen_lr_kwargs = {
'decay_steps': 60000,
'decay_rate': 0.9,
'staircase': True,
}
gen_lr = tf.train.exponential_decay(
learning_rate=gen_lr_base,
global_step=tf.train.get_or_create_global_step(),
**gen_lr_kwargs)
dis_lr = dis_lr_base
return gen_lr, dis_lr
def _get_gan_model(generator_inputs, generated_data, real_data,
generator_scope):
"""Manually construct and return a GANModel tuple."""
generator_vars = tf.contrib.framework.get_trainable_variables(generator_scope)
discriminator_fn = networks.discriminator
with tf.variable_scope('discriminator') as dis_scope:
discriminator_gen_outputs = discriminator_fn(generated_data)
with tf.variable_scope(dis_scope, reuse=True):
discriminator_real_outputs = discriminator_fn(real_data)
discriminator_vars = tf.contrib.framework.get_trainable_variables(
dis_scope)
# Manually construct GANModel tuple.
gan_model = tfgan.GANModel(
generator_inputs=generator_inputs,
generated_data=generated_data,
generator_variables=generator_vars,
generator_scope=generator_scope,
generator_fn=None, # not necessary
real_data=real_data,
discriminator_real_outputs=discriminator_real_outputs,
discriminator_gen_outputs=discriminator_gen_outputs,
discriminator_variables=discriminator_vars,
discriminator_scope=dis_scope,
discriminator_fn=discriminator_fn)
return gan_model
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for image_compression.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
def _test_build_graph_helper(self, weight_factor):
FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor
batch_size = 3
patch_size = 16
FLAGS.batch_size = batch_size
FLAGS.patch_size = patch_size
mock_imgs = np.zeros([batch_size, patch_size, patch_size, 3],
dtype=np.float32)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = mock_imgs
train.main(None)
def test_build_graph_noadversarialloss(self):
self._test_build_graph_helper(0.0)
def test_build_graph_adversarialloss(self):
self._test_build_graph_helper(1.0)
if __name__ == '__main__':
tf.test.main()
......@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf
from google3.third_party.tensorflow_models.gan.mnist import eval # pylint:disable=redefined-builtin
import eval # pylint:disable=redefined-builtin
class EvalTest(tf.test.TestCase):
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.mnist.train."""
"""Tests for mnist.train."""
from __future__ import absolute_import
from __future__ import division
......@@ -30,7 +30,8 @@ mock = tf.test.mock
class TrainTest(tf.test.TestCase):
def test_run_one_train_step(self):
@mock.patch.object(train, 'data_provider', autospec=True)
def test_run_one_train_step(self, mock_data_provider):
FLAGS.max_number_of_steps = 1
FLAGS.gan_type = 'unconditional'
FLAGS.batch_size = 5
......@@ -42,10 +43,9 @@ class TrainTest(tf.test.TestCase):
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = (
mock_imgs, mock_lbls, None)
train.main(None)
mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None)
train.main(None)
def _test_build_graph_helper(self, gan_type):
FLAGS.max_number_of_steps = 0
......
......@@ -31,7 +31,8 @@ mock = tf.test.mock
class TrainTest(tf.test.TestCase):
def test_full_flow(self):
@mock.patch.object(train, 'data_provider', autospec=True)
def test_full_flow(self, mock_data_provider):
FLAGS.eval_dir = self.get_temp_dir()
FLAGS.batch_size = 16
FLAGS.max_number_of_steps = 2
......@@ -42,10 +43,9 @@ class TrainTest(tf.test.TestCase):
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = (
mock_imgs, mock_lbls, None)
train.main(None)
mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None)
train.main(None)
if __name__ == '__main__':
......
......@@ -23,48 +23,27 @@
"metadata": {},
"source": [
"## Table of Contents\n",
"\n",
"<a href=#installation_and_setup>Installation and Setup</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#download_data>Download Data</a>\n",
"\n",
"<a href=#unconditional_example>Unconditional GAN example</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#unconditional_input>Input pipeline</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#unconditional_model>Model</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#unconditional_loss>Loss</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#unconditional_train>Train and evaluation</a>\n",
"\n",
"<a href=#ganestimator_example>GANEstimator example</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#ganestimator_input>Input pipeline</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#ganestimator_train>Train</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#ganestimator_eval>Eval</a>\n",
"\n",
"<a href=#conditional_example>Conditional GAN example</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#conditional_input>Input pipeline</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#conditional_model>Model</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#conditional_loss>Loss</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#conditional_train>Train and evaluation</a>\n",
"\n",
"<a href=#infogan_example>InfoGAN example</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#infogan_input>Input pipeline</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#infogan_model>Model</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#infogan_loss>Loss</a>\n",
"\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href=#infogan_train>Train and evaluation</a>"
"<a href='#installation_and_setup'>Installation and Setup</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#download_data'>Download Data</a><br>\n",
"<a href='#unconditional_example'>Unconditional GAN example</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#unconditional_input'>Input pipeline</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#unconditional_model'>Model</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#unconditional_loss'>Loss</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#unconditional_train'>Train and evaluation</a><br>\n",
"<a href='#ganestimator_example'>GANEstimator example</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#ganestimator_input'>Input pipeline</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#ganestimator_train'>Train</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#ganestimator_eval'>Eval</a><br>\n",
"<a href='#conditional_example'>Conditional GAN example</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#conditional_input'>Input pipeline</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#conditional_model'>Model</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#conditional_loss'>Loss</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#conditional_train'>Train and evaluation</a><br>\n",
"<a href='#infogan_example'>InfoGAN example</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#infogan_input'>Input pipeline</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#infogan_model'>Model</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#infogan_loss'>Loss</a><br>\n",
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='#infogan_train'>Train and evaluation</a><br>"
]
},
{
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment