dcgan.py 7.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DCGAN generator and discriminator from https://arxiv.org/abs/1511.06434."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from math import log

Mark Sandler's avatar
Mark Sandler committed
22
from six.moves import xrange  # pylint: disable=redefined-builtin
23
import tensorflow as tf
24
from tensorflow.contrib import slim as contrib_slim
Mark Sandler's avatar
Mark Sandler committed
25

26
slim = contrib_slim
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


def _validate_image_inputs(inputs):
  inputs.get_shape().assert_has_rank(4)
  inputs.get_shape()[1:3].assert_is_fully_defined()
  if inputs.get_shape()[1] != inputs.get_shape()[2]:
    raise ValueError('Input tensor does not have equal width and height: ',
                     inputs.get_shape()[1:3])
  width = inputs.get_shape().as_list()[1]
  if log(width, 2) != int(log(width, 2)):
    raise ValueError('Input tensor `width` is not a power of 2: ', width)


# TODO(joelshor): Use fused batch norm by default. Investigate why some GAN
# setups need the gradient of gradient FusedBatchNormGrad.
def discriminator(inputs,
                  depth=64,
                  is_training=True,
                  reuse=None,
                  scope='Discriminator',
                  fused_batch_norm=False):
  """Discriminator network for DCGAN.

  Construct discriminator network from inputs to the final endpoint.

  Args:
    inputs: A tensor of size [batch_size, height, width, channels]. Must be
      floating point.
    depth: Number of channels in first convolution layer.
    is_training: Whether the network is for training or not.
    reuse: Whether or not the network variables should be reused. `scope`
      must be given to be reused.
    scope: Optional variable_scope.
    fused_batch_norm: If `True`, use a faster, fused implementation of
      batch norm.

  Returns:
    logits: The pre-softmax activations, a tensor of size [batch_size, 1]
    end_points: a dictionary from components of the network to their activation.

  Raises:
    ValueError: If the input image shape is not 4-dimensional, if the spatial
      dimensions aren't defined at graph construction time, if the spatial
      dimensions aren't square, or if the spatial dimensions aren't a power of
      two.
  """

  normalizer_fn = slim.batch_norm
  normalizer_fn_args = {
      'is_training': is_training,
      'zero_debias_moving_mean': True,
      'fused': fused_batch_norm,
  }

  _validate_image_inputs(inputs)
  inp_shape = inputs.get_shape().as_list()[1]

  end_points = {}
85
86
  with tf.compat.v1.variable_scope(
      scope, values=[inputs], reuse=reuse) as scope:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    with slim.arg_scope([normalizer_fn], **normalizer_fn_args):
      with slim.arg_scope([slim.conv2d],
                          stride=2,
                          kernel_size=4,
                          activation_fn=tf.nn.leaky_relu):
        net = inputs
        for i in xrange(int(log(inp_shape, 2))):
          scope = 'conv%i' % (i + 1)
          current_depth = depth * 2**i
          normalizer_fn_ = None if i == 0 else normalizer_fn
          net = slim.conv2d(
              net, current_depth, normalizer_fn=normalizer_fn_, scope=scope)
          end_points[scope] = net

        logits = slim.conv2d(net, 1, kernel_size=1, stride=1, padding='VALID',
                             normalizer_fn=None, activation_fn=None)
        logits = tf.reshape(logits, [-1, 1])
        end_points['logits'] = logits

        return logits, end_points


# TODO(joelshor): Use fused batch norm by default. Investigate why some GAN
# setups need the gradient of gradient FusedBatchNormGrad.
def generator(inputs,
              depth=64,
              final_size=32,
              num_outputs=3,
              is_training=True,
              reuse=None,
              scope='Generator',
              fused_batch_norm=False):
  """Generator network for DCGAN.

  Construct generator network from inputs to the final endpoint.

  Args:
    inputs: A tensor with any size N. [batch_size, N]
    depth: Number of channels in last deconvolution layer.
    final_size: The shape of the final output.
    num_outputs: Number of output features. For images, this is the number of
      channels.
    is_training: whether is training or not.
    reuse: Whether or not the network has its variables should be reused. scope
      must be given to be reused.
    scope: Optional variable_scope.
    fused_batch_norm: If `True`, use a faster, fused implementation of
      batch norm.

  Returns:
    logits: the pre-softmax activations, a tensor of size
      [batch_size, 32, 32, channels]
    end_points: a dictionary from components of the network to their activation.

  Raises:
    ValueError: If `inputs` is not 2-dimensional.
    ValueError: If `final_size` isn't a power of 2 or is less than 8.
  """
  normalizer_fn = slim.batch_norm
  normalizer_fn_args = {
      'is_training': is_training,
      'zero_debias_moving_mean': True,
      'fused': fused_batch_norm,
  }

  inputs.get_shape().assert_has_rank(2)
  if log(final_size, 2) != int(log(final_size, 2)):
    raise ValueError('`final_size` (%i) must be a power of 2.' % final_size)
  if final_size < 8:
    raise ValueError('`final_size` (%i) must be greater than 8.' % final_size)

  end_points = {}
  num_layers = int(log(final_size, 2)) - 1
160
161
  with tf.compat.v1.variable_scope(
      scope, values=[inputs], reuse=reuse) as scope:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    with slim.arg_scope([normalizer_fn], **normalizer_fn_args):
      with slim.arg_scope([slim.conv2d_transpose],
                          normalizer_fn=normalizer_fn,
                          stride=2,
                          kernel_size=4):
        net = tf.expand_dims(tf.expand_dims(inputs, 1), 1)

        # First upscaling is different because it takes the input vector.
        current_depth = depth * 2 ** (num_layers - 1)
        scope = 'deconv1'
        net = slim.conv2d_transpose(
            net, current_depth, stride=1, padding='VALID', scope=scope)
        end_points[scope] = net

        for i in xrange(2, num_layers):
          scope = 'deconv%i' % (i)
          current_depth = depth * 2 ** (num_layers - i)
          net = slim.conv2d_transpose(net, current_depth, scope=scope)
          end_points[scope] = net

        # Last layer has different normalizer and activation.
        scope = 'deconv%i' % (num_layers)
        net = slim.conv2d_transpose(
            net, depth, normalizer_fn=None, activation_fn=None, scope=scope)
        end_points[scope] = net

        # Convert to proper channels.
        scope = 'logits'
        logits = slim.conv2d(
            net,
            num_outputs,
            normalizer_fn=None,
            activation_fn=None,
            kernel_size=1,
            stride=1,
            padding='VALID',
            scope=scope)
        end_points[scope] = logits

        logits.get_shape().assert_has_rank(4)
        logits.get_shape().assert_is_compatible_with(
            [None, final_size, final_size, num_outputs])

        return logits, end_points