Commit e8dd2bf3 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fix adversarial training with recent shape changes.

PiperOrigin-RevId: 173414999
parent 7d921c12
......@@ -73,7 +73,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
between the new logits and the original logits.
Args:
logits: 2-D float Tensor, [num_timesteps*batch_size, m], where m=1 if
logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if
num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput.
......@@ -89,6 +89,9 @@ def virtual_adversarial_loss(logits, embedded, inputs,
# Only care about the KL divergence on the final timestep.
weights = inputs.eos_weights
assert weights is not None
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1)
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
......@@ -101,6 +104,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients(
......@@ -141,6 +145,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
logits = tf.stop_gradient(logits)
f_inputs, _ = inputs
weights = f_inputs.eos_weights
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), f_inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(f_inputs.eos_weights, indices), 1)
assert weights is not None
perturbs = [
......@@ -194,10 +201,10 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
Args:
q_logits: logits for 1st argument of KL divergence shape
[num_timesteps * batch_size, num_classes] if num_classes > 2, and
[num_timesteps * batch_size] if num_classes == 2.
[batch_size, num_timesteps, num_classes] if num_classes > 2, and
[batch_size, num_timesteps] if num_classes == 2.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
weights: 1-D float tensor with shape [num_timesteps * batch_size].
weights: 1-D float tensor with shape [batch_size, num_timesteps].
Elements should be 1.0 only on end of sequences
Returns:
......@@ -208,18 +215,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
q = tf.nn.sigmoid(q_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
kl = tf.squeeze(kl)
kl = tf.squeeze(kl, 2)
# For softmax regression
else:
q = tf.nn.softmax(q_logits)
kl = tf.reduce_sum(
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), -1)
num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
kl.get_shape().assert_has_rank(1)
weights.get_shape().assert_has_rank(1)
kl.get_shape().assert_has_rank(2)
weights.get_shape().assert_has_rank(2)
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
return loss
......@@ -185,8 +185,8 @@ class VatxtModel(object):
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.gather_nd(inputs.labels, indices)
weights = tf.gather_nd(inputs.weights, indices)
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else:
labels = inputs.labels
weights = inputs.weights
......@@ -259,8 +259,8 @@ class VatxtModel(object):
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.gather_nd(inputs.labels, indices)
weights = tf.gather_nd(inputs.weights, indices)
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else:
labels = inputs.labels
weights = inputs.weights
......@@ -303,9 +303,9 @@ class VatxtModel(object):
inputs.length)
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
lstm_out = tf.gather_nd(lstm_out, indices)
labels = tf.gather_nd(inputs.labels, indices)
weights = tf.gather_nd(inputs.weights, indices)
lstm_out = tf.expand_dims(tf.gather_nd(lstm_out, indices), 1)
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else:
labels = inputs.labels
weights = inputs.weights
......
......@@ -217,7 +217,7 @@ def classification_loss(logits, labels, weights):
# Logistic loss
if inner_dim == 1:
loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=tf.squeeze(logits), labels=tf.cast(labels, tf.float32))
logits=tf.squeeze(logits, -1), labels=tf.cast(labels, tf.float32))
# Softmax loss
else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
......@@ -252,7 +252,7 @@ def predictions(logits):
with tf.name_scope('predictions'):
# For binary classification
if inner_dim == 1:
pred = tf.cast(tf.greater(tf.squeeze(logits), 0.5), tf.int64)
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64)
# For multi-class classification
else:
pred = tf.argmax(logits, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment