Commit 518f0201 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix type annotation for length_normalization_fn, which should be Optional.

Both unit tests and real use cases are passing None to it.

PiperOrigin-RevId: 385468558
parent 7ecbac3c
......@@ -15,7 +15,7 @@
"""Sampling module for top_k, top_p and greedy decoding."""
import abc
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional
import numpy as np
import tensorflow as tf
......@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p):
], -1)
# Scatter sorted indices to original indexes.
indices_to_remove = scatter_values_on_batch_indices(
sorted_indices_to_remove, sorted_indices)
top_p_logits = set_tensor_by_indices_to_value(
logits, indices_to_remove, np.NINF)
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove,
sorted_indices)
top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove,
np.NINF)
return top_p_logits
......@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices):
tensor_shape = decoding_module.shape_list(batch_indices)
broad_casted_batch_dims = tf.reshape(
tf.broadcast_to(
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1),
tensor_shape), [1, -1])
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape),
[1, -1])
pair_indices = tf.transpose(
tf.concat([broad_casted_batch_dims,
tf.reshape(batch_indices, [1, -1])], 0))
return tf.scatter_nd(pair_indices,
tf.reshape(values, [-1]), tensor_shape)
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape)
def set_tensor_by_indices_to_value(input_tensor, indices, value):
......@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
......@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def __init__(self,
symbols_to_logits_fn,
length_normalization_fn: Callable[[int, tf.DType], float],
vocab_size: int,
max_decode_length: int,
eos_id: int,
padded_decode: bool,
length_normalization_fn: Optional[Callable[[int, tf.DType],
float]] = None,
top_k=0,
top_p=1.0,
sample_temperature=0.0,
......@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self.max_decode_length = max_decode_length
self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32)
self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(sample_temperature,
dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(
sample_temperature, dtype=tf.float32)
self.enable_greedy = enable_greedy
super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype)
......@@ -330,10 +331,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
return state, state_shape_invariants
def _get_new_alive_state(
self,
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
"""Gather the sequences that are still alive.
......@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module.StateKeys.ALIVE_CACHE: new_cache
}
def _get_new_finished_state(self,
state: Dict[str, Any],
new_seq: tf.Tensor,
def _get_new_finished_state(self, state: Dict[str, Any], new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
batch_size: int) -> Dict[str, tf.Tensor]:
......@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm = self.length_normalization_fn(self.max_decode_length + 1,
self.dtype)
alive_log_probs = alive_log_probs / length_norm
seq_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_seq)
score_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_scores)
seq_cond = decoding_module.expand_to_same_rank(finished_cond, finished_seq)
score_cond = decoding_module.expand_to_same_rank(finished_cond,
finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment