"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "94306352f489c7c2a8dc18af89e2efe0a76a5159"
Unverified Commit 367fdf33 authored by Konstantin Kotik's avatar Konstantin Kotik Committed by GitHub
Browse files

`MinNewTokensLengthLogitsProcessor` for `.generate` method #20814 (#20892)



* feat: add min new length logit processor

* test: add min new length logit processor

* docs: add MinNewTokensLengthLogitsProcessor

* feat: import MinNewTokensLengthLogitsProcessor

* fix: update pytorch dummy objects

* refactor & fix: rename attributes and var and get rid of dynamic attribute

* tests: align test with new interface

* docs: fix typo

* docs: minor clarification

* Empty-Commit

* empty commit

* run automated quality edits
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
parent 4fd89e49
...@@ -116,6 +116,9 @@ generation. ...@@ -116,6 +116,9 @@ generation.
[[autodoc]] MinLengthLogitsProcessor [[autodoc]] MinLengthLogitsProcessor
- __call__ - __call__
[[autodoc]] MinNewTokensLengthLogitsProcessor
- __call__
[[autodoc]] TemperatureLogitsWarper [[autodoc]] TemperatureLogitsWarper
- __call__ - __call__
......
...@@ -886,6 +886,7 @@ else: ...@@ -886,6 +886,7 @@ else:
"MaxLengthCriteria", "MaxLengthCriteria",
"MaxTimeCriteria", "MaxTimeCriteria",
"MinLengthLogitsProcessor", "MinLengthLogitsProcessor",
"MinNewTokensLengthLogitsProcessor",
"NoBadWordsLogitsProcessor", "NoBadWordsLogitsProcessor",
"NoRepeatNGramLogitsProcessor", "NoRepeatNGramLogitsProcessor",
"PhrasalConstraint", "PhrasalConstraint",
...@@ -4140,6 +4141,7 @@ if TYPE_CHECKING: ...@@ -4140,6 +4141,7 @@ if TYPE_CHECKING:
MaxLengthCriteria, MaxLengthCriteria,
MaxTimeCriteria, MaxTimeCriteria,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PhrasalConstraint, PhrasalConstraint,
......
...@@ -51,6 +51,7 @@ else: ...@@ -51,6 +51,7 @@ else:
"LogitsProcessorList", "LogitsProcessorList",
"LogitsWarper", "LogitsWarper",
"MinLengthLogitsProcessor", "MinLengthLogitsProcessor",
"MinNewTokensLengthLogitsProcessor",
"NoBadWordsLogitsProcessor", "NoBadWordsLogitsProcessor",
"NoRepeatNGramLogitsProcessor", "NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor", "PrefixConstrainedLogitsProcessor",
...@@ -171,6 +172,7 @@ if TYPE_CHECKING: ...@@ -171,6 +172,7 @@ if TYPE_CHECKING:
LogitsProcessorList, LogitsProcessorList,
LogitsWarper, LogitsWarper,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
......
...@@ -121,6 +121,42 @@ class MinLengthLogitsProcessor(LogitsProcessor): ...@@ -121,6 +121,42 @@ class MinLengthLogitsProcessor(LogitsProcessor):
return scores return scores
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
Args:
prompt_length_to_skip (`int`):
The input tokens length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
("eos_token_id", eos_token_id),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
if new_tokens_length < self.min_new_tokens:
scores[:, self.eos_token_id] = -float("inf")
return scores
class TemperatureLogitsWarper(LogitsWarper): class TemperatureLogitsWarper(LogitsWarper):
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`LogitsWarper`] for temperature (exponential scaling output probability distribution).
......
...@@ -199,6 +199,13 @@ class MinLengthLogitsProcessor(metaclass=DummyObject): ...@@ -199,6 +199,13 @@ class MinLengthLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MinNewTokensLengthLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NoBadWordsLogitsProcessor(metaclass=DummyObject): class NoBadWordsLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -36,6 +36,7 @@ if is_torch_available(): ...@@ -36,6 +36,7 @@ if is_torch_available():
LogitNormalization, LogitNormalization,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
...@@ -72,6 +73,54 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -72,6 +73,54 @@ class LogitsProcessorTest(unittest.TestCase):
scores_before_min_length = min_dist_processor(input_ids, scores) scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any()) self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_new_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
# check that first input is skipped (min new length applying)
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
new_min_dist_processor = MinNewTokensLengthLogitsProcessor(
prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=3, eos_token_id=eos_token_id
)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that, for skipping, now prompt length is 5, after that we expect first 5 tokens will be skipped
self.assertTrue(new_min_dist_processor.prompt_length_to_skip == 5)
# check that min length is applied at length 2
input_ids = ids_tensor((batch_size, 2), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that min new length is applied at length 6 (because it has only 1 new token)
input_ids = ids_tensor((batch_size, 6), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that min new length is applied at length 7 (because it has only 2 new tokens)
input_ids = ids_tensor((batch_size, 7), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), batch_size * [-float("inf")])
# check that min new length is not applied anymore at length 8
input_ids = ids_tensor((batch_size, 8), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
# check that min new length is not applied anymore at length 15
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = new_min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_temperature_dist_warper(self): def test_temperature_dist_warper(self):
input_ids = None input_ids = None
length = 20 length = 20
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment