"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "cfe2c2be5d4e384b4efd9c2f2266edd68876b34e"
Commit 36db2450 authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 346570482
parent a5ed2bb7
...@@ -25,7 +25,7 @@ from official.nlp.modeling.ops import decoding_module ...@@ -25,7 +25,7 @@ from official.nlp.modeling.ops import decoding_module
def greedy(log_probs): def greedy(log_probs):
"""Returns the top ids and scores based on greedy decoding.""" """Returns the top ids and scores based on greedy decoding."""
log_probs, ids = tf.nn.top_k(log_probs, k=1) log_probs, ids = tf.math.top_k(log_probs, k=1)
return log_probs, ids return log_probs, ids
...@@ -56,9 +56,9 @@ def sample_top_k(logits, top_k): ...@@ -56,9 +56,9 @@ def sample_top_k(logits, top_k):
Logits with top_k filtering applied. Logits with top_k filtering applied.
""" """
top_k_logits = tf.math.top_k(logits, k=top_k) top_k_logits = tf.math.top_k(logits, k=top_k)
indices_to_remove = logits < top_k_logits[0][..., -1, None] indices_to_remove = logits < tf.expand_dims(top_k_logits[0][..., -1], -1)
top_k_logits = set_tensor_by_indices_to_value( top_k_logits = set_tensor_by_indices_to_value(logits, indices_to_remove,
logits, indices_to_remove, np.NINF) np.NINF)
return top_k_logits return top_k_logits
...@@ -425,7 +425,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -425,7 +425,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
finished_cond, finished_seq) finished_cond, finished_seq)
score_cond = decoding_module.expand_to_same_rank( score_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_scores) finished_cond, finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq, finished_scores) finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs) finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores return finished_seq, finished_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment