"vscode:/vscode.git/clone" did not exist on "25ab939690032b71e0ea50f813d7e93a62263d58"
Unverified Commit af6527c9 authored by Andrew M Dai's avatar Andrew M Dai Committed by GitHub
Browse files

Merge pull request #3402 from a-dai/master

Merging in improvements and fixes to adversarial_text
parents f51da4bb 9a9e4228
/official/ @nealwu @k-w-w @karmel
/research/adversarial_crypto/ @dave-andersen
/research/adversarial_text/ @rsepassi
/research/adversarial_text/ @rsepassi @a-dai
/research/adv_imagenet_models/ @AlexeyKurakin
/research/attention_ocr/ @alexgorban
/research/audioset/ @plakal @dpwe
......
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# Binaries
# ==============================================================================
py_binary(
......@@ -8,7 +10,7 @@ py_binary(
deps = [
":graphs",
# google3 file dep,
# tensorflow dep,
# tensorflow internal dep,
],
)
......@@ -19,7 +21,7 @@ py_binary(
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
# tensorflow internal dep,
],
)
......@@ -32,7 +34,8 @@ py_binary(
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
# tensorflow internal gpu deps
# tensorflow internal dep,
],
)
......
......@@ -154,3 +154,4 @@ control which dataset is processed and how.
## Contact for Issues
* Ryan Sepassi, @rsepassi
* Andrew M. Dai, @a-dai
......@@ -16,7 +16,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
# Dependency imports
......@@ -39,6 +38,8 @@ flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
# Parameters for building the graph
flags.DEFINE_string('adv_training_method', None,
'The flag which specifies training method. '
'"" : non-adversarial training (e.g. for running the '
' semi-supervised sequence learning model) '
'"rp" : random perturbation training '
'"at" : adversarial training '
'"vat" : virtual adversarial training '
......@@ -74,7 +75,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
between the new logits and the original logits.
Args:
logits: 2-D float Tensor, [num_timesteps*batch_size, m], where m=1 if
logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if
num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput.
......@@ -90,6 +91,9 @@ def virtual_adversarial_loss(logits, embedded, inputs,
# Only care about the KL divergence on the final timestep.
weights = inputs.eos_weights
assert weights is not None
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1)
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
......@@ -102,6 +106,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients(
......@@ -142,6 +147,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
logits = tf.stop_gradient(logits)
f_inputs, _ = inputs
weights = f_inputs.eos_weights
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), f_inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(f_inputs.eos_weights, indices), 1)
assert weights is not None
perturbs = [
......@@ -195,10 +203,10 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
Args:
q_logits: logits for 1st argument of KL divergence shape
[num_timesteps * batch_size, num_classes] if num_classes > 2, and
[num_timesteps * batch_size] if num_classes == 2.
[batch_size, num_timesteps, num_classes] if num_classes > 2, and
[batch_size, num_timesteps] if num_classes == 2.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
weights: 1-D float tensor with shape [num_timesteps * batch_size].
weights: 1-D float tensor with shape [batch_size, num_timesteps].
Elements should be 1.0 only on end of sequences
Returns:
......@@ -209,18 +217,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
q = tf.nn.sigmoid(q_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
kl = tf.squeeze(kl)
kl = tf.squeeze(kl, 2)
# For softmax regression
else:
q = tf.nn.softmax(q_logits)
kl = tf.reduce_sum(
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), -1)
num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
kl.get_shape().assert_has_rank(1)
weights.get_shape().assert_has_rank(1)
kl.get_shape().assert_has_rank(2)
weights.get_shape().assert_has_rank(2)
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
return loss
......@@ -271,7 +271,7 @@ def build_labeled_sequence(seq, class_label, label_gain=False):
Args:
seq: SequenceWrapper.
class_label: bool.
class_label: integer, starting from 0.
label_gain: bool. If True, class_label will be put on every timestep and
weight will increase linearly from 0 to 1.
......
......@@ -259,7 +259,7 @@ def dbpedia_documents(dataset='train',
content=content,
is_validation=is_validation,
is_test=False,
label=int(row[0]),
label=int(row[0]) - 1, # Labels should start from 0
add_tokens=True)
......
......@@ -25,7 +25,7 @@ import time
import tensorflow as tf
import graphs
from adversarial_text import graphs
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -75,7 +75,8 @@ def run_eval(eval_ops, summary_writer, saver):
Returns:
dict<metric name, value>, with value being the average over all examples.
"""
sv = tf.train.Supervisor(logdir=FLAGS.eval_dir, saver=None, summary_op=None)
sv = tf.train.Supervisor(
logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None)
with sv.managed_session(
master=FLAGS.master, start_standard_services=False) as sess:
if not restore_from_checkpoint(sess, saver):
......@@ -113,6 +114,7 @@ def _log_values(sess, value_ops, summary_writer=None):
if summary_writer is not None:
global_step_val = sess.run(tf.train.get_global_step())
tf.logging.info('Finished eval for step ' + str(global_step_val))
summary_writer.add_summary(summary, global_step_val)
......
......@@ -24,9 +24,9 @@ import os
import tensorflow as tf
import adversarial_losses as adv_lib
import inputs as inputs_lib
import layers as layers_lib
from adversarial_text import adversarial_losses as adv_lib
from adversarial_text import inputs as inputs_lib
from adversarial_text import layers as layers_lib
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_timesteps', 100, 'Number of timesteps for BPTT')
# Model architechture
flags.DEFINE_bool('bidir_lstm', False, 'Whether to build a bidirectional LSTM.')
flags.DEFINE_bool('single_label', True, 'Whether the sequence has a single '
'label, for optimization.')
flags.DEFINE_integer('rnn_num_layers', 1, 'Number of LSTM layers.')
flags.DEFINE_integer('rnn_cell_size', 512,
'Number of hidden units in the LSTM.')
......@@ -181,7 +183,14 @@ class VatxtModel(object):
self.tensors['cl_logits'] = logits
self.tensors['cl_loss'] = loss
acc = layers_lib.accuracy(logits, inputs.labels, inputs.weights)
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else:
labels = inputs.labels
weights = inputs.weights
acc = layers_lib.accuracy(logits, labels, weights)
tf.summary.scalar('accuracy', acc)
adv_loss = (self.adversarial_loss() * tf.constant(
......@@ -189,11 +198,10 @@ class VatxtModel(object):
tf.summary.scalar('adversarial_loss', adv_loss)
total_loss = loss + adv_loss
tf.summary.scalar('total_classification_loss', total_loss)
with tf.control_dependencies([inputs.save_state(next_state)]):
total_loss = tf.identity(total_loss)
tf.summary.scalar('total_classification_loss', total_loss)
return total_loss
def language_model_graph(self, compute_loss=True):
......@@ -249,10 +257,17 @@ class VatxtModel(object):
_, next_state, logits, _ = self.cl_loss_from_embedding(
embedded, inputs=inputs, return_intermediates=True)
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else:
labels = inputs.labels
weights = inputs.weights
eval_ops = {
'accuracy':
tf.contrib.metrics.streaming_accuracy(
layers_lib.predictions(logits), inputs.labels, inputs.weights)
layers_lib.predictions(logits), labels, weights)
}
with tf.control_dependencies([inputs.save_state(next_state)]):
......@@ -286,8 +301,16 @@ class VatxtModel(object):
lstm_out, next_state = self.layers['lstm'](embedded, inputs.state,
inputs.length)
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
lstm_out = tf.expand_dims(tf.gather_nd(lstm_out, indices), 1)
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1)
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1)
else:
labels = inputs.labels
weights = inputs.weights
logits = self.layers['cl_logits'](lstm_out)
loss = layers_lib.classification_loss(logits, inputs.labels, inputs.weights)
loss = layers_lib.classification_loss(logits, labels, weights)
if return_intermediates:
return lstm_out, next_state, logits, loss
......@@ -419,12 +442,12 @@ class VatxtBidirModel(VatxtModel):
tf.summary.scalar('adversarial_loss', adv_loss)
total_loss = loss + adv_loss
tf.summary.scalar('total_classification_loss', total_loss)
saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)]
with tf.control_dependencies(saves):
total_loss = tf.identity(total_loss)
tf.summary.scalar('total_classification_loss', total_loss)
return total_loss
def language_model_graph(self, compute_loss=True):
......
......@@ -29,7 +29,7 @@ import tempfile
import tensorflow as tf
import graphs
from adversarial_text import graphs
from adversarial_text.data import data_utils
flags = tf.app.flags
......
......@@ -51,27 +51,16 @@ class VatxtInput(object):
batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID])
self._num_states = num_states
# Once the tokens have passed through embedding and LSTM, the output Tensor
# shapes will be time-major, i.e. shape = (time, batch, dim). Here we make
# both weights and labels time-major with a transpose, and then merge the
# time and batch dimensions such that they are both vectors of shape
# (time*batch).
w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT]
w = tf.transpose(w, [1, 0])
w = tf.reshape(w, [-1])
self._weights = w
l = batch.sequences[data_utils.SequenceWrapper.F_LABEL]
l = tf.transpose(l, [1, 0])
l = tf.reshape(l, [-1])
self._labels = l
# eos weights
self._eos_weights = None
if eos_id:
ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32)
ew = tf.transpose(ew, [1, 0])
ew = tf.reshape(ew, [-1])
self._eos_weights = ew
@property
......
......@@ -16,12 +16,11 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
# Dependency imports
import tensorflow as tf
K = tf.contrib.keras
K = tf.keras
def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.):
......@@ -96,7 +95,7 @@ class Embedding(K.layers.Layer):
class LSTM(object):
"""LSTM layer using static_rnn.
"""LSTM layer using dynamic_rnn.
Exposes variables in `trainable_weights` property.
"""
......@@ -120,16 +119,11 @@ class LSTM(object):
])
# shape(x) = (batch_size, num_timesteps, embedding_dim)
# Convert into a time-major list for static_rnn
x = tf.unstack(tf.transpose(x, perm=[1, 0, 2]))
lstm_out, next_state = tf.contrib.rnn.static_rnn(
lstm_out, next_state = tf.nn.dynamic_rnn(
cell, x, initial_state=initial_state, sequence_length=seq_length)
# Merge time and batch dimensions
# shape(lstm_out) = timesteps * (batch_size, cell_size)
lstm_out = tf.concat(lstm_out, 0)
# shape(lstm_out) = (timesteps*batch_size, cell_size)
# shape(lstm_out) = (batch_size, timesteps, cell_size)
if self.keep_prob < 1.:
lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
......@@ -154,6 +148,7 @@ class SoftmaxLoss(K.layers.Layer):
self.num_candidate_samples = num_candidate_samples
self.vocab_freqs = vocab_freqs
super(SoftmaxLoss, self).__init__(**kwargs)
self.multiclass_dense_layer = K.layers.Dense(self.vocab_size)
def build(self, input_shape):
input_shape = input_shape[0]
......@@ -166,6 +161,7 @@ class SoftmaxLoss(K.layers.Layer):
shape=(self.vocab_size,),
name='lm_lin_b',
initializer=K.initializers.glorot_uniform())
self.multiclass_dense_layer.build(input_shape)
super(SoftmaxLoss, self).build(input_shape)
......@@ -173,25 +169,30 @@ class SoftmaxLoss(K.layers.Layer):
x, labels, weights = inputs
if self.num_candidate_samples > -1:
assert self.vocab_freqs is not None
labels = tf.expand_dims(labels, -1)
labels_reshaped = tf.reshape(labels, [-1])
labels_reshaped = tf.expand_dims(labels_reshaped, -1)
sampled = tf.nn.fixed_unigram_candidate_sampler(
true_classes=labels,
true_classes=labels_reshaped,
num_true=1,
num_sampled=self.num_candidate_samples,
unique=True,
range_max=self.vocab_size,
unigrams=self.vocab_freqs)
inputs_reshaped = tf.reshape(x, [-1, int(x.get_shape()[2])])
lm_loss = tf.nn.sampled_softmax_loss(
weights=tf.transpose(self.lin_w),
biases=self.lin_b,
labels=labels,
inputs=x,
labels=labels_reshaped,
inputs=inputs_reshaped,
num_sampled=self.num_candidate_samples,
num_classes=self.vocab_size,
sampled_values=sampled)
lm_loss = tf.reshape(
lm_loss,
[int(x.get_shape()[0]), int(x.get_shape()[1])])
else:
logits = tf.matmul(x, self.lin_w) + self.lin_b
logits = self.multiclass_dense_layer(x)
lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
......@@ -218,7 +219,7 @@ def classification_loss(logits, labels, weights):
# Logistic loss
if inner_dim == 1:
loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=tf.squeeze(logits), labels=tf.cast(labels, tf.float32))
logits=tf.squeeze(logits, -1), labels=tf.cast(labels, tf.float32))
# Softmax loss
else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
......@@ -253,10 +254,10 @@ def predictions(logits):
with tf.name_scope('predictions'):
# For binary classification
if inner_dim == 1:
pred = tf.cast(tf.greater(tf.squeeze(logits), 0.5), tf.int64)
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64)
# For multi-class classification
else:
pred = tf.argmax(logits, 1)
pred = tf.argmax(logits, 2)
return pred
......@@ -355,10 +356,9 @@ def optimize(loss,
opt.ready_for_local_init_op)
else:
# Non-sync optimizer
variables_averages_op = variable_averages.apply(tvars)
apply_gradient_op = opt.apply_gradients(grads_and_vars, global_step)
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train_op')
with tf.control_dependencies([apply_gradient_op]):
train_op = variable_averages.apply(tvars)
return train_op
......
......@@ -27,8 +27,8 @@ from __future__ import print_function
import tensorflow as tf
import graphs
import train_utils
from adversarial_text import graphs
from adversarial_text import train_utils
FLAGS = tf.app.flags.FLAGS
......
......@@ -35,8 +35,8 @@ from __future__ import print_function
import tensorflow as tf
import graphs
import train_utils
from adversarial_text import graphs
from adversarial_text import train_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
......
......@@ -64,8 +64,8 @@ def run_training(train_op,
sv = tf.train.Supervisor(
logdir=FLAGS.train_dir,
is_chief=is_chief,
save_summaries_secs=5 * 60,
save_model_secs=5 * 60,
save_summaries_secs=30,
save_model_secs=30,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
global_step=global_step)
......@@ -90,10 +90,9 @@ def run_training(train_op,
global_step_val = 0
while not sv.should_stop() and global_step_val < FLAGS.max_steps:
global_step_val = train_step(sess, train_op, loss, global_step)
sv.stop()
# Final checkpoint
if is_chief:
if is_chief and global_step_val >= FLAGS.max_steps:
sv.saver.save(sess, sv.save_path, global_step=global_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment