Commit 2da86542 authored by Poorva Potdar's avatar Poorva Potdar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 364378436
parent f9491103
...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple ...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from official.modeling import tf_utils
Output = Tuple[tf.Tensor, tf.Tensor] Output = Tuple[tf.Tensor, tf.Tensor]
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict] InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
...@@ -64,15 +65,7 @@ def log_prob_from_logits(logits): ...@@ -64,15 +65,7 @@ def log_prob_from_logits(logits):
def shape_list(tensor): def shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list.""" """Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions) return tf_utils.get_shape_list(tensor)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
def get_shape_keep_last_dim(tensor): def get_shape_keep_last_dim(tensor):
......
...@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p): ...@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p):
""" """
sorted_indices = tf.argsort(logits, direction="DESCENDING") sorted_indices = tf.argsort(logits, direction="DESCENDING")
# Flatten logits as tf.gather on TPU needs axis to be compile time constant. # Flatten logits as tf.gather on TPU needs axis to be compile time constant.
range_for_gather = tf.expand_dims(tf.range(0, logits.shape[0]), axis=1) logits_shape = decoding_module.shape_list(logits)
range_for_gather = tf.tile(range_for_gather * logits.shape[1], range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1)
[1, logits.shape[1]]) + sorted_indices range_for_gather = tf.tile(range_for_gather * logits_shape[1],
[1, logits_shape[1]]) + sorted_indices
flattened_logits = tf.reshape(logits, [-1]) flattened_logits = tf.reshape(logits, [-1])
flattened_sorted_indices = tf.reshape(range_for_gather, [-1]) flattened_sorted_indices = tf.reshape(range_for_gather, [-1])
sorted_logits = tf.reshape( sorted_logits = tf.reshape(
tf.gather(flattened_logits, flattened_sorted_indices), tf.gather(flattened_logits, flattened_sorted_indices),
[logits.shape[0], logits.shape[1]]) [logits_shape[0], logits_shape[1]])
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold. # Remove tokens with cumulative probability above the threshold.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment