Commit 0d961be2 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #1439 from rsepassi/adv_text

Updates to adversarial_text model
parents afe2e68b a97304d5
...@@ -83,8 +83,10 @@ def virtual_adversarial_loss(logits, embedded, inputs, ...@@ -83,8 +83,10 @@ def virtual_adversarial_loss(logits, embedded, inputs,
""" """
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details. # Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits = tf.stop_gradient(logits) logits = tf.stop_gradient(logits)
# Only care about the KL divergence on the final timestep.
weights = _end_of_seq_mask(inputs.labels) weights = _end_of_seq_mask(inputs.labels)
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim) # shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length) d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length)
...@@ -173,11 +175,15 @@ def _mask_by_length(t, length): ...@@ -173,11 +175,15 @@ def _mask_by_length(t, length):
def _scale_l2(x, norm_length): def _scale_l2(x, norm_length):
# shape(x) = (batch, num_timesteps, d) # shape(x) = (batch, num_timesteps, d)
x /= (1e-12 + tf.reduce_max(tf.abs(x), 2, keep_dims=True))
x_2 = tf.reduce_sum(tf.pow(x, 2), 2, keep_dims=True)
x /= tf.sqrt(1e-6 + x_2)
return norm_length * x # Divide x by max(abs(x)) for a numerically stable L2 norm.
# 2norm(x) = a * 2norm(x/a)
# Scale over the full sequence, dims (1, 2)
alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12
l2_norm = alpha * tf.sqrt(tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2),
keep_dims=True) + 1e-6)
x_unit = x / l2_norm
return norm_length * x_unit
def _end_of_seq_mask(tokens): def _end_of_seq_mask(tokens):
...@@ -225,5 +231,8 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights): ...@@ -225,5 +231,8 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
num_labels = tf.reduce_sum(weights) num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels) num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl') kl.get_shape().assert_has_rank(2)
weights.get_shape().assert_has_rank(1)
loss = tf.identity(tf.reduce_sum(tf.expand_dims(weights, -1) * kl) /
num_labels, name='kl')
return loss return loss
...@@ -84,28 +84,35 @@ def run_eval(eval_ops, summary_writer, saver): ...@@ -84,28 +84,35 @@ def run_eval(eval_ops, summary_writer, saver):
metric_names, ops = zip(*eval_ops.items()) metric_names, ops = zip(*eval_ops.items())
value_ops, update_ops = zip(*ops) value_ops, update_ops = zip(*ops)
value_ops_dict = dict(zip(metric_names, value_ops))
# Run update ops # Run update ops
num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
tf.logging.info('Running %d batches for evaluation.', num_batches) tf.logging.info('Running %d batches for evaluation.', num_batches)
for i in range(num_batches): for i in range(num_batches):
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
tf.logging.info('Running batch %d/%d...', i + 1, num_batches) tf.logging.info('Running batch %d/%d...', i + 1, num_batches)
if (i + 1) % 50 == 0:
_log_values(sess, value_ops_dict)
sess.run(update_ops) sess.run(update_ops)
values = sess.run(value_ops) _log_values(sess, value_ops_dict, summary_writer=summary_writer)
metric_values = dict(zip(metric_names, values))
tf.logging.info('Eval metric values:') def _log_values(sess, value_ops, summary_writer=None):
summary = tf.summary.Summary() metric_names, value_ops = zip(*value_ops.items())
for name, val in metric_values.items(): values = sess.run(value_ops)
summary.value.add(tag=name, simple_value=val)
tf.logging.info('%s = %.3f', name, val)
tf.logging.info('Eval metric values:')
summary = tf.summary.Summary()
for name, val in zip(metric_names, values):
summary.value.add(tag=name, simple_value=val)
tf.logging.info('%s = %.3f', name, val)
if summary_writer is not None:
global_step_val = sess.run(tf.train.get_global_step()) global_step_val = sess.run(tf.train.get_global_step())
summary_writer.add_summary(summary, global_step_val) summary_writer.add_summary(summary, global_step_val)
return metric_values
def main(_): def main(_):
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
......
...@@ -81,11 +81,10 @@ class Embedding(K.layers.Layer): ...@@ -81,11 +81,10 @@ class Embedding(K.layers.Layer):
def _normalize(self, emb): def _normalize(self, emb):
weights = self.vocab_freqs / tf.reduce_sum(self.vocab_freqs) weights = self.vocab_freqs / tf.reduce_sum(self.vocab_freqs)
mean = tf.reduce_sum(weights * emb, 0, keep_dims=True)
emb -= tf.reduce_sum(weights * emb, 0, keep_dims=True) var = tf.reduce_sum(weights * tf.pow(emb - mean, 2.), 0, keep_dims=True)
emb /= tf.sqrt(1e-6 + tf.reduce_sum( stddev = tf.sqrt(1e-6 + var)
weights * tf.pow(emb, 2.), 0, keep_dims=True)) return (emb - mean) / stddev
return emb
class LSTM(object): class LSTM(object):
...@@ -201,7 +200,7 @@ def classification_loss(logits, labels, weights): ...@@ -201,7 +200,7 @@ def classification_loss(logits, labels, weights):
logits: 2-D [timesteps*batch_size, m] float tensor, where m=1 if logits: 2-D [timesteps*batch_size, m] float tensor, where m=1 if
num_classes=2, otherwise m=num_classes. num_classes=2, otherwise m=num_classes.
labels: 1-D [timesteps*batch_size] integer tensor. labels: 1-D [timesteps*batch_size] integer tensor.
weights: 2-D [timesteps*batch_size] float tensor. weights: 1-D [timesteps*batch_size] float tensor.
Returns: Returns:
Loss scalar of type float. Loss scalar of type float.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment