Commit 518f0201 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix type annotation for length_normalization_fn, which should be Optional.

Both unit tests and real use cases are passing None to it.

PiperOrigin-RevId: 385468558
parent 7ecbac3c
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Sampling module for top_k, top_p and greedy decoding.""" """Sampling module for top_k, top_p and greedy decoding."""
import abc import abc
from typing import Any, Callable, Dict from typing import Any, Callable, Dict, Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p): ...@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p):
], -1) ], -1)
# Scatter sorted indices to original indexes. # Scatter sorted indices to original indexes.
indices_to_remove = scatter_values_on_batch_indices( indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove,
sorted_indices_to_remove, sorted_indices) sorted_indices)
top_p_logits = set_tensor_by_indices_to_value( top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove,
logits, indices_to_remove, np.NINF) np.NINF)
return top_p_logits return top_p_logits
...@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices): ...@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices):
tensor_shape = decoding_module.shape_list(batch_indices) tensor_shape = decoding_module.shape_list(batch_indices)
broad_casted_batch_dims = tf.reshape( broad_casted_batch_dims = tf.reshape(
tf.broadcast_to( tf.broadcast_to(
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape),
tensor_shape), [1, -1]) [1, -1])
pair_indices = tf.transpose( pair_indices = tf.transpose(
tf.concat([broad_casted_batch_dims, tf.concat([broad_casted_batch_dims,
tf.reshape(batch_indices, [1, -1])], 0)) tf.reshape(batch_indices, [1, -1])], 0))
return tf.scatter_nd(pair_indices, return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape)
tf.reshape(values, [-1]), tensor_shape)
def set_tensor_by_indices_to_value(input_tensor, indices, value): def set_tensor_by_indices_to_value(input_tensor, indices, value):
...@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value): ...@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
input_tensor: float (batch_size, dim) input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim) indices: bool (batch_size, dim)
value: float scalar value: float scalar
Returns: Returns:
output_tensor: same shape as input_tensor. output_tensor: same shape as input_tensor.
""" """
...@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
symbols_to_logits_fn, symbols_to_logits_fn,
length_normalization_fn: Callable[[int, tf.DType], float],
vocab_size: int, vocab_size: int,
max_decode_length: int, max_decode_length: int,
eos_id: int, eos_id: int,
padded_decode: bool, padded_decode: bool,
length_normalization_fn: Optional[Callable[[int, tf.DType],
float]] = None,
top_k=0, top_k=0,
top_p=1.0, top_p=1.0,
sample_temperature=0.0, sample_temperature=0.0,
...@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32) self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32)
self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32) self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(sample_temperature, self.sample_temperature = tf.convert_to_tensor(
dtype=tf.float32) sample_temperature, dtype=tf.float32)
self.enable_greedy = enable_greedy self.enable_greedy = enable_greedy
super(SamplingModule, self).__init__( super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype) length_normalization_fn=length_normalization_fn, dtype=dtype)
...@@ -330,10 +331,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -330,10 +331,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
return state, state_shape_invariants return state, state_shape_invariants
def _get_new_alive_state( def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor,
self,
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor, new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]: new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
"""Gather the sequences that are still alive. """Gather the sequences that are still alive.
...@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module.StateKeys.ALIVE_CACHE: new_cache decoding_module.StateKeys.ALIVE_CACHE: new_cache
} }
def _get_new_finished_state(self, def _get_new_finished_state(self, state: Dict[str, Any], new_seq: tf.Tensor,
state: Dict[str, Any],
new_seq: tf.Tensor,
new_log_probs: tf.Tensor, new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor, new_finished_flags: tf.Tensor,
batch_size: int) -> Dict[str, tf.Tensor]: batch_size: int) -> Dict[str, tf.Tensor]:
...@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm = self.length_normalization_fn(self.max_decode_length + 1, length_norm = self.length_normalization_fn(self.max_decode_length + 1,
self.dtype) self.dtype)
alive_log_probs = alive_log_probs / length_norm alive_log_probs = alive_log_probs / length_norm
seq_cond = decoding_module.expand_to_same_rank( seq_cond = decoding_module.expand_to_same_rank(finished_cond, finished_seq)
finished_cond, finished_seq) score_cond = decoding_module.expand_to_same_rank(finished_cond,
score_cond = decoding_module.expand_to_same_rank( finished_scores)
finished_cond, finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq) finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs) finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores return finished_seq, finished_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment