pix2pix_test.py 5.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for pix2pix."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
from tensorflow.contrib import framework as contrib_framework
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from nets import pix2pix


class GeneratorTest(tf.test.TestCase):

  def _reduced_default_blocks(self):
    """Returns the default blocks, scaled down to make test run faster."""
    return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob)
            for b in pix2pix._default_generator_blocks()]

  def test_output_size_nn_upsample_conv(self):
    batch_size = 2
    height, width = 256, 256
    num_outputs = 4

    images = tf.ones((batch_size, height, width, 3))
39
    with contrib_framework.arg_scope(pix2pix.pix2pix_arg_scope()):
40
41
42
43
44
      logits, _ = pix2pix.pix2pix_generator(
          images, num_outputs, blocks=self._reduced_default_blocks(),
          upsample_method='nn_upsample_conv')

    with self.test_session() as session:
45
      session.run(tf.compat.v1.global_variables_initializer())
46
47
48
49
50
51
52
53
54
55
      np_outputs = session.run(logits)
      self.assertListEqual([batch_size, height, width, num_outputs],
                           list(np_outputs.shape))

  def test_output_size_conv2d_transpose(self):
    batch_size = 2
    height, width = 256, 256
    num_outputs = 4

    images = tf.ones((batch_size, height, width, 3))
56
    with contrib_framework.arg_scope(pix2pix.pix2pix_arg_scope()):
57
58
59
60
61
      logits, _ = pix2pix.pix2pix_generator(
          images, num_outputs, blocks=self._reduced_default_blocks(),
          upsample_method='conv2d_transpose')

    with self.test_session() as session:
62
      session.run(tf.compat.v1.global_variables_initializer())
63
64
65
66
67
68
69
70
71
72
73
74
75
76
      np_outputs = session.run(logits)
      self.assertListEqual([batch_size, height, width, num_outputs],
                           list(np_outputs.shape))

  def test_block_number_dictates_number_of_layers(self):
    batch_size = 2
    height, width = 256, 256
    num_outputs = 4

    images = tf.ones((batch_size, height, width, 3))
    blocks = [
        pix2pix.Block(64, 0.5),
        pix2pix.Block(128, 0),
    ]
77
    with contrib_framework.arg_scope(pix2pix.pix2pix_arg_scope()):
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
      _, end_points = pix2pix.pix2pix_generator(
          images, num_outputs, blocks)

    num_encoder_layers = 0
    num_decoder_layers = 0
    for end_point in end_points:
      if end_point.startswith('encoder'):
        num_encoder_layers += 1
      elif end_point.startswith('decoder'):
        num_decoder_layers += 1

    self.assertEqual(num_encoder_layers, len(blocks))
    self.assertEqual(num_decoder_layers, len(blocks))


class DiscriminatorTest(tf.test.TestCase):

  def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2):
    return (input_size + pad * 2 - kernel_size) // stride + 1

  def test_four_layers(self):
    batch_size = 2
    input_size = 256

    output_size = self._layer_output_size(input_size)
    output_size = self._layer_output_size(output_size)
    output_size = self._layer_output_size(output_size)
    output_size = self._layer_output_size(output_size, stride=1)
    output_size = self._layer_output_size(output_size, stride=1)

    images = tf.ones((batch_size, input_size, input_size, 3))
109
    with contrib_framework.arg_scope(pix2pix.pix2pix_arg_scope()):
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
      logits, end_points = pix2pix.pix2pix_discriminator(
          images, num_filters=[64, 128, 256, 512])
    self.assertListEqual([batch_size, output_size, output_size, 1],
                         logits.shape.as_list())
    self.assertListEqual([batch_size, output_size, output_size, 1],
                         end_points['predictions'].shape.as_list())

  def test_four_layers_no_padding(self):
    batch_size = 2
    input_size = 256

    output_size = self._layer_output_size(input_size, pad=0)
    output_size = self._layer_output_size(output_size, pad=0)
    output_size = self._layer_output_size(output_size, pad=0)
    output_size = self._layer_output_size(output_size, stride=1, pad=0)
    output_size = self._layer_output_size(output_size, stride=1, pad=0)

    images = tf.ones((batch_size, input_size, input_size, 3))
128
    with contrib_framework.arg_scope(pix2pix.pix2pix_arg_scope()):
129
130
131
132
133
134
135
136
137
138
139
140
      logits, end_points = pix2pix.pix2pix_discriminator(
          images, num_filters=[64, 128, 256, 512], padding=0)
    self.assertListEqual([batch_size, output_size, output_size, 1],
                         logits.shape.as_list())
    self.assertListEqual([batch_size, output_size, output_size, 1],
                         end_points['predictions'].shape.as_list())

  def test_four_layers_wrog_paddig(self):
    batch_size = 2
    input_size = 256

    images = tf.ones((batch_size, input_size, input_size, 3))
141
    with contrib_framework.arg_scope(pix2pix.pix2pix_arg_scope()):
142
143
144
145
146
147
148
149
150
      with self.assertRaises(TypeError):
        pix2pix.pix2pix_discriminator(
            images, num_filters=[64, 128, 256, 512], padding=1.5)

  def test_four_layers_negative_padding(self):
    batch_size = 2
    input_size = 256

    images = tf.ones((batch_size, input_size, input_size, 3))
151
    with contrib_framework.arg_scope(pix2pix.pix2pix_arg_scope()):
152
153
154
155
156
157
      with self.assertRaises(ValueError):
        pix2pix.pix2pix_discriminator(
            images, num_filters=[64, 128, 256, 512], padding=-1)

if __name__ == '__main__':
  tf.test.main()