Commit b15a86fc authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make two beam search utils function as public. They are helpful.

PiperOrigin-RevId: 341953641
parent 026367f1
...@@ -188,8 +188,8 @@ class SequenceBeamSearch(tf.Module): ...@@ -188,8 +188,8 @@ class SequenceBeamSearch(tf.Module):
tf.slice(alive_seq, [0, 0, i], [batch_size, self.beam_size, 1]), tf.slice(alive_seq, [0, 0, i], [batch_size, self.beam_size, 1]),
[batch_size * self.beam_size, -1]) [batch_size * self.beam_size, -1])
else: else:
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size] flat_ids = flatten_beam_dim(alive_seq) # [batch_size * beam_size]
flat_cache = tf.nest.map_structure(_flatten_beam_dim, alive_cache) flat_cache = tf.nest.map_structure(flatten_beam_dim, alive_cache)
flat_logits, flat_cache = self.symbols_to_logits_fn( flat_logits, flat_cache = self.symbols_to_logits_fn(
flat_ids, i, flat_cache) flat_ids, i, flat_cache)
...@@ -404,7 +404,7 @@ class SequenceBeamSearch(tf.Module): ...@@ -404,7 +404,7 @@ class SequenceBeamSearch(tf.Module):
cur_index = tf.constant(0) cur_index = tf.constant(0)
# Create alive sequence with shape [batch_size, beam_size, 1] # Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq = _expand_to_beam_size(initial_ids, self.beam_size) alive_seq = expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2) alive_seq = tf.expand_dims(alive_seq, axis=2)
if self.padded_decode: if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1]) alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
...@@ -419,7 +419,7 @@ class SequenceBeamSearch(tf.Module): ...@@ -419,7 +419,7 @@ class SequenceBeamSearch(tf.Module):
# Expand all values stored in the dictionary to the beam size, so that each # Expand all values stored in the dictionary to the beam size, so that each
# beam has a separate cache. # beam has a separate cache.
alive_cache = tf.nest.map_structure( alive_cache = tf.nest.map_structure(
lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache) lambda t: expand_to_beam_size(t, self.beam_size), initial_cache)
# Initialize tensor storing finished sequences with filler values. # Initialize tensor storing finished sequences with filler values.
finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32) finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
...@@ -588,7 +588,7 @@ def _length_normalization(alpha, length, dtype=tf.float32): ...@@ -588,7 +588,7 @@ def _length_normalization(alpha, length, dtype=tf.float32):
return tf.pow(((5. + tf.cast(length, dtype)) / 6.), alpha) return tf.pow(((5. + tf.cast(length, dtype)) / 6.), alpha)
def _expand_to_beam_size(tensor, beam_size): def expand_to_beam_size(tensor, beam_size):
"""Tiles a given tensor by beam_size. """Tiles a given tensor by beam_size.
Args: Args:
...@@ -605,6 +605,21 @@ def _expand_to_beam_size(tensor, beam_size): ...@@ -605,6 +605,21 @@ def _expand_to_beam_size(tensor, beam_size):
return tf.tile(tensor, tile_dims) return tf.tile(tensor, tile_dims)
def flatten_beam_dim(tensor):
"""Reshapes first two dimensions into a single dimension.
Args:
tensor: Tensor to reshape of shape [A, B, ...]
Returns:
Reshaped tensor of shape [A*B, ...]
"""
shape = _shape_list(tensor)
shape[0] *= shape[1]
shape.pop(1) # Remove beam dim
return tf.reshape(tensor, shape)
def _shape_list(tensor): def _shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list.""" """Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions) # Get statically known shape (may contain None's for unknown dimensions)
...@@ -630,21 +645,6 @@ def _get_shape_keep_last_dim(tensor): ...@@ -630,21 +645,6 @@ def _get_shape_keep_last_dim(tensor):
return tf.TensorShape(shape_list) return tf.TensorShape(shape_list)
def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension.
Args:
tensor: Tensor to reshape of shape [A, B, ...]
Returns:
Reshaped tensor of shape [A*B, ...]
"""
shape = _shape_list(tensor)
shape[0] *= shape[1]
shape.pop(1) # Remove beam dim
return tf.reshape(tensor, shape)
def _unflatten_beam_dim(tensor, batch_size, beam_size): def _unflatten_beam_dim(tensor, batch_size, beam_size):
"""Reshapes first dimension back to [batch_size, beam_size]. """Reshapes first dimension back to [batch_size, beam_size].
......
...@@ -24,7 +24,7 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase): ...@@ -24,7 +24,7 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
def test_expand_to_beam_size(self): def test_expand_to_beam_size(self):
x = tf.ones([7, 4, 2, 5]) x = tf.ones([7, 4, 2, 5])
x = beam_search._expand_to_beam_size(x, 3) x = beam_search.expand_to_beam_size(x, 3)
shape = tf.shape(x) shape = tf.shape(x)
self.assertAllEqual([7, 3, 4, 2, 5], shape) self.assertAllEqual([7, 3, 4, 2, 5], shape)
...@@ -36,7 +36,7 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase): ...@@ -36,7 +36,7 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
def test_flatten_beam_dim(self): def test_flatten_beam_dim(self):
x = tf.ones([7, 4, 2, 5]) x = tf.ones([7, 4, 2, 5])
x = beam_search._flatten_beam_dim(x) x = beam_search.flatten_beam_dim(x)
self.assertAllEqual([28, 2, 5], tf.shape(x)) self.assertAllEqual([28, 2, 5], tf.shape(x))
def test_unflatten_beam_dim(self): def test_unflatten_beam_dim(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment