Commit 30826f6b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 466203688
parent b519ea47
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from official.modeling import tf_utils from official.modeling import tf_utils
Output = Tuple[tf.Tensor, tf.Tensor] Output = Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]]
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict] InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
InitialState = Tuple[Dict[str, Any], Dict[str, Any]] InitialState = Tuple[Dict[str, Any], Dict[str, Any]]
...@@ -46,6 +46,10 @@ class StateKeys: ...@@ -46,6 +46,10 @@ class StateKeys:
# the previous iteration. # the previous iteration.
ALIVE_CACHE = "ALIVE_CACHE" ALIVE_CACHE = "ALIVE_CACHE"
# The initial model state/cache after model processing the initial token.
# The cache will be filled if extra_cache_output is true.
INITIAL_OUTPUT_CACHE = "INITIAL_OUTPUT_CACHE"
# Top finished sequences for each batch item. # Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are # Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s. # shorter than CUR_INDEX + 1 are padded with 0s.
...@@ -109,7 +113,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -109,7 +113,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
length_normalization_fn: Callable[[int, tf.DType], float], length_normalization_fn: Callable[[int, tf.DType], float],
dtype: tf.DType = tf.float32, dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None): decoding_name: Optional[str] = None,
extra_cache_output: bool = False):
"""Initialize the Decoding Module. """Initialize the Decoding Module.
Args: Args:
...@@ -118,24 +123,26 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -118,24 +123,26 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
dtype: A tensorflow data type used for score computation. The default is dtype: A tensorflow data type used for score computation. The default is
tf.float32. tf.float32.
decoding_name: an optional name for the decoding loop tensors. decoding_name: an optional name for the decoding loop tensors.
extra_cache_output: If true, the first cache will be in the states.
""" """
self.length_normalization_fn = length_normalization_fn self.length_normalization_fn = length_normalization_fn
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
self.decoding_name = decoding_name self.decoding_name = decoding_name
def generate(self, def generate(self, initial_ids: tf.Tensor,
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor]) -> Output: initial_cache: Dict[str, tf.Tensor]) -> Output:
"""Implements the decoding strategy (beam_search or sampling). """Implements the decoding strategy (beam_search or sampling).
Args: Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn. initial_ids: initial ids to pass into the symbols_to_logits_fn. int tensor
int tensor with shape [batch_size, 1] with shape [batch_size, 1]
initial_cache: dictionary for caching model outputs from previous step. initial_cache: dictionary for caching model outputs from previous step.
Returns: Returns:
Tuple of tensors representing Tuple of tensors representing
finished_sequence: shape [batch, max_seq_length] finished_sequence: shape [batch, max_seq_length]
finished_scores: [batch] finished_scores: [batch]
first_cache: The cache after init token
""" """
batch_size = ( batch_size = (
initial_ids.shape.as_list()[0] initial_ids.shape.as_list()[0]
...@@ -163,6 +170,17 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -163,6 +170,17 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
} }
new_state.update(alive_state) new_state.update(alive_state)
new_state.update(finished_state) new_state.update(finished_state)
if self.extra_cache_output:
i = state[StateKeys.CUR_INDEX]
old_cache = state[StateKeys.INITIAL_OUTPUT_CACHE]
def update_with_cache(new_state, cache):
"""Updates new_state with cache."""
new_state.update({StateKeys.INITIAL_OUTPUT_CACHE: cache})
tf.cond(
tf.equal(i, 0), lambda: update_with_cache(new_state, new_cache),
lambda: update_with_cache(new_state, old_cache))
return [new_state] return [new_state]
finished_state = tf.nest.map_structure( finished_state = tf.nest.map_structure(
......
...@@ -29,6 +29,7 @@ class TestSubclass(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -29,6 +29,7 @@ class TestSubclass(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
length_normalization_fn=length_normalization, length_normalization_fn=length_normalization,
extra_cache_output=True,
dtype=tf.float32): dtype=tf.float32):
super(TestSubclass, self).__init__( super(TestSubclass, self).__init__(
length_normalization_fn=length_normalization, dtype=dtype) length_normalization_fn=length_normalization, dtype=dtype)
......
...@@ -163,7 +163,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -163,7 +163,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
sample_temperature=0.0, sample_temperature=0.0,
enable_greedy: bool = True, enable_greedy: bool = True,
dtype: tf.DType = tf.float32, dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None): decoding_name: Optional[str] = None,
extra_cache_output: bool = False):
"""Initialize sampling module.""" """Initialize sampling module."""
self.symbols_to_logits_fn = symbols_to_logits_fn self.symbols_to_logits_fn = symbols_to_logits_fn
self.length_normalization_fn = length_normalization_fn self.length_normalization_fn = length_normalization_fn
...@@ -178,10 +179,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -178,10 +179,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
sample_temperature, dtype=tf.float32) sample_temperature, dtype=tf.float32)
self.enable_greedy = enable_greedy self.enable_greedy = enable_greedy
self.decoding_name = decoding_name self.decoding_name = decoding_name
self.extra_cache_output = extra_cache_output
super(SamplingModule, self).__init__( super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, length_normalization_fn=length_normalization_fn,
dtype=dtype, dtype=dtype,
decoding_name=decoding_name) decoding_name=decoding_name,
extra_cache_output=extra_cache_output)
def _grow_alive_seq(self, def _grow_alive_seq(self,
state: Dict[str, Any], state: Dict[str, Any],
...@@ -300,16 +303,14 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -300,16 +303,14 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module.StateKeys.CUR_INDEX: decoding_module.StateKeys.CUR_INDEX:
tf.TensorShape([]), tf.TensorShape([]),
decoding_module.StateKeys.ALIVE_SEQ: decoding_module.StateKeys.ALIVE_SEQ:
tf.TensorShape( tf.TensorShape([batch_size, self.max_decode_length + 1]),
[batch_size, self.max_decode_length + 1]),
decoding_module.StateKeys.ALIVE_LOG_PROBS: decoding_module.StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([batch_size, 1]), tf.TensorShape([batch_size, 1]),
decoding_module.StateKeys.ALIVE_CACHE: decoding_module.StateKeys.ALIVE_CACHE:
tf.nest.map_structure(lambda state: state.get_shape(), tf.nest.map_structure(lambda state: state.get_shape(),
alive_cache), alive_cache),
decoding_module.StateKeys.FINISHED_SEQ: decoding_module.StateKeys.FINISHED_SEQ:
tf.TensorShape( tf.TensorShape([batch_size, self.max_decode_length + 1]),
[batch_size, self.max_decode_length + 1]),
decoding_module.StateKeys.FINISHED_SCORES: decoding_module.StateKeys.FINISHED_SCORES:
tf.TensorShape([batch_size, 1]), tf.TensorShape([batch_size, 1]),
decoding_module.StateKeys.FINISHED_FLAGS: decoding_module.StateKeys.FINISHED_FLAGS:
...@@ -324,8 +325,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -324,8 +325,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module.StateKeys.ALIVE_LOG_PROBS: decoding_module.StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([None, 1]), tf.TensorShape([None, 1]),
decoding_module.StateKeys.ALIVE_CACHE: decoding_module.StateKeys.ALIVE_CACHE:
tf.nest.map_structure( tf.nest.map_structure(decoding_module.get_shape_keep_last_dim,
decoding_module.get_shape_keep_last_dim,
alive_cache), alive_cache),
decoding_module.StateKeys.FINISHED_SEQ: decoding_module.StateKeys.FINISHED_SEQ:
tf.TensorShape([None, None]), tf.TensorShape([None, None]),
...@@ -335,6 +335,22 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -335,6 +335,22 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
tf.TensorShape([None, 1]) tf.TensorShape([None, 1])
} }
if self.extra_cache_output:
state.update(
{decoding_module.StateKeys.INITIAL_OUTPUT_CACHE: alive_cache})
if self.padded_decode:
state_shape_invariants.update({
decoding_module.StateKeys.INITIAL_OUTPUT_CACHE:
tf.nest.map_structure(lambda state: state.get_shape(),
alive_cache)
})
else:
state_shape_invariants.update({
decoding_module.StateKeys.INITIAL_OUTPUT_CACHE:
tf.nest.map_structure(decoding_module.get_shape_keep_last_dim,
alive_cache),
})
return state, state_shape_invariants return state, state_shape_invariants
def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor, def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor,
...@@ -428,6 +444,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -428,6 +444,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
finished_scores) finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq) finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs) finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
if self.extra_cache_output:
return finished_seq, finished_scores, finished_state[
decoding_module.StateKeys.INITIAL_OUTPUT_CACHE]
return finished_seq, finished_scores return finished_seq, finished_scores
def _continue_search(self, state) -> tf.Tensor: def _continue_search(self, state) -> tf.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment