Commit 9f0a567d authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fixes to multiclass training. Change DBpedia data generation to generate labels starting from 0.

PiperOrigin-RevId: 185783053
parent 9e30188e
...@@ -10,7 +10,7 @@ py_binary( ...@@ -10,7 +10,7 @@ py_binary(
deps = [ deps = [
":graphs", ":graphs",
# google3 file dep, # google3 file dep,
# tensorflow dep, # tensorflow internal dep,
], ],
) )
...@@ -21,7 +21,7 @@ py_binary( ...@@ -21,7 +21,7 @@ py_binary(
":graphs", ":graphs",
":train_utils", ":train_utils",
# google3 file dep, # google3 file dep,
# tensorflow dep, # tensorflow internal dep,
], ],
) )
...@@ -34,7 +34,8 @@ py_binary( ...@@ -34,7 +34,8 @@ py_binary(
":graphs", ":graphs",
":train_utils", ":train_utils",
# google3 file dep, # google3 file dep,
# tensorflow dep, # tensorflow internal gpu deps
# tensorflow internal dep,
], ],
) )
......
...@@ -154,3 +154,4 @@ control which dataset is processed and how. ...@@ -154,3 +154,4 @@ control which dataset is processed and how.
## Contact for Issues ## Contact for Issues
* Ryan Sepassi, @rsepassi * Ryan Sepassi, @rsepassi
* Andrew M. Dai, @a-dai
...@@ -38,6 +38,8 @@ flags.DEFINE_float('small_constant_for_finite_diff', 1e-1, ...@@ -38,6 +38,8 @@ flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
# Parameters for building the graph # Parameters for building the graph
flags.DEFINE_string('adv_training_method', None, flags.DEFINE_string('adv_training_method', None,
'The flag which specifies training method. ' 'The flag which specifies training method. '
'"" : non-adversarial training (e.g. for running the '
' semi-supervised sequence learning model) '
'"rp" : random perturbation training ' '"rp" : random perturbation training '
'"at" : adversarial training ' '"at" : adversarial training '
'"vat" : virtual adversarial training ' '"vat" : virtual adversarial training '
......
...@@ -271,7 +271,7 @@ def build_labeled_sequence(seq, class_label, label_gain=False): ...@@ -271,7 +271,7 @@ def build_labeled_sequence(seq, class_label, label_gain=False):
Args: Args:
seq: SequenceWrapper. seq: SequenceWrapper.
class_label: bool. class_label: integer, starting from 0.
label_gain: bool. If True, class_label will be put on every timestep and label_gain: bool. If True, class_label will be put on every timestep and
weight will increase linearly from 0 to 1. weight will increase linearly from 0 to 1.
......
...@@ -259,7 +259,7 @@ def dbpedia_documents(dataset='train', ...@@ -259,7 +259,7 @@ def dbpedia_documents(dataset='train',
content=content, content=content,
is_validation=is_validation, is_validation=is_validation,
is_test=False, is_test=False,
label=int(row[0]), label=int(row[0]) - 1, # Labels should start from 0
add_tokens=True) add_tokens=True)
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
# Dependency imports # Dependency imports
import tensorflow as tf import tensorflow as tf
K = tf.contrib.keras K = tf.keras
def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.): def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.):
...@@ -148,6 +148,7 @@ class SoftmaxLoss(K.layers.Layer): ...@@ -148,6 +148,7 @@ class SoftmaxLoss(K.layers.Layer):
self.num_candidate_samples = num_candidate_samples self.num_candidate_samples = num_candidate_samples
self.vocab_freqs = vocab_freqs self.vocab_freqs = vocab_freqs
super(SoftmaxLoss, self).__init__(**kwargs) super(SoftmaxLoss, self).__init__(**kwargs)
self.multiclass_dense_layer = K.layers.Dense(self.vocab_size)
def build(self, input_shape): def build(self, input_shape):
input_shape = input_shape[0] input_shape = input_shape[0]
...@@ -160,6 +161,7 @@ class SoftmaxLoss(K.layers.Layer): ...@@ -160,6 +161,7 @@ class SoftmaxLoss(K.layers.Layer):
shape=(self.vocab_size,), shape=(self.vocab_size,),
name='lm_lin_b', name='lm_lin_b',
initializer=K.initializers.glorot_uniform()) initializer=K.initializers.glorot_uniform())
self.multiclass_dense_layer.build(input_shape)
super(SoftmaxLoss, self).build(input_shape) super(SoftmaxLoss, self).build(input_shape)
...@@ -190,7 +192,7 @@ class SoftmaxLoss(K.layers.Layer): ...@@ -190,7 +192,7 @@ class SoftmaxLoss(K.layers.Layer):
lm_loss, lm_loss,
[int(x.get_shape()[0]), int(x.get_shape()[1])]) [int(x.get_shape()[0]), int(x.get_shape()[1])])
else: else:
logits = tf.matmul(x, self.lin_w) + self.lin_b logits = self.multiclass_dense_layer(x)
lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels) logits=logits, labels=labels)
...@@ -255,7 +257,7 @@ def predictions(logits): ...@@ -255,7 +257,7 @@ def predictions(logits):
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64) pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64)
# For multi-class classification # For multi-class classification
else: else:
pred = tf.argmax(logits, 1) pred = tf.argmax(logits, 2)
return pred return pred
...@@ -354,10 +356,9 @@ def optimize(loss, ...@@ -354,10 +356,9 @@ def optimize(loss,
opt.ready_for_local_init_op) opt.ready_for_local_init_op)
else: else:
# Non-sync optimizer # Non-sync optimizer
variables_averages_op = variable_averages.apply(tvars)
apply_gradient_op = opt.apply_gradients(grads_and_vars, global_step) apply_gradient_op = opt.apply_gradients(grads_and_vars, global_step)
with tf.control_dependencies([apply_gradient_op, variables_averages_op]): with tf.control_dependencies([apply_gradient_op]):
train_op = tf.no_op(name='train_op') train_op = variable_averages.apply(tvars)
return train_op return train_op
......
...@@ -64,8 +64,8 @@ def run_training(train_op, ...@@ -64,8 +64,8 @@ def run_training(train_op,
sv = tf.train.Supervisor( sv = tf.train.Supervisor(
logdir=FLAGS.train_dir, logdir=FLAGS.train_dir,
is_chief=is_chief, is_chief=is_chief,
save_summaries_secs=5 * 60, save_summaries_secs=30,
save_model_secs=5 * 60, save_model_secs=30,
local_init_op=local_init_op, local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op, ready_for_local_init_op=ready_for_local_init_op,
global_step=global_step) global_step=global_step)
...@@ -90,10 +90,9 @@ def run_training(train_op, ...@@ -90,10 +90,9 @@ def run_training(train_op,
global_step_val = 0 global_step_val = 0
while not sv.should_stop() and global_step_val < FLAGS.max_steps: while not sv.should_stop() and global_step_val < FLAGS.max_steps:
global_step_val = train_step(sess, train_op, loss, global_step) global_step_val = train_step(sess, train_op, loss, global_step)
sv.stop()
# Final checkpoint # Final checkpoint
if is_chief: if is_chief and global_step_val >= FLAGS.max_steps:
sv.saver.save(sess, sv.save_path, global_step=global_step) sv.saver.save(sess, sv.save_path, global_step=global_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment