cyclegan.py 11.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the CycleGAN generator and discriminator networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
Mark Sandler's avatar
Mark Sandler committed
21
from six.moves import xrange  # pylint: disable=redefined-builtin
22
import tensorflow as tf
23
24
25
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import util as contrib_util
26

27
layers = contrib_layers
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59


def cyclegan_arg_scope(instance_norm_center=True,
                       instance_norm_scale=True,
                       instance_norm_epsilon=0.001,
                       weights_init_stddev=0.02,
                       weight_decay=0.0):
  """Returns a default argument scope for all generators and discriminators.

  Args:
    instance_norm_center: Whether instance normalization applies centering.
    instance_norm_scale: Whether instance normalization applies scaling.
    instance_norm_epsilon: Small float added to the variance in the instance
      normalization to avoid dividing by zero.
    weights_init_stddev: Standard deviation of the random values to initialize
      the convolution kernels with.
    weight_decay: Magnitude of weight decay applied to all convolution kernel
      variables of the generator.

  Returns:
    An arg-scope.
  """
  instance_norm_params = {
      'center': instance_norm_center,
      'scale': instance_norm_scale,
      'epsilon': instance_norm_epsilon,
  }

  weights_regularizer = None
  if weight_decay and weight_decay > 0.0:
    weights_regularizer = layers.l2_regularizer(weight_decay)

60
  with contrib_framework.arg_scope(
61
62
63
      [layers.conv2d],
      normalizer_fn=layers.instance_norm,
      normalizer_params=instance_norm_params,
64
65
      weights_initializer=tf.compat.v1.random_normal_initializer(
          0, weights_init_stddev),
66
67
68
69
      weights_regularizer=weights_regularizer) as sc:
    return sc


70
71
def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose',
                      pad_mode='REFLECT', align_corners=False):
72
73
74
75
76
77
78
79
80
81
82
  """Upsamples the given inputs.

  Args:
    net: A Tensor of size [batch_size, height, width, filters].
    num_outputs: The number of output filters.
    stride: A list of 2 scalars or a 1x2 Tensor indicating the scale,
      relative to the inputs, of the output dimensions. For example, if kernel
      size is [2, 3], then the output height and width will be twice and three
      times the input size.
    method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv',
      or 'conv2d_transpose'.
83
84
85
86
    pad_mode: mode for tf.pad, one of "CONSTANT", "REFLECT", or "SYMMETRIC".
    align_corners: option for method, 'bilinear_upsample_conv'. If true, the
      centers of the 4 corner pixels of the input and output tensors are
      aligned, preserving the values at the corner pixels.
87
88
89
90
91
92
93

  Returns:
    A Tensor which was upsampled using the specified method.

  Raises:
    ValueError: if `method` is not recognized.
  """
94
95
  with tf.compat.v1.variable_scope('upconv'):
    net_shape = tf.shape(input=net)
96
97
98
99
100
101
102
103
104
    height = net_shape[1]
    width = net_shape[2]

    # Reflection pad by 1 in spatial dimensions (axes 1, 2 = h, w) to make a 3x3
    # 'valid' convolution produce an output with the same dimension as the
    # input.
    spatial_pad_1 = np.array([[0, 0], [1, 1], [1, 1], [0, 0]])

    if method == 'nn_upsample_conv':
105
106
107
108
      net = tf.image.resize(
          net, [stride[0] * height, stride[1] * width],
          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
      net = tf.pad(tensor=net, paddings=spatial_pad_1, mode=pad_mode)
109
      net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
110
    elif method == 'bilinear_upsample_conv':
111
      net = tf.compat.v1.image.resize_bilinear(
112
113
          net, [stride[0] * height, stride[1] * width],
          align_corners=align_corners)
114
      net = tf.pad(tensor=net, paddings=spatial_pad_1, mode=pad_mode)
115
116
      net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
    elif method == 'conv2d_transpose':
117
118
119
120
      # This corrects 1 pixel offset for images with even width and height.
      # conv2d is left aligned and conv2d_transpose is right aligned for even
      # sized images (while doing 'SAME' padding).
      # Note: This doesn't reflect actual model in paper.
121
      net = layers.conv2d_transpose(
122
123
          net, num_outputs, kernel_size=[3, 3], stride=stride, padding='valid')
      net = net[:, 1:, 1:, :]
124
    else:
125
      raise ValueError('Unknown method: [%s]' % method)
126
127
128
129
130

    return net


def _dynamic_or_static_shape(tensor):
131
  shape = tf.shape(input=tensor)
132
  static_shape = contrib_util.constant_value(shape)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
  return static_shape if static_shape is not None else shape


def cyclegan_generator_resnet(images,
                              arg_scope_fn=cyclegan_arg_scope,
                              num_resnet_blocks=6,
                              num_filters=64,
                              upsample_fn=cyclegan_upsample,
                              kernel_size=3,
                              tanh_linear_slope=0.0,
                              is_training=False):
  """Defines the cyclegan resnet network architecture.

  As closely as possible following
  https://github.com/junyanz/CycleGAN/blob/master/models/architectures.lua#L232

  FYI: This network requires input height and width to be divisible by 4 in
  order to generate an output with shape equal to input shape. Assertions will
  catch this if input dimensions are known at graph construction time, but
  there's no protection if unknown at graph construction time (you'll see an
  error).

  Args:
    images: Input image tensor of shape [batch_size, h, w, 3].
    arg_scope_fn: Function to create the global arg_scope for the network.
    num_resnet_blocks: Number of ResNet blocks in the middle of the generator.
    num_filters: Number of filters of the first hidden layer.
    upsample_fn: Upsampling function for the decoder part of the generator.
    kernel_size: Size w or list/tuple [h, w] of the filter kernels for all inner
      layers.
    tanh_linear_slope: Slope of the linear function to add to the tanh over the
      logits.
    is_training: Whether the network is created in training mode or inference
      only mode. Not actually needed, just for compliance with other generator
      network functions.

  Returns:
    A `Tensor` representing the model output and a dictionary of model end
      points.

  Raises:
    ValueError: If the input height or width is known at graph construction time
      and not a multiple of 4.
  """
  # Neither dropout nor batch norm -> dont need is_training
  del is_training

  end_points = {}

  input_size = images.shape.as_list()
  height, width = input_size[1], input_size[2]
  if height and height % 4 != 0:
    raise ValueError('The input height must be a multiple of 4.')
  if width and width % 4 != 0:
    raise ValueError('The input width must be a multiple of 4.')
188
  num_outputs = input_size[3]
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

  if not isinstance(kernel_size, (list, tuple)):
    kernel_size = [kernel_size, kernel_size]

  kernel_height = kernel_size[0]
  kernel_width = kernel_size[1]
  pad_top = (kernel_height - 1) // 2
  pad_bottom = kernel_height // 2
  pad_left = (kernel_width - 1) // 2
  pad_right = kernel_width // 2
  paddings = np.array(
      [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]],
      dtype=np.int32)
  spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]])

204
  with contrib_framework.arg_scope(arg_scope_fn()):
205
206
207
208

    ###########
    # Encoder #
    ###########
209
    with tf.compat.v1.variable_scope('input'):
210
      # 7x7 input stage
211
      net = tf.pad(tensor=images, paddings=spatial_pad_3, mode='REFLECT')
212
213
214
      net = layers.conv2d(net, num_filters, kernel_size=[7, 7], padding='VALID')
      end_points['encoder_0'] = net

215
    with tf.compat.v1.variable_scope('encoder'):
216
217
218
219
220
      with contrib_framework.arg_scope([layers.conv2d],
                                       kernel_size=kernel_size,
                                       stride=2,
                                       activation_fn=tf.nn.relu,
                                       padding='VALID'):
221

222
        net = tf.pad(tensor=net, paddings=paddings, mode='REFLECT')
223
224
        net = layers.conv2d(net, num_filters * 2)
        end_points['encoder_1'] = net
225
        net = tf.pad(tensor=net, paddings=paddings, mode='REFLECT')
226
227
228
229
230
231
        net = layers.conv2d(net, num_filters * 4)
        end_points['encoder_2'] = net

    ###################
    # Residual Blocks #
    ###################
232
    with tf.compat.v1.variable_scope('residual_blocks'):
233
234
235
236
237
      with contrib_framework.arg_scope([layers.conv2d],
                                       kernel_size=kernel_size,
                                       stride=1,
                                       activation_fn=tf.nn.relu,
                                       padding='VALID'):
238
        for block_id in xrange(num_resnet_blocks):
239
240
          with tf.compat.v1.variable_scope('block_{}'.format(block_id)):
            res_net = tf.pad(tensor=net, paddings=paddings, mode='REFLECT')
241
            res_net = layers.conv2d(res_net, num_filters * 4)
242
            res_net = tf.pad(tensor=res_net, paddings=paddings, mode='REFLECT')
243
244
245
246
247
248
249
250
251
            res_net = layers.conv2d(res_net, num_filters * 4,
                                    activation_fn=None)
            net += res_net

            end_points['resnet_block_%d' % block_id] = net

    ###########
    # Decoder #
    ###########
252
    with tf.compat.v1.variable_scope('decoder'):
253

254
255
256
257
      with contrib_framework.arg_scope([layers.conv2d],
                                       kernel_size=kernel_size,
                                       stride=1,
                                       activation_fn=tf.nn.relu):
258

259
        with tf.compat.v1.variable_scope('decoder1'):
260
261
262
          net = upsample_fn(net, num_outputs=num_filters * 2, stride=[2, 2])
        end_points['decoder1'] = net

263
        with tf.compat.v1.variable_scope('decoder2'):
264
265
266
          net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2])
        end_points['decoder2'] = net

267
268
    with tf.compat.v1.variable_scope('output'):
      net = tf.pad(tensor=net, paddings=spatial_pad_3, mode='REFLECT')
269
270
271
272
273
274
275
276
277
278
279
280
      logits = layers.conv2d(
          net,
          num_outputs, [7, 7],
          activation_fn=None,
          normalizer_fn=None,
          padding='valid')
      logits = tf.reshape(logits, _dynamic_or_static_shape(images))

      end_points['logits'] = logits
      end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope

  return end_points['predictions'], end_points