Unverified Commit c130e67d authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

remove adjust_logits_during_generation method (#10087)

* add forced logits processors

* delete adjust_logits method

* add forced_eos_token_id argument in config

* add tests for forced logits processors

* update gen utils tests

* add forced option to tf generate

* remove adjust_logits method from tf models

* update adjust_logits for marian

* delete _force_token_id_to_be_generated method

* style

* import warnings

* pass max_length to _get_logits_processor

* set forced_eos_token_id to None

* set forced attributes in conf utils

* typo

* fix rag generate

* add forced_eos_token_id in rag config

* remove force_bos_token_to_be_generated from BartConfig

* remove _force_token_ids_generation from FSMT

* nit

* fix negative constant

* apply suggestions from code review
parent 22a32cf4
......@@ -1468,10 +1468,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
+ layer_past_key_values[2:],
)
return (past[0], reordered_past)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
......@@ -84,6 +84,9 @@ class PegasusConfig(PretrainedConfig):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (:obj:`int`, `optional`, defaults to 1):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
Example::
......@@ -127,6 +130,7 @@ class PegasusConfig(PretrainedConfig):
gradient_checkpointing=False,
pad_token_id=0,
eos_token_id=1,
forced_eos_token_id=1,
**kwargs
):
super().__init__(
......@@ -134,6 +138,7 @@ class PegasusConfig(PretrainedConfig):
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
......
......@@ -1327,16 +1327,6 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
......
......@@ -1483,10 +1483,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
+ layer_past_key_values[2:],
)
return (past[0], reordered_past)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
......@@ -74,6 +74,9 @@ RAG_CONFIG_DOC = r"""
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
"""
......@@ -110,6 +113,7 @@ class RagConfig(PretrainedConfig):
do_marginalize=False,
output_retrieved=False,
use_cache=True,
forced_eos_token_id=None,
**kwargs
):
super().__init__(
......@@ -117,6 +121,7 @@ class RagConfig(PretrainedConfig):
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
is_encoder_decoder=is_encoder_decoder,
prefix=prefix,
vocab_size=vocab_size,
......@@ -161,6 +166,9 @@ class RagConfig(PretrainedConfig):
self.use_cache = use_cache
if self.forced_eos_token_id is None:
self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
@classmethod
def from_question_encoder_generator_configs(
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
......
......@@ -1089,9 +1089,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
def set_retriever(self, retriever: RagRetriever):
self.rag.retriever = retriever
def adjust_logits_during_generation(self, logits, cur_len, max_length):
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
......@@ -1313,6 +1310,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
decoder_start_token_id=None,
n_docs=None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
**model_kwargs
):
"""
......@@ -1403,6 +1402,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
forced_bos_token_id (:obj:`int`, `optional`):
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
......@@ -1498,7 +1503,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder_input_ids=context_input_ids,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
......
......@@ -28,6 +28,8 @@ if is_torch_available():
from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
......@@ -393,3 +395,44 @@ class LogitsProcessorTest(unittest.TestCase):
processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3
)
)
def test_forced_bos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
# check that all scores are -inf except the bos_token_id score
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertTrue(torch.isneginf(scores[:, bos_token_id + 1 :]).all())
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
# check that bos_token_id is not forced if current length is greater than 1
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any())
def test_forced_eos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all())
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length is not reached
input_ids = ids_tensor((batch_size, 3), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any())
......@@ -26,6 +26,8 @@ if is_torch_available():
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering
from transformers.generation_beam_search import BeamSearchScorer
from transformers.generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
......@@ -70,7 +72,14 @@ class GenerationTesterMixin:
return config, input_ids, attention_mask, max_length
@staticmethod
def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None):
def _get_logits_processor_and_kwargs(
input_length,
eos_token_id,
forced_bos_token_id=None,
forced_eos_token_id=None,
max_length=None,
diversity_penalty=None,
):
process_kwargs = {
"min_length": input_length + 1,
"bad_words_ids": [[1, 0]],
......@@ -92,6 +101,18 @@ class GenerationTesterMixin:
if eos_token_id is not None
else []
)
+ (
[
ForcedBOSTokenLogitsProcessor(forced_bos_token_id),
]
if forced_bos_token_id is not None
else []
)
+ (
[ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)]
if forced_eos_token_id is not None
else []
)
+ [
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
......@@ -182,13 +203,17 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
if model.config.is_encoder_decoder:
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], model.config.eos_token_id
input_ids.shape[-1],
eos_token_id=model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
kwargs = {}
if model.config.is_encoder_decoder:
max_length = 4
output_generate = model.generate(
input_ids,
......@@ -544,14 +569,19 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], model.config.eos_token_id
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
if model.config.is_encoder_decoder:
max_length = 4
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
model=model,
......@@ -586,14 +616,18 @@ class GenerationTesterMixin:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 4
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], model.config.eos_token_id
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
if model.config.is_encoder_decoder:
max_length = 4
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
......@@ -630,14 +664,19 @@ class GenerationTesterMixin:
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
# check `generate()` and `beam_search()` are equal
......@@ -684,13 +723,19 @@ class GenerationTesterMixin:
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
if model.config.is_encoder_decoder:
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_generate, output_beam_search = self._beam_search_generate(
model=model,
......@@ -732,19 +777,24 @@ class GenerationTesterMixin:
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache"
return
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
config.use_cache = True
......@@ -780,6 +830,7 @@ class GenerationTesterMixin:
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
......@@ -819,6 +870,7 @@ class GenerationTesterMixin:
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
......@@ -892,16 +944,22 @@ class GenerationTesterMixin:
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
diversity_penalty=2.0,
)
model = model_class(config).to(torch_device).eval()
# check `generate()` and `group_beam_search()` are equal
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model,
......@@ -943,16 +1001,22 @@ class GenerationTesterMixin:
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
diversity_penalty=2.0,
)
num_return_sequences = 1
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
......
......@@ -46,6 +46,7 @@ class SimpleSummarizationPipelineTests(unittest.TestCase):
decoder_attention_heads=1,
max_length=4,
min_length=1,
forced_eos_token_id=None,
)
model = BartForConditionalGeneration(config)
# Bias output towards L
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment