Commit a15ebc46 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Allow passing an optional name for the decoding loop tensors.

PiperOrigin-RevId: 455384659
parent 3c77e654
...@@ -107,18 +107,18 @@ class SequenceBeamSearch(tf.Module): ...@@ -107,18 +107,18 @@ class SequenceBeamSearch(tf.Module):
max_decode_length, max_decode_length,
eos_id, eos_id,
padded_decode, padded_decode,
dtype=tf.float32): dtype=tf.float32,
decoding_name=None):
"""Initialize sequence beam search. """Initialize sequence beam search.
Args: Args:
symbols_to_logits_fn: A function to provide logits, which is the symbols_to_logits_fn: A function to provide logits, which is the interface
interface to the Transformer model. The passed in arguments are: ids -> to the Transformer model. The passed in arguments are: ids -> A tensor
A tensor with shape [batch_size * beam_size, index]. index -> A with shape [batch_size * beam_size, index]. index -> A scalar. cache ->
scalar. cache -> A nested dictionary of tensors [batch_size * A nested dictionary of tensors [batch_size * beam_size, ...]. The
beam_size, ...]. function must return a tuple of logits and the updated cache: logits ->
The function must return a tuple of logits and the updated cache: logits A tensor with shape [batch * beam_size, vocab_size]. updated cache -> A
-> A tensor with shape [batch * beam_size, vocab_size]. updated cache nested dictionary with the same structure as the input cache.
-> A nested dictionary with the same structure as the input cache.
vocab_size: An integer, the size of the vocabulary, used for topk vocab_size: An integer, the size of the vocabulary, used for topk
computation. computation.
beam_size: An integer, number of beams for beam search. beam_size: An integer, number of beams for beam search.
...@@ -130,6 +130,7 @@ class SequenceBeamSearch(tf.Module): ...@@ -130,6 +130,7 @@ class SequenceBeamSearch(tf.Module):
for beam search. for beam search.
dtype: A tensorflow data type used for score computation. The default is dtype: A tensorflow data type used for score computation. The default is
tf.float32. tf.float32.
decoding_name: an optional name for the decoding loop tensors.
""" """
self.symbols_to_logits_fn = symbols_to_logits_fn self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -139,6 +140,7 @@ class SequenceBeamSearch(tf.Module): ...@@ -139,6 +140,7 @@ class SequenceBeamSearch(tf.Module):
self.eos_id = eos_id self.eos_id = eos_id
self.padded_decode = padded_decode self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
self.decoding_name = decoding_name
def search(self, initial_ids, initial_cache): def search(self, initial_ids, initial_cache):
"""Beam search for sequences with highest scores. """Beam search for sequences with highest scores.
...@@ -370,7 +372,8 @@ class SequenceBeamSearch(tf.Module): ...@@ -370,7 +372,8 @@ class SequenceBeamSearch(tf.Module):
_search_step, _search_step,
loop_vars=[state], loop_vars=[state],
shape_invariants=[state_shapes], shape_invariants=[state_shapes],
parallel_iterations=1)) parallel_iterations=1,
name=self.decoding_name))
finished_state = finished_state[0] finished_state = finished_state[0]
return self._process_finished_state(finished_state) return self._process_finished_state(finished_state)
...@@ -587,7 +590,8 @@ def sequence_beam_search(symbols_to_logits_fn, ...@@ -587,7 +590,8 @@ def sequence_beam_search(symbols_to_logits_fn,
max_decode_length, max_decode_length,
eos_id, eos_id,
padded_decode=False, padded_decode=False,
dtype="float32"): dtype="float32",
decoding_name=None):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
...@@ -612,13 +616,15 @@ def sequence_beam_search(symbols_to_logits_fn, ...@@ -612,13 +616,15 @@ def sequence_beam_search(symbols_to_logits_fn,
beam search. beam search.
dtype: A tensorflow data type used for score computation. The default is dtype: A tensorflow data type used for score computation. The default is
tf.float32. tf.float32.
decoding_name: an optional name for the decoding loop tensors.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size] sequence scores [batch_size, beam_size]
""" """
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, beam_size, alpha, sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, beam_size, alpha,
max_decode_length, eos_id, padded_decode, dtype) max_decode_length, eos_id, padded_decode, dtype,
decoding_name)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
......
...@@ -60,10 +60,12 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase): ...@@ -60,10 +60,12 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
y) y)
@parameterized.named_parameters([ @parameterized.named_parameters([
('padded_decode_true', True), ('padded_decode_true_with_name', True, 'decoding'),
('padded_decode_false', False), ('padded_decode_false_with_name', False, 'decoding'),
('padded_decode_true_without_name', True, None),
('padded_decode_false_without_name', False, None),
]) ])
def test_sequence_beam_search(self, padded_decode): def test_sequence_beam_search(self, padded_decode, name):
# batch_size*beam_size, max_decode_length, vocab_size # batch_size*beam_size, max_decode_length, vocab_size
probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2], probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2],
[0.1, 0.8, 0.1]], [0.1, 0.8, 0.1]],
...@@ -91,7 +93,8 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase): ...@@ -91,7 +93,8 @@ class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
max_decode_length=3, max_decode_length=3,
eos_id=9, eos_id=9,
padded_decode=padded_decode, padded_decode=padded_decode,
dtype=tf.float32) dtype=tf.float32,
decoding_name=name)
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions) self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy).""" """Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
import abc import abc
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import tensorflow as tf import tensorflow as tf
...@@ -108,7 +108,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -108,7 +108,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
length_normalization_fn: Callable[[int, tf.DType], float], length_normalization_fn: Callable[[int, tf.DType], float],
dtype: tf.DType = tf.float32): dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None):
"""Initialize the Decoding Module. """Initialize the Decoding Module.
Args: Args:
...@@ -116,9 +117,11 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -116,9 +117,11 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
parameter. Function accepts input as length, dtype and returns float. parameter. Function accepts input as length, dtype and returns float.
dtype: A tensorflow data type used for score computation. The default is dtype: A tensorflow data type used for score computation. The default is
tf.float32. tf.float32.
decoding_name: an optional name for the decoding loop tensors.
""" """
self.length_normalization_fn = length_normalization_fn self.length_normalization_fn = length_normalization_fn
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
self.decoding_name = decoding_name
def generate(self, def generate(self,
initial_ids: tf.Tensor, initial_ids: tf.Tensor,
...@@ -169,7 +172,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -169,7 +172,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
_generate_step, _generate_step,
loop_vars=[state], loop_vars=[state],
shape_invariants=[state_shapes], shape_invariants=[state_shapes],
parallel_iterations=1)) parallel_iterations=1,
name=self.decoding_name))
final_state = self._process_finished_state(finished_state[0]) final_state = self._process_finished_state(finished_state[0])
return final_state return final_state
...@@ -277,6 +281,3 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -277,6 +281,3 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
return dtypes.float16.max return dtypes.float16.max
else: else:
raise AssertionError("Invalid dtype: %s" % self.dtype) raise AssertionError("Invalid dtype: %s" % self.dtype)
...@@ -162,7 +162,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -162,7 +162,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
top_p=1.0, top_p=1.0,
sample_temperature=0.0, sample_temperature=0.0,
enable_greedy: bool = True, enable_greedy: bool = True,
dtype: tf.DType = tf.float32): dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None):
"""Initialize sampling module.""" """Initialize sampling module."""
self.symbols_to_logits_fn = symbols_to_logits_fn self.symbols_to_logits_fn = symbols_to_logits_fn
self.length_normalization_fn = length_normalization_fn self.length_normalization_fn = length_normalization_fn
...@@ -176,8 +177,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -176,8 +177,11 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self.sample_temperature = tf.convert_to_tensor( self.sample_temperature = tf.convert_to_tensor(
sample_temperature, dtype=tf.float32) sample_temperature, dtype=tf.float32)
self.enable_greedy = enable_greedy self.enable_greedy = enable_greedy
self.decoding_name = decoding_name
super(SamplingModule, self).__init__( super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype) length_normalization_fn=length_normalization_fn,
dtype=dtype,
decoding_name=decoding_name)
def _grow_alive_seq(self, def _grow_alive_seq(self,
state: Dict[str, Any], state: Dict[str, Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment