Commit 5b9feb6d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 332953126
parent 20a0716c
......@@ -14,12 +14,13 @@
# ==============================================================================
"""Test beam search helper methods."""
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.modeling.ops import beam_search
class BeamSearchHelperTests(tf.test.TestCase):
class BeamSearchTests(tf.test.TestCase, parameterized.TestCase):
def test_expand_to_beam_size(self):
x = tf.ones([7, 4, 2, 5])
......@@ -67,6 +68,41 @@ class BeamSearchHelperTests(tf.test.TestCase):
[[[4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [20, 21, 22, 23]]],
y)
@parameterized.named_parameters([
('padded_decode_true', True),
('padded_decode_false', False),
])
def test_sequence_beam_search(self, padded_decode):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.5, 0.3, 0.2],
[0.1, 0.8, 0.1]],
[[0.1, 0.8, 0.1], [0.3, 0.4, 0.3],
[0.2, 0.1, 0.7]]])
# batch_size, max_decode_length, num_heads, embed_size per head
x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}
if __name__ == "__main__":
def _get_test_symbols_to_logits_fn():
"""Test function that returns logits for next token."""
def symbols_to_logits_fn(_, i, cache):
logits = tf.cast(probabilities[:, i, :], tf.float32)
return logits, cache
return symbols_to_logits_fn
predictions, _ = beam_search.sequence_beam_search(
symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
initial_ids=tf.zeros([1], dtype=tf.int32),
initial_cache=cache,
vocab_size=3,
beam_size=2,
alpha=0.6,
max_decode_length=3,
eos_id=9,
padded_decode=padded_decode,
dtype=tf.float32)
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment