"examples/vscode:/vscode.git/clone" did not exist on "3fca52022fe0ea9aaf0a0ea8a0fc13308bf69a9f"
Commit 7d921c12 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Add option for when there is only one label for classification, for speed....

Add option for when there is only one label for classification, for speed. (The case for the standard datasets).

PiperOrigin-RevId: 172774825
parent fda7e6dc
......@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_timesteps', 100, 'Number of timesteps for BPTT')
# Model architechture
flags.DEFINE_bool('bidir_lstm', False, 'Whether to build a bidirectional LSTM.')
flags.DEFINE_bool('single_label', True, 'Whether the sequence has a single '
'label, for optimization.')
flags.DEFINE_integer('rnn_num_layers', 1, 'Number of LSTM layers.')
flags.DEFINE_integer('rnn_cell_size', 512,
'Number of hidden units in the LSTM.')
......@@ -181,7 +183,14 @@ class VatxtModel(object):
self.tensors['cl_logits'] = logits
self.tensors['cl_loss'] = loss
acc = layers_lib.accuracy(logits, inputs.labels, inputs.weights)
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.gather_nd(inputs.labels, indices)
weights = tf.gather_nd(inputs.weights, indices)
else:
labels = inputs.labels
weights = inputs.weights
acc = layers_lib.accuracy(logits, labels, weights)
tf.summary.scalar('accuracy', acc)
adv_loss = (self.adversarial_loss() * tf.constant(
......@@ -248,10 +257,17 @@ class VatxtModel(object):
_, next_state, logits, _ = self.cl_loss_from_embedding(
embedded, inputs=inputs, return_intermediates=True)
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
labels = tf.gather_nd(inputs.labels, indices)
weights = tf.gather_nd(inputs.weights, indices)
else:
labels = inputs.labels
weights = inputs.weights
eval_ops = {
'accuracy':
tf.contrib.metrics.streaming_accuracy(
layers_lib.predictions(logits), inputs.labels, inputs.weights)
layers_lib.predictions(logits), labels, weights)
}
with tf.control_dependencies([inputs.save_state(next_state)]):
......@@ -285,8 +301,16 @@ class VatxtModel(object):
lstm_out, next_state = self.layers['lstm'](embedded, inputs.state,
inputs.length)
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
lstm_out = tf.gather_nd(lstm_out, indices)
labels = tf.gather_nd(inputs.labels, indices)
weights = tf.gather_nd(inputs.weights, indices)
else:
labels = inputs.labels
weights = inputs.weights
logits = self.layers['cl_logits'](lstm_out)
loss = layers_lib.classification_loss(logits, inputs.labels, inputs.weights)
loss = layers_lib.classification_loss(logits, labels, weights)
if return_intermediates:
return lstm_out, next_state, logits, loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment