networks.py 4.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN compression example using TFGAN."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from slim.nets import dcgan
from slim.nets import pix2pix


def _last_conv_layer(end_points):
  """"Returns the last convolutional layer from an endpoints dictionary."""
  conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
  conv_list.sort()
  return end_points[conv_list[-1]]


def _encoder(img_batch, is_training=True, bits=64, depth=64):
  """Maps images to internal representation.

  Args:
    img_batch: Stuff
    is_training: Stuff
    bits: Number of bits per patch.
    depth: Stuff

  Returns:
    Real-valued 2D Tensor of size [batch_size, bits].
  """
  _, end_points = dcgan.discriminator(
      img_batch, depth=depth, is_training=is_training, scope='Encoder')

  # (joelshor): Make the DCGAN convolutional layer that converts to logits
  # not trainable, since it doesn't affect the encoder output.

  # Get the pre-logit layer, which is the last conv.
  net = _last_conv_layer(end_points)

  # Transform the features to the proper number of bits.
  with tf.variable_scope('EncoderTransformer'):
    encoded = tf.contrib.layers.conv2d(net, bits, kernel_size=1, stride=1,
                                       padding='VALID', normalizer_fn=None,
                                       activation_fn=None)
  encoded = tf.squeeze(encoded, [1, 2])
  encoded.shape.assert_has_rank(2)

  # Map encoded to the range [-1, 1].
  return tf.nn.softsign(encoded)


def _binarizer(prebinary_codes, is_training):
  """Binarize compression logits.

  During training, add noise, as in https://arxiv.org/pdf/1611.01704.pdf. During
  eval, map [-1, 1] -> {-1, 1}.

  Args:
    prebinary_codes: Floating-point tensors corresponding to pre-binary codes.
      Shape is [batch, code_length].
    is_training: A python bool. If True, add noise. If false, binarize.

  Returns:
    Binarized codes. Shape is [batch, code_length].

  Raises:
    ValueError: If the shape of `prebinary_codes` isn't static.
  """
  if is_training:
    # In order to train codes that can be binarized during eval, we add noise as
    # in https://arxiv.org/pdf/1611.01704.pdf. Another option is to use a
    # stochastic node, as in https://arxiv.org/abs/1608.05148.
    noise = tf.random_uniform(
        prebinary_codes.shape,
        minval=-1.0,
        maxval=1.0)
    return prebinary_codes + noise
  else:
    return tf.sign(prebinary_codes)


def _decoder(codes, final_size, is_training, depth=64):
  """Compression decoder."""
  decoded_img, _ = dcgan.generator(
      codes,
      depth=depth,
      final_size=final_size,
      num_outputs=3,
      is_training=is_training,
      scope='Decoder')

  # Map output to [-1, 1].
  # Use softsign instead of tanh, as per empirical results of
  # http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
  return tf.nn.softsign(decoded_img)


def _validate_image_inputs(image_batch):
  image_batch.shape.assert_has_rank(4)
  image_batch.shape[1:].assert_is_fully_defined()


def compression_model(image_batch, num_bits=64, depth=64, is_training=True):
  """Image compression model.

  Args:
    image_batch: A batch of images to compress and reconstruct. Images should
      be normalized already. Shape is [batch, height, width, channels].
    num_bits: Desired number of bits per image in the compressed representation.
    depth: The base number of filters for the encoder and decoder networks.
    is_training: A python bool. If False, run in evaluation mode.

  Returns:
    uncompressed images, binary codes, prebinary codes
  """
  image_batch = tf.convert_to_tensor(image_batch)
  _validate_image_inputs(image_batch)
  final_size = image_batch.shape.as_list()[1]

  prebinary_codes = _encoder(image_batch, is_training, num_bits, depth)
  binary_codes = _binarizer(prebinary_codes, is_training)
  uncompressed_imgs = _decoder(binary_codes, final_size, is_training, depth)
  return uncompressed_imgs, binary_codes, prebinary_codes


def discriminator(image_batch, unused_conditioning=None, depth=64):
  """A thin wrapper around the pix2pix discriminator to conform to TFGAN API."""
  logits, _ = pix2pix.pix2pix_discriminator(
      image_batch, num_filters=[depth, 2 * depth, 4 * depth, 8 * depth])
  return tf.layers.flatten(logits)