Commit fe6c3819 authored by Joel Shor's avatar Joel Shor Committed by joel-shor
Browse files

Project import generated by Copybara.

PiperOrigin-RevId: 176969064
parent b6907e8d
......@@ -57,8 +57,10 @@ Banner () {
echo -e "${green}${text}${nc}"
}
# Download the dataset.
"${git_repo}/research/slim/datasets/download_and_convert_imagenet.sh" ${DATASET_DIR}
# Download the dataset. You will be asked for an ImageNet username and password.
# To get one, register at http://www.image-net.org/.
bazel build "${git_repo}/research/slim:download_and_convert_imagenet"
"./bazel-bin/download_and_convert_imagenet" ${DATASET_DIR}
# Run the compression model.
NUM_STEPS=10000
......
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the Imagenet dataset.
# 2. Trains image compression model on patches from Imagenet.
# 3. Evaluates the models and writes sample images to disk.
#
# Usage:
# cd models/research/gan/image_compression
# ./launch_jobs.sh ${weight_factor} ${git_repo}
set -e
# Weight of the adversarial loss.
weight_factor=$1
if [[ "$weight_factor" == "" ]]; then
echo "'weight_factor' must not be empty."
exit
fi
# Location of the git repository.
git_repo=$2
if [[ "$git_repo" == "" ]]; then
echo "'git_repo' must not be empty."
exit
fi
# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/compression-model
# Base name for where the evaluation images will be saved to.
EVAL_DIR=/tmp/compression-model/eval
# Where the dataset is saved to.
DATASET_DIR=/tmp/imagenet-data
export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/slim:$git_repo/research/slim/nets
# A helper function for printing pretty output.
Banner () {
local text=$1
local green='\033[0;32m'
local nc='\033[0m' # No color.
echo -e "${green}${text}${nc}"
}
# Download the dataset.
bazel build "${git_repo}/research/slim:download_and_convert_imagenet"
"./bazel-bin/download_and_convert_imagenet" ${DATASET_DIR}
# Run the pix2pix model.
NUM_STEPS=10000
MODEL_TRAIN_DIR="${TRAIN_DIR}/wt${weight_factor}"
Banner "Starting training an image compression model for ${NUM_STEPS} steps..."
python "${git_repo}/research/gan/image_compression/train.py" \
--train_log_dir=${MODEL_TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--max_number_of_steps=${NUM_STEPS} \
--weight_factor=${weight_factor} \
--alsologtostderr
Banner "Finished training pix2pix model ${NUM_STEPS} steps."
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN Pix2Pix example using TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.nets import cyclegan
from slim.nets import pix2pix
def generator(input_images):
"""Thin wrapper around CycleGAN generator to conform to the TFGAN API.
Args:
input_images: A batch of images to translate. Images should be normalized
already. Shape is [batch, height, width, channels].
Returns:
Returns generated image batch.
"""
input_images.shape.assert_has_rank(4)
with tf.contrib.framework.arg_scope(cyclegan.cyclegan_arg_scope()):
output_images, _ = cyclegan.cyclegan_generator_resnet(input_images)
return output_images
def discriminator(image_batch, unused_conditioning=None):
"""A thin wrapper around the Pix2Pix discriminator to conform to TFGAN API."""
with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
logits_4d, _ = pix2pix.pix2pix_discriminator(
image_batch, num_filters=[64, 128, 256, 512])
logits_4d.shape.assert_has_rank(4)
# Output of logits is 4D. Reshape to 2D, for TFGAN.
logits_2d = tf.contrib.layers.flatten(logits_4d)
return logits_2d
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.examples.networks.networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from google3.third_party.tensorflow_models.gan.pix2pix import networks
class Pix2PixTest(tf.test.TestCase):
def test_generator_run(self):
img_batch = tf.zeros([3, 128, 128, 3])
model_output = networks.generator(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(model_output)
def test_generator_graph(self):
for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]):
tf.reset_default_graph()
img = tf.ones(shape + [3])
output_imgs = networks.generator(img)
self.assertAllEqual(shape + [3], output_imgs.shape.as_list())
def test_generator_graph_unknown_batch_dim(self):
img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
output_imgs = networks.generator(img)
self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list())
def test_generator_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
networks.generator(tf.zeros([28, 28, 3]))
def test_discriminator_run(self):
img_batch = tf.zeros([3, 70, 70, 3])
disc_output = networks.discriminator(img_batch)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(disc_output)
def test_discriminator_graph(self):
# Check graph construction for a number of image size/depths and batch
# sizes.
for batch_size, patch_size in zip([3, 6], [70, 128]):
tf.reset_default_graph()
img = tf.ones([batch_size, patch_size, patch_size, 3])
disc_output = networks.discriminator(img)
self.assertEqual(2, disc_output.shape.ndims)
self.assertEqual(batch_size, disc_output.shape.as_list()[0])
def test_discriminator_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'Shape must be rank 4'):
networks.discriminator(tf.zeros([28, 28, 3]))
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains an image-to-image translation network with an adversarial loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import data_provider
from google3.third_party.tensorflow_models.gan.pix2pix import networks
flags = tf.flags
tfgan = tf.contrib.gan
flags.DEFINE_integer('batch_size', 10, 'The number of images in each batch.')
flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.')
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_string('train_log_dir', '/tmp/pix2pix/',
'Directory where to write event logs.')
flags.DEFINE_float('generator_lr', 0.00001,
'The compression model learning rate.')
flags.DEFINE_float('discriminator_lr', 0.00001,
'The discriminator learning rate.')
flags.DEFINE_integer('max_number_of_steps', 2000000,
'The maximum number of gradient steps.')
flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
flags.DEFINE_float(
'weight_factor', 0.0,
'How much to weight the adversarial loss relative to pixel loss.')
flags.DEFINE_string('dataset_dir', None, 'Location of data.')
FLAGS = flags.FLAGS
def main(_):
if not tf.gfile.Exists(FLAGS.train_log_dir):
tf.gfile.MakeDirs(FLAGS.train_log_dir)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Get real and distorted images.
with tf.device('/cpu:0'), tf.name_scope('inputs'):
real_images = data_provider.provide_data(
'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
patch_size=FLAGS.patch_size)
distorted_images = _distort_images(
real_images, downscale_size=int(FLAGS.patch_size / 2),
upscale_size=FLAGS.patch_size)
# Create a GANModel tuple.
gan_model = tfgan.gan_model(
generator_fn=networks.generator,
discriminator_fn=networks.discriminator,
real_data=real_images,
generator_inputs=distorted_images)
tfgan.eval.add_image_comparison_summaries(
gan_model, num_comparisons=3, display_diffs=True)
tfgan.eval.add_gan_model_image_summaries(gan_model, grid_size=3)
# Define the GANLoss tuple using standard library functions.
with tf.name_scope('losses'):
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.least_squares_generator_loss,
discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
# Define the standard L1 pixel loss.
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data,
ord=1) / FLAGS.patch_size ** 2
# Modify the loss tuple to include the pixel loss. Add summaries as well.
gan_loss = tfgan.losses.combine_adversarial_loss(
gan_loss, gan_model, l1_pixel_loss,
weight_factor=FLAGS.weight_factor)
with tf.name_scope('train_ops'):
# Get the GANTrain ops using the custom optimizers and optional
# discriminator weight clipping.
gen_lr, dis_lr = _lr(FLAGS.generator_lr, FLAGS.discriminator_lr)
gen_opt, dis_opt = _optimizer(gen_lr, dis_lr)
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N,
transform_grads_fn=tf.contrib.training.clip_gradient_norms_fn(1e3))
tf.summary.scalar('generator_lr', gen_lr)
tf.summary.scalar('discriminator_lr', dis_lr)
# Use GAN train step function if using adversarial loss, otherwise
# only train the generator.
train_steps = tfgan.GANTrainSteps(
generator_train_steps=1,
discriminator_train_steps=int(FLAGS.weight_factor > 0))
# Run the alternating training loop. Skip it if no steps should be taken
# (used for graph construction tests).
status_message = tf.string_join(
['Starting train step: ',
tf.as_string(tf.train.get_or_create_global_step())],
name='status_message')
if FLAGS.max_number_of_steps == 0: return
tfgan.gan_train(
train_ops,
FLAGS.train_log_dir,
get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)],
master=FLAGS.master,
is_chief=FLAGS.task == 0)
def _optimizer(gen_lr, dis_lr):
kwargs = {'beta1': 0.5, 'beta2': 0.999}
generator_opt = tf.train.AdamOptimizer(gen_lr, **kwargs)
discriminator_opt = tf.train.AdamOptimizer(dis_lr, **kwargs)
return generator_opt, discriminator_opt
def _lr(gen_lr_base, dis_lr_base):
"""Return the generator and discriminator learning rates."""
gen_lr = tf.train.exponential_decay(
learning_rate=gen_lr_base,
global_step=tf.train.get_or_create_global_step(),
decay_steps=100000,
decay_rate=0.8,
staircase=True,)
dis_lr = dis_lr_base
return gen_lr, dis_lr
def _distort_images(images, downscale_size, upscale_size):
downscaled = tf.image.resize_area(images, [downscale_size] * 2)
upscaled = tf.image.resize_area(downscaled, [upscale_size] * 2)
return upscaled
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for pix2pix.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from google3.third_party.tensorflow_models.gan.pix2pix import train
FLAGS = tf.flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
def _test_build_graph_helper(self, weight_factor):
FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor
FLAGS.batch_size = 9
FLAGS.patch_size = 32
mock_imgs = np.zeros(
[FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3],
dtype=np.float32)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = mock_imgs
train.main(None)
def test_build_graph_noadversarialloss(self):
self._test_build_graph_helper(0.0)
def test_build_graph_adversarialloss(self):
self._test_build_graph_helper(1.0)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment