Commit 913640d4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 285765110
parent 722d9e57
...@@ -323,13 +323,16 @@ class SequenceBeamSearch(object): ...@@ -323,13 +323,16 @@ class SequenceBeamSearch(object):
new state dictionary. new state dictionary.
""" """
# Grow alive sequences by one token. # Grow alive sequences by one token.
new_seq, new_log_probs, new_cache = self._grow_alive_seq(state) new_seq, new_log_probs, topk_ids, new_cache = self._grow_alive_seq(state)
new_finished_flags = tf.equal(topk_ids, self.eos_id)
# Collect top beam_size alive sequences # Collect top beam_size alive sequences
alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache) alive_state = self._get_new_alive_state(new_seq, new_log_probs,
new_finished_flags, new_cache)
# Combine newly finished sequences with existing finished sequences, and # Combine newly finished sequences with existing finished sequences, and
# collect the top k scoring sequences. # collect the top k scoring sequences.
finished_state = self._get_new_finished_state(state, new_seq, new_log_probs) finished_state = self._get_new_finished_state(state, new_seq, new_log_probs,
new_finished_flags)
# Increment loop index and create new state dictionary # Increment loop index and create new state dictionary
new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1} new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
...@@ -407,18 +410,20 @@ class SequenceBeamSearch(object): ...@@ -407,18 +410,20 @@ class SequenceBeamSearch(object):
tf.expand_dims(topk_ids, axis=0)) tf.expand_dims(topk_ids, axis=0))
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0]) topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else: else:
topk_ids = tf.expand_dims(topk_ids, axis=2) topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2) return topk_seq, topk_log_probs, topk_ids, new_cache
return topk_seq, topk_log_probs, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache): def _get_new_alive_state(self, new_seq, new_log_probs, new_finished_flags,
new_cache):
"""Gather the top k sequences that are still alive. """Gather the top k sequences that are still alive.
Args: Args:
new_seq: New sequences generated by growing the current alive sequences new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1] int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
new_log_probs: Log probabilities of new sequences new_log_probs: Log probabilities of new sequences float32 tensor with
float32 tensor with shape [batch_size, beam_size] shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
new_cache: Dict of cached values for each sequence. new_cache: Dict of cached values for each sequence.
Returns: Returns:
...@@ -428,7 +433,6 @@ class SequenceBeamSearch(object): ...@@ -428,7 +433,6 @@ class SequenceBeamSearch(object):
Dict cache storing decoder states for top alive sequences} Dict cache storing decoder states for top alive sequences}
""" """
# To prevent finished sequences from being considered, set log probs to -inf # To prevent finished sequences from being considered, set log probs to -inf
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_log_probs += tf.cast(new_finished_flags, self.dtype) * -inf(self.dtype) new_log_probs += tf.cast(new_finished_flags, self.dtype) * -inf(self.dtype)
top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams( top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
...@@ -441,15 +445,18 @@ class SequenceBeamSearch(object): ...@@ -441,15 +445,18 @@ class SequenceBeamSearch(object):
_StateKeys.ALIVE_CACHE: top_alive_cache _StateKeys.ALIVE_CACHE: top_alive_cache
} }
def _get_new_finished_state(self, state, new_seq, new_log_probs): def _get_new_finished_state(self, state, new_seq, new_log_probs,
new_finished_flags):
"""Combine new and old finished sequences, and gather the top k sequences. """Combine new and old finished sequences, and gather the top k sequences.
Args: Args:
state: A dictionary with the current loop state. state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, beam_size, i + 1] int32 tensor with shape [batch_size, beam_size, i + 1]
new_log_probs: Log probabilities of new sequences new_log_probs: Log probabilities of new sequences float32 tensor with
float32 tensor with shape [batch_size, beam_size] shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
Returns: Returns:
Dictionary with finished keys from _StateKeys: Dictionary with finished keys from _StateKeys:
...@@ -476,7 +483,6 @@ class SequenceBeamSearch(object): ...@@ -476,7 +483,6 @@ class SequenceBeamSearch(object):
new_scores = new_log_probs / length_norm new_scores = new_log_probs / length_norm
# Set the scores of the still-alive seq in new_seq to large negative values. # Set the scores of the still-alive seq in new_seq to large negative values.
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_scores += ((1. - tf.cast(new_finished_flags, self.dtype)) * new_scores += ((1. - tf.cast(new_finished_flags, self.dtype)) *
-inf(self.dtype)) -inf(self.dtype))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment