Commit 8cb5ac1e authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 364378436
parent 0e6f8848
......@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple
import tensorflow as tf
from tensorflow.python.framework import dtypes
from official.modeling import tf_utils
Output = Tuple[tf.Tensor, tf.Tensor]
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
......@@ -64,15 +65,7 @@ def log_prob_from_logits(logits):
def shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
return tf_utils.get_shape_list(tensor)
def get_shape_keep_last_dim(tensor):
......
......@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p):
"""
sorted_indices = tf.argsort(logits, direction="DESCENDING")
# Flatten logits as tf.gather on TPU needs axis to be compile time constant.
range_for_gather = tf.expand_dims(tf.range(0, logits.shape[0]), axis=1)
range_for_gather = tf.tile(range_for_gather * logits.shape[1],
[1, logits.shape[1]]) + sorted_indices
logits_shape = decoding_module.shape_list(logits)
range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1)
range_for_gather = tf.tile(range_for_gather * logits_shape[1],
[1, logits_shape[1]]) + sorted_indices
flattened_logits = tf.reshape(logits, [-1])
flattened_sorted_indices = tf.reshape(range_for_gather, [-1])
sorted_logits = tf.reshape(
tf.gather(flattened_logits, flattened_sorted_indices),
[logits.shape[0], logits.shape[1]])
[logits_shape[0], logits_shape[1]])
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment