Commit e8dd2bf3 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fix adversarial training with recent shape changes.

PiperOrigin-RevId: 173414999
parent 7d921c12
...@@ -73,7 +73,7 @@ def virtual_adversarial_loss(logits, embedded, inputs, ...@@ -73,7 +73,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
between the new logits and the original logits. between the new logits and the original logits.
Args: Args:
logits: 2-D float Tensor, [num_timesteps*batch_size, m], where m=1 if logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if
num_classes=2, otherwise m=num_classes. num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim]. embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput. inputs: VatxtInput.
...@@ -89,6 +89,9 @@ def virtual_adversarial_loss(logits, embedded, inputs, ...@@ -89,6 +89,9 @@ def virtual_adversarial_loss(logits, embedded, inputs,
# Only care about the KL divergence on the final timestep. # Only care about the KL divergence on the final timestep.
weights = inputs.eos_weights weights = inputs.eos_weights
assert weights is not None assert weights is not None
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1)
# Initialize perturbation with random noise. # Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim) # shape(embedded) = (batch_size, num_timesteps, embedding_dim)
...@@ -101,6 +104,7 @@ def virtual_adversarial_loss(logits, embedded, inputs, ...@@ -101,6 +104,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
for _ in xrange(FLAGS.num_power_iteration): for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2( d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff) _mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d) d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights) kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients( d, = tf.gradients(
...@@ -141,6 +145,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs, ...@@ -141,6 +145,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
logits = tf.stop_gradient(logits) logits = tf.stop_gradient(logits)
f_inputs, _ = inputs f_inputs, _ = inputs
weights = f_inputs.eos_weights weights = f_inputs.eos_weights
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), f_inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(f_inputs.eos_weights, indices), 1)
assert weights is not None assert weights is not None
perturbs = [ perturbs = [
...@@ -194,10 +201,10 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights): ...@@ -194,10 +201,10 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
Args: Args:
q_logits: logits for 1st argument of KL divergence shape q_logits: logits for 1st argument of KL divergence shape
[num_timesteps * batch_size, num_classes] if num_classes > 2, and [batch_size, num_timesteps, num_classes] if num_classes > 2, and
[num_timesteps * batch_size] if num_classes == 2. [batch_size, num_timesteps] if num_classes == 2.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits. p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
weights: 1-D float tensor with shape [num_timesteps * batch_size]. weights: 1-D float tensor with shape [batch_size, num_timesteps].
Elements should be 1.0 only on end of sequences Elements should be 1.0 only on end of sequences
Returns: Returns:
...@@ -208,18 +215,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights): ...@@ -208,18 +215,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
q = tf.nn.sigmoid(q_logits) q = tf.nn.sigmoid(q_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) + kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q)) tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
kl = tf.squeeze(kl) kl = tf.squeeze(kl, 2)
# For softmax regression # For softmax regression
else: else:
q = tf.nn.softmax(q_logits) q = tf.nn.softmax(q_logits)
kl = tf.reduce_sum( kl = tf.reduce_sum(
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1) q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), -1)
num_labels = tf.reduce_sum(weights) num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels) num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
kl.get_shape().assert_has_rank(1) kl.get_shape().assert_has_rank(2)
weights.get_shape().assert_has_rank(1) weights.get_shape().assert_has_rank(2)
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl') loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
return loss return loss
...@@ -185,8 +185,8 @@ class VatxtModel(object): ...@@ -185,8 +185,8 @@ class VatxtModel(object):
if FLAGS.single_label: if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.gather_nd(inputs.labels, indices) labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.gather_nd(inputs.weights, indices) weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else: else:
labels = inputs.labels labels = inputs.labels
weights = inputs.weights weights = inputs.weights
...@@ -259,8 +259,8 @@ class VatxtModel(object): ...@@ -259,8 +259,8 @@ class VatxtModel(object):
if FLAGS.single_label: if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.gather_nd(inputs.labels, indices) labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.gather_nd(inputs.weights, indices) weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else: else:
labels = inputs.labels labels = inputs.labels
weights = inputs.weights weights = inputs.weights
...@@ -303,9 +303,9 @@ class VatxtModel(object): ...@@ -303,9 +303,9 @@ class VatxtModel(object):
inputs.length) inputs.length)
if FLAGS.single_label: if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
lstm_out = tf.gather_nd(lstm_out, indices) lstm_out = tf.expand_dims(tf.gather_nd(lstm_out, indices), 1)
labels = tf.gather_nd(inputs.labels, indices) labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.gather_nd(inputs.weights, indices) weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else: else:
labels = inputs.labels labels = inputs.labels
weights = inputs.weights weights = inputs.weights
......
...@@ -217,7 +217,7 @@ def classification_loss(logits, labels, weights): ...@@ -217,7 +217,7 @@ def classification_loss(logits, labels, weights):
# Logistic loss # Logistic loss
if inner_dim == 1: if inner_dim == 1:
loss = tf.nn.sigmoid_cross_entropy_with_logits( loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=tf.squeeze(logits), labels=tf.cast(labels, tf.float32)) logits=tf.squeeze(logits, -1), labels=tf.cast(labels, tf.float32))
# Softmax loss # Softmax loss
else: else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
...@@ -252,7 +252,7 @@ def predictions(logits): ...@@ -252,7 +252,7 @@ def predictions(logits):
with tf.name_scope('predictions'): with tf.name_scope('predictions'):
# For binary classification # For binary classification
if inner_dim == 1: if inner_dim == 1:
pred = tf.cast(tf.greater(tf.squeeze(logits), 0.5), tf.int64) pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64)
# For multi-class classification # For multi-class classification
else: else:
pred = tf.argmax(logits, 1) pred = tf.argmax(logits, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment