adversarial_losses.py 8.45 KB
Newer Older
1
# Copyright 2017 Google Inc. All Rights Reserved.
Ryan Sepassi's avatar
Ryan Sepassi committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adversarial losses for text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

20
21
# Dependency imports

Ryan Sepassi's avatar
Ryan Sepassi committed
22
23
24
25
26
27
import tensorflow as tf

flags = tf.app.flags
FLAGS = flags.FLAGS

# Adversarial and virtual adversarial training parameters.
28
flags.DEFINE_float('perturb_norm_length', 5.0,
Ryan Sepassi's avatar
Ryan Sepassi committed
29
                   'Norm length of adversarial perturbation to be '
30
31
                   'optimized with validation. '
                   '5.0 is optimal on IMDB with virtual adversarial training. ')
Ryan Sepassi's avatar
Ryan Sepassi committed
32
33
34

# Virtual adversarial training parameters
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration')
35
flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
Ryan Sepassi's avatar
Ryan Sepassi committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                   'Small constant for finite difference method')

# Parameters for building the graph
flags.DEFINE_string('adv_training_method', None,
                    'The flag which specifies training method. '
                    '"rp"  : random perturbation training '
                    '"at"  : adversarial training '
                    '"vat" : virtual adversarial training '
                    '"atvat" : at + vat ')
flags.DEFINE_float('adv_reg_coeff', 1.0,
                   'Regularization coefficient of adversarial loss.')


def random_perturbation_loss(embedded, length, loss_fn):
  """Adds noise to embeddings and recomputes classification loss."""
  noise = tf.random_normal(shape=tf.shape(embedded))
  perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length)
  return loss_fn(embedded + perturb)


def adversarial_loss(embedded, loss, loss_fn):
  """Adds gradient to embedding and recomputes classification loss."""
  grad, = tf.gradients(
      loss,
      embedded,
      aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
  grad = tf.stop_gradient(grad)
  perturb = _scale_l2(grad, FLAGS.perturb_norm_length)
  return loss_fn(embedded + perturb)


def virtual_adversarial_loss(logits, embedded, inputs,
                             logits_from_embedding_fn):
  """Virtual adversarial loss.

  Computes virtual adversarial perturbation by finite difference method and
  power iteration, adds it to the embedding, and computes the KL divergence
  between the new logits and the original logits.

  Args:
    logits: 2-D float Tensor, [num_timesteps*batch_size, m], where m=1 if
      num_classes=2, otherwise m=num_classes.
    embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
    inputs: VatxtInput.
    logits_from_embedding_fn: callable that takes embeddings and returns
      classifier logits.

  Returns:
    kl: float scalar.
  """
  # Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
  logits = tf.stop_gradient(logits)
88

89
  # Only care about the KL divergence on the final timestep.
90
91
  weights = inputs.eos_weights
  assert weights is not None
Ryan Sepassi's avatar
Ryan Sepassi committed
92

93
  # Initialize perturbation with random noise.
Ryan Sepassi's avatar
Ryan Sepassi committed
94
  # shape(embedded) = (batch_size, num_timesteps, embedding_dim)
95
  d = tf.random_normal(shape=tf.shape(embedded))
Ryan Sepassi's avatar
Ryan Sepassi committed
96
97
98
99
100
101

  # Perform finite difference method and power iteration.
  # See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
  # Adding small noise to input and taking gradient with respect to the noise
  # corresponds to 1 power iteration.
  for _ in xrange(FLAGS.num_power_iteration):
102
103
    d = _scale_l2(
        _mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
Ryan Sepassi's avatar
Ryan Sepassi committed
104
105
106
107
108
109
110
111
    d_logits = logits_from_embedding_fn(embedded + d)
    kl = _kl_divergence_with_logits(logits, d_logits, weights)
    d, = tf.gradients(
        kl,
        d,
        aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
    d = tf.stop_gradient(d)

112
  perturb = _scale_l2(d, FLAGS.perturb_norm_length)
Ryan Sepassi's avatar
Ryan Sepassi committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  vadv_logits = logits_from_embedding_fn(embedded + perturb)
  return _kl_divergence_with_logits(logits, vadv_logits, weights)


def random_perturbation_loss_bidir(embedded, length, loss_fn):
  """Adds noise to embeddings and recomputes classification loss."""
  noise = [tf.random_normal(shape=tf.shape(emb)) for emb in embedded]
  masked = [_mask_by_length(n, length) for n in noise]
  scaled = [_scale_l2(m, FLAGS.perturb_norm_length) for m in masked]
  return loss_fn([e + s for (e, s) in zip(embedded, scaled)])


def adversarial_loss_bidir(embedded, loss, loss_fn):
  """Adds gradient to embeddings and recomputes classification loss."""
  grads = tf.gradients(
      loss,
      embedded,
      aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
  adv_exs = [
      emb + _scale_l2(tf.stop_gradient(g), FLAGS.perturb_norm_length)
      for emb, g in zip(embedded, grads)
  ]
  return loss_fn(adv_exs)


def virtual_adversarial_loss_bidir(logits, embedded, inputs,
                                   logits_from_embedding_fn):
  """Virtual adversarial loss for bidirectional models."""
  logits = tf.stop_gradient(logits)
  f_inputs, _ = inputs
143
144
  weights = f_inputs.eos_weights
  assert weights is not None
Ryan Sepassi's avatar
Ryan Sepassi committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

  perturbs = [
      _mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length)
      for emb in embedded
  ]
  for _ in xrange(FLAGS.num_power_iteration):
    perturbs = [
        _scale_l2(d, FLAGS.small_constant_for_finite_diff) for d in perturbs
    ]
    d_logits = logits_from_embedding_fn(
        [emb + d for (emb, d) in zip(embedded, perturbs)])
    kl = _kl_divergence_with_logits(logits, d_logits, weights)
    perturbs = tf.gradients(
        kl,
        perturbs,
        aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
    perturbs = [tf.stop_gradient(d) for d in perturbs]

163
  perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs]
Ryan Sepassi's avatar
Ryan Sepassi committed
164
165
166
167
168
169
170
171
  vadv_logits = logits_from_embedding_fn(
      [emb + d for (emb, d) in zip(embedded, perturbs)])
  return _kl_divergence_with_logits(logits, vadv_logits, weights)


def _mask_by_length(t, length):
  """Mask t, 3-D [batch, time, dim], by length, 1-D [batch,]."""
  maxlen = t.get_shape().as_list()[1]
172
173
174

  # Subtract 1 from length to prevent the perturbation from going on 'eos'
  mask = tf.sequence_mask(length - 1, maxlen=maxlen)
Ryan Sepassi's avatar
Ryan Sepassi committed
175
176
177
178
179
180
181
  mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
  # shape(mask) = (batch, num_timesteps, 1)
  return t * mask


def _scale_l2(x, norm_length):
  # shape(x) = (batch, num_timesteps, d)
182
183
184
185
  # Divide x by max(abs(x)) for a numerically stable L2 norm.
  # 2norm(x) = a * 2norm(x/a)
  # Scale over the full sequence, dims (1, 2)
  alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12
186
187
  l2_norm = alpha * tf.sqrt(
      tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6)
188
189
  x_unit = x / l2_norm
  return norm_length * x_unit
Ryan Sepassi's avatar
Ryan Sepassi committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210


def _kl_divergence_with_logits(q_logits, p_logits, weights):
  """Returns weighted KL divergence between distributions q and p.

  Args:
    q_logits: logits for 1st argument of KL divergence shape
              [num_timesteps * batch_size, num_classes] if num_classes > 2, and
              [num_timesteps * batch_size] if num_classes == 2.
    p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
    weights: 1-D float tensor with shape [num_timesteps * batch_size].
             Elements should be 1.0 only on end of sequences

  Returns:
    KL: float scalar.
  """
  # For logistic regression
  if FLAGS.num_classes == 2:
    q = tf.nn.sigmoid(q_logits)
    kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
          tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
211
    kl = tf.squeeze(kl)
Ryan Sepassi's avatar
Ryan Sepassi committed
212
213
214

  # For softmax regression
  else:
215
    q = tf.nn.softmax(q_logits)
216
217
    kl = tf.reduce_sum(
        q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
Ryan Sepassi's avatar
Ryan Sepassi committed
218
219
220
221

  num_labels = tf.reduce_sum(weights)
  num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)

222
  kl.get_shape().assert_has_rank(1)
223
  weights.get_shape().assert_has_rank(1)
224
  loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
Ryan Sepassi's avatar
Ryan Sepassi committed
225
  return loss