Unverified Commit 9442b3ce authored by Kevin Bondzio's avatar Kevin Bondzio Committed by GitHub
Browse files

Add soft length regulation for sequence generation (#15245)



* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* fix wrong docstring

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* change test according to new param

* fix formatting

* fix test case

* fix doc style

* move start_length calculation to Logitprocessor

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* fix test config, fix formatting

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* add possibility to softly regulate length when using sampling method in model.generate() function

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* remove unused import

* fix small errors

* fix test

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* change test according to new param

* fix test case

* move start_length calculation to Logitprocessor

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* fix test config, fix formatting

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* fix small errors

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/generation_utils.py

* Update src/transformers/generation_utils.py

* fix docstring, add type ind model rag

* fix docstrings

* introduce seq_length variable for cleaner code

* fix black formatting

* add input_ids_seq_length to modeling_rag

* add input_ids_seq_length to test

* retrigger checks

* retrigger checks
Co-authored-by: default avatarKevin Bondzio <kev@AIM-LAP-02.local>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarKevin Bondzio <kev@AIM-LAP-02.fritz.box>
parent 322c8533
......@@ -295,6 +295,7 @@ class PretrainedConfig(PushToHubMixin):
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
......
......@@ -15,7 +15,7 @@
import inspect
import math
from typing import Callable, Iterable, List, Optional
from typing import Callable, Iterable, List, Optional, Tuple
import numpy as np
import torch
......@@ -647,3 +647,32 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
scores[scores == float("inf")] = torch.finfo(scores.dtype).max
return scores
class ExponentialDecayLengthPenalty(LogitsProcessor):
r"""
[`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been
reached.
Args:
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`int`):
The id of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.
"""
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
return scores
......@@ -28,6 +28,7 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, Phra
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
......@@ -667,6 +668,7 @@ class GenerationMixin:
repetition_penalty: float,
no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
input_ids_seq_length: int,
encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]],
min_length: int,
......@@ -679,6 +681,7 @@ class GenerationMixin:
num_beam_groups: int,
diversity_penalty: float,
remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList],
) -> LogitsProcessorList:
"""
......@@ -710,6 +713,11 @@ class GenerationMixin:
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
exponential_decay_length_penalty = (
exponential_decay_length_penalty
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
)
# instantiate processors list
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
......@@ -743,6 +751,10 @@ class GenerationMixin:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
if exponential_decay_length_penalty is not None:
processors.append(
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
)
processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors
......@@ -858,6 +870,7 @@ class GenerationMixin:
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r"""
......@@ -1003,6 +1016,11 @@ class GenerationMixin:
crash. Note that using `remove_invalid_values` can slow down generation.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
where penalty starts and `decay_factor` represents the factor of exponential decay
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
......@@ -1152,10 +1170,12 @@ class GenerationMixin:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
input_ids_seq_length = input_ids.shape[-1]
# 5. Prepare `max_length` depending on other stopping criteria
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
if max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids.shape[-1]
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning
warnings.warn(
......@@ -1167,10 +1187,10 @@ class GenerationMixin:
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
if input_ids.shape[-1] >= max_length:
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}. "
f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. "
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
)
......@@ -1202,6 +1222,7 @@ class GenerationMixin:
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
bad_words_ids=bad_words_ids,
min_length=min_length,
......@@ -1214,6 +1235,7 @@ class GenerationMixin:
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
)
......
......@@ -15,7 +15,7 @@
"""RAG model implementation."""
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -1405,6 +1405,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs
):
"""
......@@ -1534,6 +1535,11 @@ class RagTokenForGeneration(RagPreTrainedModel):
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
)
exponential_decay_length_penalty = (
exponential_decay_length_penalty
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
)
# retrieve docs
if self.retriever is not None and context_input_ids is None:
......@@ -1577,6 +1583,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
dtype=torch.long,
device=next(self.parameters()).device,
)
input_ids_seq_length = input_ids.shape[-1]
last_hidden_state = encoder_outputs["last_hidden_state"]
def extend_enc_output(tensor, num_beams=None):
......@@ -1603,6 +1610,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
......@@ -1615,6 +1623,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
)
......
......@@ -28,6 +28,7 @@ if is_torch_available():
from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
......@@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase):
atol=1e-6,
)
)
def test_exponential_decay_length_penalty(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
penalty_start = 5
penalty_factor = 1.1
input_ids = ids_tensor((batch_size, 2), vocab_size=vocab_size)
input_ids_seq_length = input_ids.shape[-1]
length_decay_processor = ExponentialDecayLengthPenalty(
exponential_decay_length_penalty=(penalty_start, penalty_factor),
eos_token_id=eos_token_id,
input_ids_seq_length=input_ids_seq_length,
)
# check that penalty is not applied before start
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_start = length_decay_processor(input_ids, scores)
self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist())
# check that penalty is applied after start
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_after_start = length_decay_processor(input_ids, scores)
self.assertTrue(
torch.gt(
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)
......@@ -82,6 +82,7 @@ config_common_kwargs = {
"eos_token_id": 8,
"sep_token_id": 9,
"decoder_start_token_id": 10,
"exponential_decay_length_penalty": (5, 1.01),
"task_specific_params": {"translation": "some_params"},
"problem_type": "regression",
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment