Commit edfe5df3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Allow passing an optional name for the decoding loop tensors.

PiperOrigin-RevId: 455384659
parent 9a783f15
......@@ -107,18 +107,18 @@ class SequenceBeamSearch(tf.Module):
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32):
dtype=tf.float32,
decoding_name=None):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are: ids ->
A tensor with shape [batch_size * beam_size, index]. index -> A
scalar. cache -> A nested dictionary of tensors [batch_size *
beam_size, ...].
The function must return a tuple of logits and the updated cache: logits
-> A tensor with shape [batch * beam_size, vocab_size]. updated cache
-> A nested dictionary with the same structure as the input cache.
symbols_to_logits_fn: A function to provide logits, which is the interface
to the Transformer model. The passed in arguments are: ids -> A tensor
with shape [batch_size * beam_size, index]. index -> A scalar. cache ->
A nested dictionary of tensors [batch_size * beam_size, ...]. The
function must return a tuple of logits and the updated cache: logits ->
A tensor with shape [batch * beam_size, vocab_size]. updated cache -> A
nested dictionary with the same structure as the input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
beam_size: An integer, number of beams for beam search.
......@@ -130,6 +130,7 @@ class SequenceBeamSearch(tf.Module):
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
"""
self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size
......@@ -139,6 +140,7 @@ class SequenceBeamSearch(tf.Module):
self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype)
self.decoding_name = decoding_name
def search(self, initial_ids, initial_cache):
"""Beam search for sequences with highest scores.
......@@ -370,7 +372,8 @@ class SequenceBeamSearch(tf.Module):
_search_step,
loop_vars=[state],
shape_invariants=[state_shapes],
parallel_iterations=1))
parallel_iterations=1,
name=self.decoding_name))
finished_state = finished_state[0]
return self._process_finished_state(finished_state)
......@@ -587,7 +590,8 @@ def sequence_beam_search(symbols_to_logits_fn,
max_decode_length,
eos_id,
padded_decode=False,
dtype="float32"):
dtype="float32",
decoding_name=None):
"""Search for sequence of subtoken ids with the largest probability.
Args:
......@@ -612,13 +616,15 @@ def sequence_beam_search(symbols_to_logits_fn,
beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, beam_size, alpha,
max_decode_length, eos_id, padded_decode, dtype)
max_decode_length, eos_id, padded_decode, dtype,
decoding_name)
return sbs.search(initial_ids, initial_cache)
......
......@@ -60,10 +60,12 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
y)
@parameterized.named_parameters([
('padded_decode_true', True),
('padded_decode_false', False),
('padded_decode_true_with_name', True, 'decoding'),
('padded_decode_false_with_name', False, 'decoding'),
('padded_decode_true_without_name', True, None),
('padded_decode_false_without_name', False, None),
])
def test_sequence_beam_search(self, padded_decode):
def test_sequence_beam_search(self, padded_decode, name):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2],
[0.1, 0.8, 0.1]],
......@@ -91,7 +93,8 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
max_decode_length=3,
eos_id=9,
padded_decode=padded_decode,
dtype=tf.float32)
dtype=tf.float32,
decoding_name=name)
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
......
......@@ -15,7 +15,7 @@
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
import abc
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
import tensorflow as tf
......@@ -108,7 +108,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self,
length_normalization_fn: Callable[[int, tf.DType], float],
dtype: tf.DType = tf.float32):
dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None):
"""Initialize the Decoding Module.
Args:
......@@ -116,9 +117,11 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
parameter. Function accepts input as length, dtype and returns float.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
"""
self.length_normalization_fn = length_normalization_fn
self.dtype = tf.as_dtype(dtype)
self.decoding_name = decoding_name
def generate(self,
initial_ids: tf.Tensor,
......@@ -169,7 +172,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
_generate_step,
loop_vars=[state],
shape_invariants=[state_shapes],
parallel_iterations=1))
parallel_iterations=1,
name=self.decoding_name))
final_state = self._process_finished_state(finished_state[0])
return final_state
......@@ -277,6 +281,3 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
return dtypes.float16.max
else:
raise AssertionError("Invalid dtype: %s" % self.dtype)
......@@ -162,7 +162,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
top_p=1.0,
sample_temperature=0.0,
enable_greedy: bool = True,
dtype: tf.DType = tf.float32):
dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None):
"""Initialize sampling module."""
self.symbols_to_logits_fn = symbols_to_logits_fn
self.length_normalization_fn = length_normalization_fn
......@@ -176,8 +177,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self.sample_temperature = tf.convert_to_tensor(
sample_temperature, dtype=tf.float32)
self.enable_greedy = enable_greedy
self.decoding_name = decoding_name
super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype)
length_normalization_fn=length_normalization_fn,
dtype=dtype,
decoding_name=decoding_name)
def _grow_alive_seq(self,
state: Dict[str, Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment