Commit 36db2450 authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 346570482
parent a5ed2bb7
......@@ -25,7 +25,7 @@ from official.nlp.modeling.ops import decoding_module
def greedy(log_probs):
"""Returns the top ids and scores based on greedy decoding."""
log_probs, ids = tf.nn.top_k(log_probs, k=1)
log_probs, ids = tf.math.top_k(log_probs, k=1)
return log_probs, ids
......@@ -56,9 +56,9 @@ def sample_top_k(logits, top_k):
Logits with top_k filtering applied.
"""
top_k_logits = tf.math.top_k(logits, k=top_k)
indices_to_remove = logits < top_k_logits[0][..., -1, None]
top_k_logits = set_tensor_by_indices_to_value(
logits, indices_to_remove, np.NINF)
indices_to_remove = logits < tf.expand_dims(top_k_logits[0][..., -1], -1)
top_k_logits = set_tensor_by_indices_to_value(logits, indices_to_remove,
np.NINF)
return top_k_logits
......@@ -425,7 +425,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
finished_cond, finished_seq)
score_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq, finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment