Unverified Commit c3c39f7e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add Beam Search (#12131)



* fix_torch_device_generate_test

* remove @

* push new logit processors

* add processors

* save first working version

* save intermediate

* finish

* make style

* make fix-copies

* finish

* Update tests/test_modeling_flax_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 802ffaff
......@@ -186,6 +186,15 @@ generation.
.. autoclass:: transformers.FlaxTopKLogitsWarper
:members: __call__
.. autoclass:: transformers.FlaxForcedBOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.FlaxForcedEOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.FlaxMinLengthLogitsProcessor
:members: __call__
StoppingCriteria
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -1486,9 +1486,12 @@ else:
# FLAX-backed objects
if is_flax_available():
_import_structure["generation_flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
......@@ -2814,9 +2817,12 @@ if TYPE_CHECKING:
if is_flax_available():
from .generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
......
......@@ -81,16 +81,18 @@ class FlaxLogitsProcessorList(list):
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
) -> jax_xla.DeviceArray:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if len(function_args) > 3:
assert all(
arg in kwargs for arg in list(function_args.keys())[2:]
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
scores = processor(input_ids, scores, **kwargs)
scores = processor(input_ids, scores, cur_len, **kwargs)
else:
scores = processor(input_ids, scores)
scores = processor(input_ids, scores, cur_len)
return scores
......@@ -109,7 +111,9 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
self.temperature = temperature
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
scores = scores / self.temperature
return scores
......@@ -137,7 +141,9 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
mask_scores = jnp.full_like(scores, self.filter_value)
......@@ -177,7 +183,9 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
......@@ -190,3 +198,94 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
return next_scores
class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
r"""
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the first generated token.
Args:
bos_token_id (:obj:`int`):
The id of the token to force as the first generated token.
"""
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores
)
return scores
class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
r"""
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the last generated token when
:obj:`max_length` is reached.
Args:
max_length (:obj:`int`):
The maximum length of the sequence to be generated.
eos_token_id (:obj:`int`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
"""
def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length
self.eos_token_id = eos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores
)
return scores
class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
r"""
:class:`transformers.FlaxLogitsProcessor` enforcing a min-length by setting EOS probability to 0.
Args:
min_length (:obj:`int`):
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
)
return scores
This diff is collapsed.
......@@ -2,6 +2,24 @@
from ..file_utils import requires_backends
class FlaxForcedBOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxForcedEOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......@@ -25,6 +43,15 @@ class FlaxLogitsWarper:
requires_backends(self, ["flax"])
class FlaxMinLengthLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxTemperatureLogitsWarper:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......
......@@ -28,7 +28,10 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
......@@ -57,8 +60,8 @@ class LogitsProcessorTest(unittest.TestCase):
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1)
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1)
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy(), cur_len=None), axis=-1)
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy(), cur_len=None), axis=-1)
# uniform distribution stays uniform
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
......@@ -83,7 +86,7 @@ class LogitsProcessorTest(unittest.TestCase):
top_k_warp = FlaxTopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
scores = top_k_warp(input_ids, ramp_logits, cur_len=None)
# check that correct tokens are filtered
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
......@@ -94,7 +97,7 @@ class LogitsProcessorTest(unittest.TestCase):
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits)
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
......@@ -108,7 +111,7 @@ class LogitsProcessorTest(unittest.TestCase):
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
top_p_warp = FlaxTopPLogitsWarper(0.7)
filtered_dist = np.exp(top_p_warp(input_ids, dist))
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
......@@ -125,15 +128,128 @@ class LogitsProcessorTest(unittest.TestCase):
# make sure at least 2 tokens are kept
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len=None)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
def test_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 20), vocab_size=20)
cur_len = 5
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
scores = self._get_uniform_logits(batch_size, vocab_size)
cur_len = 15
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores_before_min_length).any())
def test_forced_bos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
# check that all scores are -inf except the bos_token_id score
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
cur_len = 1
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
# check that bos_token_id is not forced if current length is greater than 1
cur_len = 3
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())
def test_forced_eos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
cur_len = 4
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length is not reached
cur_len = 3
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 2
bos_token_id = 1
max_length = 15
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.copy()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.copy()
# instantiate all dist processors
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
cur_len = 10
# no processor list
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
# with processor list
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
# scores should be equal
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def test_processor_list_jitted(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 2
bos_token_id = 1
max_length = 15
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
......@@ -147,14 +263,36 @@ class LogitsProcessorTest(unittest.TestCase):
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
cur_len = 10
# no processor list
scores = temp_dist_warp(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
def run_no_processor_list(input_ids, scores, cur_len):
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
return scores
# with processor list
processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp])
scores_comp = processor(input_ids, scores_comp)
def run_processor_list(input_ids, scores, cur_len):
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
)
scores = processor(input_ids, scores, cur_len=cur_len)
return scores
jitted_run_no_processor_list = jax.jit(run_no_processor_list)
jitted_run_processor_list = jax.jit(run_processor_list)
scores = jitted_run_no_processor_list(input_ids, scores, cur_len)
scores_comp = jitted_run_processor_list(input_ids, scores_comp, cur_len)
# scores should be equal
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
......
......@@ -110,6 +110,23 @@ class FlaxGenerationTesterMixin:
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.num_beams = 2
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
......@@ -117,6 +134,46 @@ class FlaxGenerationTesterMixin:
config.temperature = 0.8
config.top_k = 10
config.top_p = 0.3
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_greedy_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.num_beams = 2
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
......@@ -168,3 +225,23 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.num_beams = 2
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment