Commit 9f0a567d authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fixes to multiclass training. Change DBpedia data generation to generate labels starting from 0.

PiperOrigin-RevId: 185783053
parent 9e30188e
......@@ -10,7 +10,7 @@ py_binary(
deps = [
":graphs",
# google3 file dep,
# tensorflow dep,
# tensorflow internal dep,
],
)
......@@ -21,7 +21,7 @@ py_binary(
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
# tensorflow internal dep,
],
)
......@@ -34,7 +34,8 @@ py_binary(
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
# tensorflow internal gpu deps
# tensorflow internal dep,
],
)
......
......@@ -154,3 +154,4 @@ control which dataset is processed and how.
## Contact for Issues
* Ryan Sepassi, @rsepassi
* Andrew M. Dai, @a-dai
......@@ -38,6 +38,8 @@ flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
# Parameters for building the graph
flags.DEFINE_string('adv_training_method', None,
'The flag which specifies training method. '
'"" : non-adversarial training (e.g. for running the '
' semi-supervised sequence learning model) '
'"rp" : random perturbation training '
'"at" : adversarial training '
'"vat" : virtual adversarial training '
......
......@@ -271,7 +271,7 @@ def build_labeled_sequence(seq, class_label, label_gain=False):
Args:
seq: SequenceWrapper.
class_label: bool.
class_label: integer, starting from 0.
label_gain: bool. If True, class_label will be put on every timestep and
weight will increase linearly from 0 to 1.
......
......@@ -259,7 +259,7 @@ def dbpedia_documents(dataset='train',
content=content,
is_validation=is_validation,
is_test=False,
label=int(row[0]),
label=int(row[0]) - 1, # Labels should start from 0
add_tokens=True)
......
......@@ -20,7 +20,7 @@ from __future__ import print_function
# Dependency imports
import tensorflow as tf
K = tf.contrib.keras
K = tf.keras
def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.):
......@@ -148,6 +148,7 @@ class SoftmaxLoss(K.layers.Layer):
self.num_candidate_samples = num_candidate_samples
self.vocab_freqs = vocab_freqs
super(SoftmaxLoss, self).__init__(**kwargs)
self.multiclass_dense_layer = K.layers.Dense(self.vocab_size)
def build(self, input_shape):
input_shape = input_shape[0]
......@@ -160,6 +161,7 @@ class SoftmaxLoss(K.layers.Layer):
shape=(self.vocab_size,),
name='lm_lin_b',
initializer=K.initializers.glorot_uniform())
self.multiclass_dense_layer.build(input_shape)
super(SoftmaxLoss, self).build(input_shape)
......@@ -190,7 +192,7 @@ class SoftmaxLoss(K.layers.Layer):
lm_loss,
[int(x.get_shape()[0]), int(x.get_shape()[1])])
else:
logits = tf.matmul(x, self.lin_w) + self.lin_b
logits = self.multiclass_dense_layer(x)
lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
......@@ -255,7 +257,7 @@ def predictions(logits):
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64)
# For multi-class classification
else:
pred = tf.argmax(logits, 1)
pred = tf.argmax(logits, 2)
return pred
......@@ -354,10 +356,9 @@ def optimize(loss,
opt.ready_for_local_init_op)
else:
# Non-sync optimizer
variables_averages_op = variable_averages.apply(tvars)
apply_gradient_op = opt.apply_gradients(grads_and_vars, global_step)
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train_op')
with tf.control_dependencies([apply_gradient_op]):
train_op = variable_averages.apply(tvars)
return train_op
......
......@@ -64,8 +64,8 @@ def run_training(train_op,
sv = tf.train.Supervisor(
logdir=FLAGS.train_dir,
is_chief=is_chief,
save_summaries_secs=5 * 60,
save_model_secs=5 * 60,
save_summaries_secs=30,
save_model_secs=30,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
global_step=global_step)
......@@ -90,10 +90,9 @@ def run_training(train_op,
global_step_val = 0
while not sv.should_stop() and global_step_val < FLAGS.max_steps:
global_step_val = train_step(sess, train_op, loss, global_step)
sv.stop()
# Final checkpoint
if is_chief:
if is_chief and global_step_val >= FLAGS.max_steps:
sv.saver.save(sess, sv.save_path, global_step=global_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment