inference_demo_test.py 3.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Tests for CycleGAN inference demo."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import PIL
import tensorflow as tf

import inference_demo
import train

FLAGS = tf.flags.FLAGS
mock = tf.test.mock
tfgan = tf.contrib.gan


def _basenames_from_glob(file_glob):
  return [os.path.basename(file_path) for file_path in tf.gfile.Glob(file_glob)]


class InferenceDemoTest(tf.test.TestCase):

  def setUp(self):
    self._export_dir = os.path.join(FLAGS.test_tmpdir, 'export')
    self._ckpt_path = os.path.join(self._export_dir, 'model.ckpt')
    self._image_glob = os.path.join(
        FLAGS.test_srcdir,
        'google3/third_party/tensorflow_models/gan/cyclegan/testdata', '*.jpg')
    self._genx_dir = os.path.join(FLAGS.test_tmpdir, 'genx')
    self._geny_dir = os.path.join(FLAGS.test_tmpdir, 'geny')

  @mock.patch.object(tfgan, 'gan_train', autospec=True)
  def testTrainingAndInferenceGraphsAreCompatible(self, unused_mock_gan_train):
    # Training and inference graphs can get out of sync if changes are made
    # to one but not the other. This test will keep them in sync.

    # Save the training graph
    train_sess = tf.Session()
    FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg'
    FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg'
    FLAGS.batch_size = 3
    FLAGS.patch_size = 128
    FLAGS.generator_lr = 0.02
    FLAGS.discriminator_lr = 0.3
    FLAGS.train_log_dir = self._export_dir
    FLAGS.master = 'master'
    FLAGS.task = 0
    FLAGS.cycle_consistency_loss_weight = 2.0
    FLAGS.max_number_of_steps = 1
    train.main(None)
    init_op = tf.global_variables_initializer()
    train_sess.run(init_op)
    train_saver = tf.train.Saver()
    train_saver.save(train_sess, save_path=self._ckpt_path)

    # Create inference graph
    tf.reset_default_graph()
    FLAGS.patch_dim = FLAGS.patch_size
    tf.logging.info('dir_path: {}'.format(os.listdir(self._export_dir)))
    FLAGS.checkpoint_path = self._ckpt_path
    FLAGS.image_set_x_glob = self._image_glob
    FLAGS.image_set_y_glob = self._image_glob
    FLAGS.generated_x_dir = self._genx_dir
    FLAGS.generated_y_dir = self._geny_dir

    inference_demo.main(None)
    tf.logging.info('gen x: {}'.format(os.listdir(self._genx_dir)))

    # Check that the image names match
    self.assertSetEqual(
        set(_basenames_from_glob(FLAGS.image_set_x_glob)),
        set(os.listdir(FLAGS.generated_y_dir)))
    self.assertSetEqual(
        set(_basenames_from_glob(FLAGS.image_set_y_glob)),
        set(os.listdir(FLAGS.generated_x_dir)))

    # Check that each image in the directory looks as expected
    for directory in [FLAGS.generated_x_dir, FLAGS.generated_x_dir]:
      for base_name in os.listdir(directory):
        image_path = os.path.join(directory, base_name)
        self.assertRealisticImage(image_path)

  def assertRealisticImage(self, image_path):
    tf.logging.info('Testing {} for realism.'.format(image_path))
    # If the normalization is off or forgotten, then the generated image is
    # all one pixel value. This tests that different pixel values are achieved.
    input_np = np.asarray(PIL.Image.open(image_path))
    self.assertEqual(len(input_np.shape), 3)
    self.assertGreaterEqual(input_np.shape[0], 50)
    self.assertGreaterEqual(input_np.shape[1], 50)
    self.assertGreater(np.mean(input_np), 20)
    self.assertGreater(np.var(input_np), 100)


if __name__ == '__main__':
  tf.test.main()