Unverified Commit 9442b3ce authored by Kevin Bondzio's avatar Kevin Bondzio Committed by GitHub
Browse files

Add soft length regulation for sequence generation (#15245)



* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* fix wrong docstring

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* change test according to new param

* fix formatting

* fix test case

* fix doc style

* move start_length calculation to Logitprocessor

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* fix test config, fix formatting

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* add possibility to softly regulate length when using sampling method in model.generate() function

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* remove unused import

* fix small errors

* fix test

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* change test according to new param

* fix test case

* move start_length calculation to Logitprocessor

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* fix test config, fix formatting

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix test config, fix formatting

* fix rag integration, fix docstyling

* add possibility to softly regulate length when using sampling method in model.generate() function

* fix rag integration, fix docstyling

* change param to tuple, add test

* fix old param in rag_model, remove unused import

* fix small errors

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/generation_utils.py

* Update src/transformers/generation_utils.py

* fix docstring, add type ind model rag

* fix docstrings

* introduce seq_length variable for cleaner code

* fix black formatting

* add input_ids_seq_length to modeling_rag

* add input_ids_seq_length to test

* retrigger checks

* retrigger checks
Co-authored-by: default avatarKevin Bondzio <kev@AIM-LAP-02.local>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarKevin Bondzio <kev@AIM-LAP-02.fritz.box>
parent 322c8533
...@@ -295,6 +295,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -295,6 +295,7 @@ class PretrainedConfig(PushToHubMixin):
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
# Fine-tuning task arguments # Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None) self.architectures = kwargs.pop("architectures", None)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import inspect import inspect
import math import math
from typing import Callable, Iterable, List, Optional from typing import Callable, Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -647,3 +647,32 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): ...@@ -647,3 +647,32 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
scores[scores == float("inf")] = torch.finfo(scores.dtype).max scores[scores == float("inf")] = torch.finfo(scores.dtype).max
return scores return scores
class ExponentialDecayLengthPenalty(LogitsProcessor):
r"""
[`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been
reached.
Args:
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`int`):
The id of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.
"""
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
return scores
...@@ -28,6 +28,7 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, Phra ...@@ -28,6 +28,7 @@ from .generation_beam_constraints import Constraint, DisjunctiveConstraint, Phra
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
...@@ -667,6 +668,7 @@ class GenerationMixin: ...@@ -667,6 +668,7 @@ class GenerationMixin:
repetition_penalty: float, repetition_penalty: float,
no_repeat_ngram_size: int, no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int, encoder_no_repeat_ngram_size: int,
input_ids_seq_length: int,
encoder_input_ids: torch.LongTensor, encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]], bad_words_ids: List[List[int]],
min_length: int, min_length: int,
...@@ -679,6 +681,7 @@ class GenerationMixin: ...@@ -679,6 +681,7 @@ class GenerationMixin:
num_beam_groups: int, num_beam_groups: int,
diversity_penalty: float, diversity_penalty: float,
remove_invalid_values: bool, remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList], logits_processor: Optional[LogitsProcessorList],
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
...@@ -710,6 +713,11 @@ class GenerationMixin: ...@@ -710,6 +713,11 @@ class GenerationMixin:
remove_invalid_values = ( remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
) )
exponential_decay_length_penalty = (
exponential_decay_length_penalty
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
)
# instantiate processors list # instantiate processors list
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
...@@ -743,6 +751,10 @@ class GenerationMixin: ...@@ -743,6 +751,10 @@ class GenerationMixin:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if remove_invalid_values is True: if remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor()) processors.append(InfNanRemoveLogitsProcessor())
if exponential_decay_length_penalty is not None:
processors.append(
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
)
processors = self._merge_criteria_processor_list(processors, logits_processor) processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors return processors
...@@ -858,6 +870,7 @@ class GenerationMixin: ...@@ -858,6 +870,7 @@ class GenerationMixin:
forced_eos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None, remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = False, synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs, **model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r""" r"""
...@@ -1003,6 +1016,11 @@ class GenerationMixin: ...@@ -1003,6 +1016,11 @@ class GenerationMixin:
crash. Note that using `remove_invalid_values` can slow down generation. crash. Note that using `remove_invalid_values` can slow down generation.
synced_gpus (`bool`, *optional*, defaults to `False`): synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
where penalty starts and `decay_factor` represents the factor of exponential decay
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
...@@ -1152,10 +1170,12 @@ class GenerationMixin: ...@@ -1152,10 +1170,12 @@ class GenerationMixin:
# if decoder-only then inputs_tensor has to be `input_ids` # if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor input_ids = inputs_tensor
input_ids_seq_length = input_ids.shape[-1]
# 5. Prepare `max_length` depending on other stopping criteria # 5. Prepare `max_length` depending on other stopping criteria
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
if max_length is None and max_new_tokens is not None: if max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids.shape[-1] max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None: elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning # Both are set, this is odd, raise a warning
warnings.warn( warnings.warn(
...@@ -1167,10 +1187,10 @@ class GenerationMixin: ...@@ -1167,10 +1187,10 @@ class GenerationMixin:
# default to config if still None # default to config if still None
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
if input_ids.shape[-1] >= max_length: if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning( logger.warning(
f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}. " f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. "
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
) )
...@@ -1202,6 +1222,7 @@ class GenerationMixin: ...@@ -1202,6 +1222,7 @@ class GenerationMixin:
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
...@@ -1214,6 +1235,7 @@ class GenerationMixin: ...@@ -1214,6 +1235,7 @@ class GenerationMixin:
num_beam_groups=num_beam_groups, num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty, diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values, remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor, logits_processor=logits_processor,
) )
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""RAG model implementation.""" """RAG model implementation."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -1405,6 +1405,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1405,6 +1405,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
forced_bos_token_id: Optional[int] = None, forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None, remove_invalid_values: Optional[bool] = None,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs **model_kwargs
): ):
""" """
...@@ -1534,6 +1535,11 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1534,6 +1535,11 @@ class RagTokenForGeneration(RagPreTrainedModel):
remove_invalid_values = ( remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
) )
exponential_decay_length_penalty = (
exponential_decay_length_penalty
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
)
# retrieve docs # retrieve docs
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
...@@ -1577,6 +1583,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1577,6 +1583,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
input_ids_seq_length = input_ids.shape[-1]
last_hidden_state = encoder_outputs["last_hidden_state"] last_hidden_state = encoder_outputs["last_hidden_state"]
def extend_enc_output(tensor, num_beams=None): def extend_enc_output(tensor, num_beams=None):
...@@ -1603,6 +1610,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1603,6 +1610,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=context_input_ids, encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
...@@ -1615,6 +1623,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1615,6 +1623,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beam_groups=num_beam_groups, num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty, diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values, remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor, logits_processor=logits_processor,
) )
......
...@@ -28,6 +28,7 @@ if is_torch_available(): ...@@ -28,6 +28,7 @@ if is_torch_available():
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
...@@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -504,3 +505,35 @@ class LogitsProcessorTest(unittest.TestCase):
atol=1e-6, atol=1e-6,
) )
) )
def test_exponential_decay_length_penalty(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
penalty_start = 5
penalty_factor = 1.1
input_ids = ids_tensor((batch_size, 2), vocab_size=vocab_size)
input_ids_seq_length = input_ids.shape[-1]
length_decay_processor = ExponentialDecayLengthPenalty(
exponential_decay_length_penalty=(penalty_start, penalty_factor),
eos_token_id=eos_token_id,
input_ids_seq_length=input_ids_seq_length,
)
# check that penalty is not applied before start
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_start = length_decay_processor(input_ids, scores)
self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist())
# check that penalty is applied after start
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_after_start = length_decay_processor(input_ids, scores)
self.assertTrue(
torch.gt(
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)
...@@ -82,6 +82,7 @@ config_common_kwargs = { ...@@ -82,6 +82,7 @@ config_common_kwargs = {
"eos_token_id": 8, "eos_token_id": 8,
"sep_token_id": 9, "sep_token_id": 9,
"decoder_start_token_id": 10, "decoder_start_token_id": 10,
"exponential_decay_length_penalty": (5, 1.01),
"task_specific_params": {"translation": "some_params"}, "task_specific_params": {"translation": "some_params"},
"problem_type": "regression", "problem_type": "regression",
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment