"git@developer.sourcefind.cn:change/sglang.git" did not exist on "bc068e96181d4b42989e1c13b59f4b64de94bd99"
Commit f7cea8d0 authored by Neal Wu's avatar Neal Wu
Browse files

Rename sampled_loss argument inputs to logits in preparation for named arguments requirement

parent e4cbe9ee
...@@ -100,13 +100,13 @@ class Seq2SeqModel(object): ...@@ -100,13 +100,13 @@ class Seq2SeqModel(object):
b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype) b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
output_projection = (w, b) output_projection = (w, b)
def sampled_loss(labels, inputs): def sampled_loss(labels, logits):
labels = tf.reshape(labels, [-1, 1]) labels = tf.reshape(labels, [-1, 1])
# We need to compute the sampled_softmax_loss using 32bit floats to # We need to compute the sampled_softmax_loss using 32bit floats to
# avoid numerical instabilities. # avoid numerical instabilities.
local_w_t = tf.cast(w_t, tf.float32) local_w_t = tf.cast(w_t, tf.float32)
local_b = tf.cast(b, tf.float32) local_b = tf.cast(b, tf.float32)
local_inputs = tf.cast(inputs, tf.float32) local_inputs = tf.cast(logits, tf.float32)
return tf.cast( return tf.cast(
tf.nn.sampled_softmax_loss( tf.nn.sampled_softmax_loss(
weights=local_w_t, weights=local_w_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment