Commit 4f7074f6 authored by Joel Shor's avatar Joel Shor Committed by Joel Shor
Browse files

Project import generated by Copybara.

PiperOrigin-RevId: 199251174
parent 30cf3752
......@@ -20,6 +20,7 @@ from __future__ import print_function
import os
from absl import flags
import numpy as np
import tensorflow as tf
......@@ -31,7 +32,7 @@ class DataProviderTest(tf.test.TestCase):
def test_cifar10_train_set(self):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cifar/testdata')
batch_size = 4
......
......@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
import data_provider
......@@ -25,8 +27,7 @@ import networks
import util
flags = tf.flags
FLAGS = tf.flags.FLAGS
FLAGS = flags.FLAGS
tfgan = tf.contrib.gan
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
......@@ -155,4 +156,4 @@ def _get_generated_data(num_images_generated, conditional_eval, num_classes):
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -19,16 +19,22 @@ from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
import eval # pylint:disable=redefined-builtin
FLAGS = tf.flags.FLAGS
FLAGS = flags.FLAGS
mock = tf.test.mock
class EvalTest(tf.test.TestCase):
class EvalTest(tf.test.TestCase, parameterized.TestCase):
def _test_build_graph_helper(self, eval_real_images, conditional_eval):
@parameterized.named_parameters(
('RealData', True, False),
('GeneratedData', False, False),
('GeneratedDataConditional', False, True))
def test_build_graph(self, eval_real_images, conditional_eval):
FLAGS.eval_real_images = eval_real_images
FLAGS.conditional_eval = conditional_eval
# Mock `frechet_inception_distance` and `inception_score`, which are
......@@ -40,15 +46,6 @@ class EvalTest(tf.test.TestCase):
mock_iscore.return_value = 1.0
eval.main(None, run_eval_loop=False)
def test_build_graph_realdata(self):
self._test_build_graph_helper(True, False)
def test_build_graph_generateddata(self):
self._test_build_graph_helper(False, False)
def test_build_graph_generateddataconditional(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__':
tf.test.main()
......@@ -18,6 +18,8 @@ from __future__ import division
from __future__ import print_function
from absl import flags
from absl import logging
import tensorflow as tf
import data_provider
......@@ -25,7 +27,6 @@ import networks
tfgan = tf.contrib.gan
flags = tf.flags
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
......@@ -173,6 +174,6 @@ def _optimizer(gen_lr, dis_lr, use_sync_replicas):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
logging.set_verbosity(logging.INFO)
tf.app.run()
......@@ -19,17 +19,23 @@ from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
FLAGS = flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def _test_build_graph_helper(self, conditional, use_sync_replicas):
@parameterized.named_parameters(
('Unconditional', False, False),
('Conditional', True, False),
('SyncReplicas', False, True))
def test_build_graph_helper(self, conditional, use_sync_replicas):
FLAGS.max_number_of_steps = 0
FLAGS.conditional = conditional
FLAGS.use_sync_replicas = use_sync_replicas
......@@ -45,14 +51,6 @@ class TrainTest(tf.test.TestCase):
mock_imgs, mock_lbls, None, None)
train.main(None)
def test_build_graph_unconditional(self):
self._test_build_graph_helper(False, False)
def test_build_graph_conditional(self):
self._test_build_graph_helper(True, False)
def test_build_graph_syncreplicas(self):
self._test_build_graph_helper(False, True)
if __name__ == '__main__':
tf.test.main()
......@@ -20,6 +20,7 @@ from __future__ import print_function
import os
from absl import flags
import numpy as np
import tensorflow as tf
......@@ -31,6 +32,12 @@ mock = tf.test.mock
class DataProviderTest(tf.test.TestCase):
def setUp(self):
super(DataProviderTest, self).setUp()
self.testdata_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata')
def test_normalize_image(self):
image = tf.random_uniform(shape=(8, 8, 3), maxval=256, dtype=tf.int32)
rescaled_image = data_provider.normalize_image(image)
......@@ -51,13 +58,8 @@ class DataProviderTest(tf.test.TestCase):
self.assertTupleEqual((10, 10, 3), sess.run(patch2).shape)
self.assertTupleEqual((10, 10, 3), sess.run(patch3).shape)
def _get_testdata_dir(self):
return os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cyclegan/testdata')
def test_custom_dataset_provider(self):
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg')
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images = data_provider._provide_custom_dataset(
......@@ -75,7 +77,7 @@ class DataProviderTest(tf.test.TestCase):
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_custom_datasets_provider(self):
file_pattern = os.path.join(self._get_testdata_dir(), '*.jpg')
file_pattern = os.path.join(self.testdata_dir, '*.jpg')
batch_size = 3
patch_size = 8
images_list = data_provider.provide_custom_datasets(
......
......@@ -21,6 +21,8 @@ from __future__ import print_function
import os
from absl import app
from absl import flags
import numpy as np
import PIL
import tensorflow as tf
......@@ -28,7 +30,6 @@ import tensorflow as tf
import data_provider
import networks
flags = tf.flags
tfgan = tf.contrib.gan
flags.DEFINE_string('checkpoint_path', '',
......@@ -147,4 +148,4 @@ def main(_):
if __name__ == '__main__':
tf.app.run()
app.run()
......@@ -5,8 +5,10 @@ from __future__ import division
from __future__ import print_function
import os
from absl import logging
import numpy as np
import PIL
import tensorflow as tf
import inference_demo
......@@ -59,7 +61,7 @@ class InferenceDemoTest(tf.test.TestCase):
# Create inference graph
tf.reset_default_graph()
FLAGS.patch_dim = FLAGS.patch_size
tf.logging.info('dir_path: {}'.format(os.listdir(self._export_dir)))
logging.info('dir_path: %s', os.listdir(self._export_dir))
FLAGS.checkpoint_path = self._ckpt_path
FLAGS.image_set_x_glob = self._image_glob
FLAGS.image_set_y_glob = self._image_glob
......@@ -67,7 +69,7 @@ class InferenceDemoTest(tf.test.TestCase):
FLAGS.generated_y_dir = self._geny_dir
inference_demo.main(None)
tf.logging.info('gen x: {}'.format(os.listdir(self._genx_dir)))
logging.info('gen x: %s', os.listdir(self._genx_dir))
# Check that the image names match
self.assertSetEqual(
......@@ -84,7 +86,7 @@ class InferenceDemoTest(tf.test.TestCase):
self.assertRealisticImage(image_path)
def assertRealisticImage(self, image_path):
tf.logging.info('Testing {} for realism.'.format(image_path))
logging.info('Testing %s for realism.', image_path)
# If the normalization is off or forgotten, then the generated image is
# all one pixel value. This tests that different pixel values are achieved.
input_np = np.asarray(PIL.Image.open(image_path))
......
......@@ -19,13 +19,12 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from absl import flags
import tensorflow as tf
import data_provider
import networks
flags = tf.flags
tfgan = tf.contrib.gan
......@@ -87,10 +86,7 @@ def _define_model(images_x, images_y):
data_y=images_y)
# Add summaries for generated images.
tfgan.eval.add_image_comparison_summaries(
cyclegan_model, num_comparisons=3, display_diffs=False)
tfgan.eval.add_gan_model_image_summaries(
cyclegan_model, grid_size=int(np.sqrt(FLAGS.batch_size)))
tfgan.eval.add_cyclegan_image_summaries(cyclegan_model)
return cyclegan_model
......
......@@ -19,12 +19,13 @@ from __future__ import division
from __future__ import print_function
from absl import flags
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
FLAGS = flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan
......@@ -60,9 +61,6 @@ class TrainTest(tf.test.TestCase):
self.assertShapeEqual(images_x_np, cyclegan_model.reconstructed_x)
self.assertShapeEqual(images_y_np, cyclegan_model.reconstructed_y)
mock_eval.add_image_comparison_summaries.assert_called_once()
mock_eval.add_gan_model_image_summaries.assert_called_once()
@mock.patch.object(train.networks, 'generator', autospec=True)
@mock.patch.object(train.networks, 'discriminator', autospec=True)
@mock.patch.object(
......
......@@ -20,6 +20,8 @@ from __future__ import print_function
import os
from absl import flags
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -27,11 +29,14 @@ import tensorflow as tf
import data_provider
class DataProviderTest(tf.test.TestCase):
class DataProviderTest(tf.test.TestCase, parameterized.TestCase):
def _test_data_provider_helper(self, split_name):
@parameterized.named_parameters(
('train', 'train'),
('validation', 'validation'))
def test_data_provider(self, split_name):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/')
batch_size = 3
......@@ -49,12 +54,6 @@ class DataProviderTest(tf.test.TestCase):
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def test_data_provider_train(self):
self._test_data_provider_helper('train')
def test_data_provider_validation(self):
self._test_data_provider_helper('validation')
if __name__ == '__main__':
tf.test.main()
......@@ -20,13 +20,14 @@ from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
import data_provider
import networks
import summaries
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
......@@ -98,4 +99,4 @@ def main(_, run_eval_loop=True):
if __name__ == '__main__':
tf.app.run()
app.run()
......@@ -20,7 +20,8 @@ from __future__ import division
from __future__ import print_function
from absl import flags
from absl import logging
import tensorflow as tf
import data_provider
......@@ -29,7 +30,6 @@ import summaries
tfgan = tf.contrib.gan
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')
......@@ -212,6 +212,6 @@ def _get_gan_model(generator_inputs, generated_data, real_data,
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
logging.set_verbosity(logging.INFO)
tf.app.run()
......@@ -19,17 +19,22 @@ from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import train
FLAGS = tf.flags.FLAGS
FLAGS = flags.FLAGS
mock = tf.test.mock
class TrainTest(tf.test.TestCase):
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def _test_build_graph_helper(self, weight_factor):
@parameterized.named_parameters(
('NoAdversarialLoss', 0.0),
('AdversarialLoss', 1.0))
def test_build_graph(self, weight_factor):
FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor
......@@ -45,12 +50,6 @@ class TrainTest(tf.test.TestCase):
mock_data_provider.provide_data.return_value = mock_imgs
train.main(None)
def test_build_graph_noadversarialloss(self):
self._test_build_graph_helper(0.0)
def test_build_graph_adversarialloss(self):
self._test_build_graph_helper(1.0)
if __name__ == '__main__':
tf.test.main()
......
......@@ -19,13 +19,14 @@ from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
import data_provider
import networks
import util
flags = tf.flags
tfgan = tf.contrib.gan
......@@ -107,4 +108,4 @@ def _get_generator_inputs(num_images_per_class, num_classes, noise_dims):
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from absl.testing import absltest
import conditional_eval
class ConditionalEvalTest(tf.test.TestCase):
class ConditionalEvalTest(absltest.TestCase):
def test_build_graph(self):
conditional_eval.main(None, run_eval_loop=False)
if __name__ == '__main__':
tf.test.main()
absltest.main()
......@@ -21,6 +21,7 @@ from __future__ import print_function
import os
from absl import flags
import tensorflow as tf
import data_provider
......@@ -30,7 +31,7 @@ class DataProviderTest(tf.test.TestCase):
def test_mnist_data_reading(self):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/mnist/testdata')
batch_size = 5
......
......@@ -20,13 +20,14 @@ from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
import data_provider
import networks
import util
flags = tf.flags
FLAGS = flags.FLAGS
tfgan = tf.contrib.gan
......@@ -100,4 +101,4 @@ def main(_, run_eval_loop=True):
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -20,21 +20,21 @@ from __future__ import print_function
import tensorflow as tf
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import eval # pylint:disable=redefined-builtin
class EvalTest(tf.test.TestCase):
class EvalTest(parameterized.TestCase):
def _test_build_graph_helper(self, eval_real_images):
tf.flags.FLAGS.eval_real_images = eval_real_images
@parameterized.named_parameters(
('RealData', True),
('GeneratedData', False))
def test_build_graph(self, eval_real_images):
flags.FLAGS.eval_real_images = eval_real_images
eval.main(None, run_eval_loop=False)
def test_build_graph_realdata(self):
self._test_build_graph_helper(True)
def test_build_graph_generateddata(self):
self._test_build_graph_helper(False)
if __name__ == '__main__':
tf.test.main()
absltest.main()
......@@ -26,6 +26,8 @@ from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
......@@ -34,7 +36,6 @@ import data_provider
import networks
import util
flags = tf.flags
tfgan = tf.contrib.gan
......@@ -156,4 +157,4 @@ def _get_write_image_ops(eval_dir, filename, images):
if __name__ == '__main__':
tf.app.run()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment