"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "fd3112bc4954071456660dbb3904a6c9b7a54d86"
Commit e4cbe9ee authored by Neal Wu's avatar Neal Wu
Browse files

Fix the argument order for tf.nn.sampled_softmax_loss in textsum

parent dcb49d91
...@@ -227,8 +227,9 @@ class Seq2SeqAttentionModel(object): ...@@ -227,8 +227,9 @@ class Seq2SeqAttentionModel(object):
def sampled_loss_func(inputs, labels): def sampled_loss_func(inputs, labels):
with tf.device('/cpu:0'): # Try gpu. with tf.device('/cpu:0'): # Try gpu.
labels = tf.reshape(labels, [-1, 1]) labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, v, inputs, labels, return tf.nn.sampled_softmax_loss(
hps.num_softmax_samples, vsize) weights=w_t, biases=v, labels=labels, inputs=inputs,
num_sampled=hps.num_softmax_samples, num_classes=vsize)
if hps.num_softmax_samples != 0 and hps.mode == 'train': if hps.num_softmax_samples != 0 and hps.mode == 'train':
self._loss = seq2seq_lib.sampled_sequence_loss( self._loss = seq2seq_lib.sampled_sequence_loss(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment