Commit f3e7cc25 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 468759527
parent 5d340ff3
...@@ -129,14 +129,18 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -129,14 +129,18 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
self.decoding_name = decoding_name self.decoding_name = decoding_name
def generate(self, initial_ids: tf.Tensor, def generate(self,
initial_cache: Dict[str, tf.Tensor]) -> Output: initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor],
initial_log_probs: Optional[tf.Tensor] = None) -> Output:
"""Implements the decoding strategy (beam_search or sampling). """Implements the decoding strategy (beam_search or sampling).
Args: Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn. int tensor initial_ids: initial ids to pass into the symbols_to_logits_fn. int tensor
with shape [batch_size, 1] with shape [batch_size, 1]
initial_cache: dictionary for caching model outputs from previous step. initial_cache: dictionary for caching model outputs from previous step.
initial_log_probs: Optionally initial log probs if there is a prefix
sequence we want to start to decode from.
Returns: Returns:
Tuple of tensors representing Tuple of tensors representing
...@@ -148,9 +152,9 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -148,9 +152,9 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
initial_ids.shape.as_list()[0] initial_ids.shape.as_list()[0]
if self.padded_decode else tf.shape(initial_ids)[0]) if self.padded_decode else tf.shape(initial_ids)[0])
state, state_shapes = self._create_initial_state(initial_ids, state, state_shapes = self._create_initial_state(initial_ids, initial_cache,
initial_cache, batch_size,
batch_size) initial_log_probs)
def _generate_step(state): def _generate_step(state):
topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq( topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq(
...@@ -196,10 +200,12 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -196,10 +200,12 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
return final_state return final_state
@abc.abstractmethod @abc.abstractmethod
def _create_initial_state(self, def _create_initial_state(
self,
initial_ids: tf.Tensor, initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor], initial_cache: Dict[str, tf.Tensor],
batch_size: int) -> InitialState: batch_size: int,
initial_log_probs: Optional[tf.Tensor] = None) -> InitialState:
"""Return initial state dictionary and its shape invariants.""" """Return initial state dictionary and its shape invariants."""
pass pass
......
...@@ -250,10 +250,13 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -250,10 +250,13 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
topk_seq = tf.concat([alive_seq, topk_ids], axis=-1) topk_seq = tf.concat([alive_seq, topk_ids], axis=-1)
return topk_seq, topk_log_probs, topk_ids, new_cache return topk_seq, topk_log_probs, topk_ids, new_cache
def _create_initial_state(self, def _create_initial_state(
self,
initial_ids: tf.Tensor, initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor], initial_cache: Dict[str, tf.Tensor],
batch_size: int) -> decoding_module.InitialState: batch_size: int,
initial_log_probs: Optional[tf.Tensor] = None
) -> decoding_module.InitialState:
"""Return initial state dictionary and its shape invariants.""" """Return initial state dictionary and its shape invariants."""
for key, value in initial_cache.items(): for key, value in initial_cache.items():
for inner_value in tf.nest.flatten(value): for inner_value in tf.nest.flatten(value):
...@@ -273,8 +276,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -273,8 +276,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
alive_seq = tf.tile(alive_seq, [1, self.max_decode_length + 1]) alive_seq = tf.tile(alive_seq, [1, self.max_decode_length + 1])
# Initial log probabilities with shape [batch_size, 1]. # Initial log probabilities with shape [batch_size, 1].
if initial_log_probs is None:
initial_log_probs = tf.constant([[0.]], dtype=self.dtype) initial_log_probs = tf.constant([[0.]], dtype=self.dtype)
alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])
else:
alive_log_probs = initial_log_probs
alive_cache = initial_cache alive_cache = initial_cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment