Commit df2e30cd authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Move _gather_beams into the SequenceBeamSearch class.

PiperOrigin-RevId: 373191989
parent a82f0b56
......@@ -218,7 +218,7 @@ class SequenceBeamSearch(tf.Module):
# Extract the alive sequences that generate the highest log probabilities
# after being extended.
topk_beam_indices = topk_indices // self.vocab_size
topk_seq, new_cache = _gather_beams([alive_seq, new_cache],
topk_seq, new_cache = self._gather_beams([alive_seq, new_cache],
topk_beam_indices, batch_size,
beams_to_keep)
......@@ -259,9 +259,10 @@ class SequenceBeamSearch(tf.Module):
new_log_probs += tf.cast(new_finished_flags,
self.dtype) * -inf(self.dtype)
top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
[new_seq, new_log_probs, new_cache], new_log_probs, batch_size,
self.beam_size)
_, topk_indexes = tf.nn.top_k(new_log_probs, k=self.beam_size)
top_alive_seq, top_alive_log_probs, top_alive_cache = (
self._gather_beams([new_seq, new_log_probs, new_cache],
topk_indexes, batch_size, self.beam_size))
return {
_StateKeys.ALIVE_SEQ: top_alive_seq,
......@@ -316,9 +317,10 @@ class SequenceBeamSearch(tf.Module):
finished_flags = tf.concat([finished_flags, new_finished_flags], axis=1)
# Return the finished sequences with the best scores.
_, topk_indexes = tf.nn.top_k(finished_scores, k=self.beam_size)
top_finished_seq, top_finished_scores, top_finished_flags = (
_gather_topk_beams([finished_seq, finished_scores, finished_flags],
finished_scores, batch_size, self.beam_size))
self._gather_beams([finished_seq, finished_scores, finished_flags],
topk_indexes, batch_size, self.beam_size))
return {
_StateKeys.FINISHED_SEQ: top_finished_seq,
......@@ -538,6 +540,43 @@ class SequenceBeamSearch(tf.Module):
not_at_max_decode_length,
tf.logical_not(worst_finished_score_better_than_best_alive_score))
@staticmethod
def _gather_beams(nested, beam_indices, batch_size, new_beam_size):
"""Gather beams from nested structure of tensors.
Each tensor in nested represents a batch of beams, where beam refers to a
single search state (beam search involves searching through multiple states
in parallel).
This function is used to gather the top beams, specified by
beam_indices, from the nested tensors.
Args:
nested: Nested structure (tensor, list, tuple or dict) containing tensors
with shape [batch_size, beam_size, ...].
beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
value in beam_indices must be between [0, beam_size), and are not
necessarily unique.
batch_size: int size of batch
new_beam_size: int number of beams to be pulled from the nested tensors.
Returns:
Nested structure containing tensors with shape
[batch_size, new_beam_size, ...]
"""
# Computes the i'th coodinate that contains the batch index for gather_nd.
# Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size
batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size])
# Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
# with shape [batch_size, beam_size, 2], where the last dimension contains
# the (i, j) gathering coordinates.
coordinates = tf.stack([batch_pos, beam_indices], axis=2)
return tf.nest.map_structure(lambda state: tf.gather_nd(state, coordinates),
nested)
def sequence_beam_search(symbols_to_logits_fn,
initial_ids,
......@@ -663,46 +702,3 @@ def _unflatten_beam_dim(tensor, batch_size, beam_size):
shape = _shape_list(tensor)
new_shape = [batch_size, beam_size] + shape[1:]
return tf.reshape(tensor, new_shape)
def _gather_beams(nested, beam_indices, batch_size, new_beam_size):
"""Gather beams from nested structure of tensors.
Each tensor in nested represents a batch of beams, where beam refers to a
single search state (beam search involves searching through multiple states
in parallel).
This function is used to gather the top beams, specified by
beam_indices, from the nested tensors.
Args:
nested: Nested structure (tensor, list, tuple or dict) containing tensors
with shape [batch_size, beam_size, ...].
beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
value in beam_indices must be between [0, beam_size), and are not
necessarily unique.
batch_size: int size of batch
new_beam_size: int number of beams to be pulled from the nested tensors.
Returns:
Nested structure containing tensors with shape
[batch_size, new_beam_size, ...]
"""
# Computes the i'th coodinate that contains the batch index for gather_nd.
# Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size
batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size])
# Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
# with shape [batch_size, beam_size, 2], where the last dimension contains
# the (i, j) gathering coordinates.
coordinates = tf.stack([batch_pos, beam_indices], axis=2)
return tf.nest.map_structure(lambda state: tf.gather_nd(state, coordinates),
nested)
def _gather_topk_beams(nested, score_or_log_prob, batch_size, beam_size):
"""Gather top beams from nested structure."""
_, topk_indexes = tf.nn.top_k(score_or_log_prob, k=beam_size)
return _gather_beams(nested, topk_indexes, batch_size, beam_size)
......@@ -54,16 +54,7 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
# [16 17 18 19]
# [20 21 22 23]]]
y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
self.assertAllEqual(
[[[4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [20, 21, 22, 23]]],
y)
def test_gather_topk_beams(self):
x = tf.reshape(tf.range(24), [2, 3, 4])
x_scores = [[0, 1, 1], [1, 0, 1]]
y = beam_search._gather_topk_beams(x, x_scores, 2, 2)
y = beam_search.SequenceBeamSearch._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
self.assertAllEqual(
[[[4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [20, 21, 22, 23]]],
y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment