networks.py 4.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN CIFAR example using TFGAN."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from slim.nets import dcgan

tfgan = tf.contrib.gan


def _last_conv_layer(end_points):
  """"Returns the last convolutional layer from an endpoints dictionary."""
  conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
  conv_list.sort()
  return end_points[conv_list[-1]]


35
def generator(noise, is_training=True):
36
37
38
39
40
41
  """Generator to produce CIFAR images.

  Args:
    noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
      does not use conditioning, this Tensor represents a noise vector of some
      kind that will be reshaped by the generator into CIFAR examples.
42
43
44
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
45
46
47
48

  Returns:
    A single Tensor with a batch of generated CIFAR images.
  """
49
  images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
50
51
52
53
54

  # Make sure output lies between [-1, 1].
  return tf.tanh(images)


55
def conditional_generator(inputs, is_training=True):
56
57
58
59
60
  """Generator to produce CIFAR images.

  Args:
    inputs: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
      conditional generator.
61
62
63
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
64
65
66
67
68
69
70

  Returns:
    A single Tensor with a batch of generated CIFAR images.
  """
  noise, one_hot_labels = inputs
  noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)

71
  images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
72
73
74
75
76

  # Make sure output lies between [-1, 1].
  return tf.tanh(images)


77
def discriminator(img, unused_conditioning, is_training=True):
78
79
80
81
82
83
84
85
86
87
  """Discriminator for CIFAR images.

  Args:
    img: A Tensor of shape [batch size, width, height, channels], that can be
      either real or generated. It is the discriminator's goal to distinguish
      between the two.
    unused_conditioning: The TFGAN API can help with conditional GANs, which
      would require extra `condition` information to both the generator and the
      discriminator. Since this example is not conditional, we do not use this
      argument.
88
89
90
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
91
92
93
94
95
96

  Returns:
    A 1D Tensor of shape [batch size] representing the confidence that the
    images are real. The output can lie in [-inf, inf], with positive values
    indicating high confidence that the images are real.
  """
97
  logits, _ = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
98
99
100
101
102
103
  return logits


# (joelshor): This discriminator creates variables that aren't used, and
# causes logging warnings. Improve `dcgan` nets to accept a target end layer,
# so extraneous variables aren't created.
104
def conditional_discriminator(img, conditioning, is_training=True):
105
106
107
108
109
110
111
  """Discriminator for CIFAR images.

  Args:
    img: A Tensor of shape [batch size, width, height, channels], that can be
      either real or generated. It is the discriminator's goal to distinguish
      between the two.
    conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
112
113
114
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
115
116
117
118
119
120

  Returns:
    A 1D Tensor of shape [batch size] representing the confidence that the
    images are real. The output can lie in [-inf, inf], with positive values
    indicating high confidence that the images are real.
  """
121
  logits, end_points = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
122
123
124
125
126
127
128
129
130

  # Condition the last convolution layer.
  _, one_hot_labels = conditioning
  net = _last_conv_layer(end_points)
  net = tfgan.features.condition_tensor_from_onehot(
      tf.contrib.layers.flatten(net), one_hot_labels)
  logits = tf.contrib.layers.linear(net, 1)

  return logits