network.py 3.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Neural network for a StarGAN model.

This module contains the Generator and Discriminator Neural Network to build a
StarGAN model.

See https://arxiv.org/abs/1711.09020 for details about the model.

See https://github.com/yunjey/StarGAN for the original pytorch implementation.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



import tensorflow as tf

import layers
import ops


def generator(inputs, targets):
  """Generator module.

  Piece everything together for the Generator.

  PyTorch Version:
  https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L22

  Args:
    inputs: Tensor of shape (batch_size, h, w, c) representing the
      images/information that we want to transform.
    targets: Tensor of shape (batch_size, num_domains) representing the target
      domain the generator should transform the image/information to.

  Returns:
    Tensor of shape (batch_size, h, w, c) as the inputs.
  """

  with tf.variable_scope('generator'):

    input_with_condition = ops.condition_input_with_pixel_padding(
        inputs, targets)

    down_sample = layers.generator_down_sample(input_with_condition)

    bottleneck = layers.generator_bottleneck(down_sample)

    up_sample = layers.generator_up_sample(bottleneck, inputs.shape[-1])

  return up_sample


def discriminator(input_net, class_num):
  """Discriminator Module.

  Piece everything together and reshape the output source tensor

  PyTorch Version:
  https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/model.py#L63

  Notes:
  The PyTorch Version run the reduce_mean operation later in their solver:
  https://github.com/yunjey/StarGAN/blob/fbdb6a6ce2a4a92e1dc034faec765e0dbe4b8164/solver.py#L245

  Args:
    input_net: Tensor of shape (batch_size, h, w, c) as batch of images.
    class_num: (int) number of domain to be predicted

  Returns:
    output_src: Tensor of shape (batch_size) where each value is a logit
    representing whether the image is real of fake.
    output_cls: Tensor of shape (batch_size, class_um) where each value is a
    logit representing whether the image is in the associated domain.
  """

  with tf.variable_scope('discriminator'):

    hidden = layers.discriminator_input_hidden(input_net)

    output_src = layers.discriminator_output_source(hidden)
    output_src = tf.contrib.layers.flatten(output_src)
    output_src = tf.reduce_mean(output_src, axis=1)

    output_cls = layers.discriminator_output_class(hidden, class_num)

  return output_src, output_cls