Commit 1d3cd3cf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 283622345
parent b478430d
...@@ -402,7 +402,9 @@ class SequenceBeamSearch(object): ...@@ -402,7 +402,9 @@ class SequenceBeamSearch(object):
topk_ids = topk_indices % self.vocab_size topk_ids = topk_indices % self.vocab_size
if self.padded_decode: if self.padded_decode:
topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1]) topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_nd_update(topk_seq, [i + 1], topk_ids) # TODO(b/145533236, hongkuny): Reverts once TF fix the validation.
topk_seq = tf.tensor_scatter_nd_update(topk_seq, [[i + 1]],
tf.expand_dims(topk_ids, axis=0))
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0]) topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else: else:
topk_ids = tf.expand_dims(topk_ids, axis=2) topk_ids = tf.expand_dims(topk_ids, axis=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment