Unverified Commit d5334651 authored by Guillaume "Vermeille" Sanchez's avatar Guillaume "Vermeille" Sanchez Committed by GitHub
Browse files

add CFG for .generate() (#24654)

parent a6e6b1c6
...@@ -65,6 +65,7 @@ else: ...@@ -65,6 +65,7 @@ else:
"EncoderNoRepeatNGramLogitsProcessor", "EncoderNoRepeatNGramLogitsProcessor",
"ExponentialDecayLengthPenalty", "ExponentialDecayLengthPenalty",
"LogitNormalization", "LogitNormalization",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
] ]
_import_structure["stopping_criteria"] = [ _import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria", "MaxNewTokensCriteria",
...@@ -188,6 +189,7 @@ if TYPE_CHECKING: ...@@ -188,6 +189,7 @@ if TYPE_CHECKING:
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper, TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
) )
from .stopping_criteria import ( from .stopping_criteria import (
MaxLengthCriteria, MaxLengthCriteria,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import inspect import inspect
import math import math
from typing import Callable, Dict, Iterable, List, Tuple, Union from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -1334,3 +1334,119 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor): ...@@ -1334,3 +1334,119 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf") scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")
return scores return scores
class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""Logits processor for Classifier-Free Guidance (CFG). The processors
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
the `unconditional_ids` branch.
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
Args:
guidance_scale (`float`):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
the last token of the prompt.
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, **optional**):
Attention mask for unconditional_ids.
model (`PreTrainedModel`):
The model computing the unconditional scores. Supposedly the same as the one computing the conditional
scores. Both models must use the same tokenizer.
smooth_factor (`float`, **optional**):
The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without
CFG. Turn it lower if the output degenerates.
use_cache (`bool`, **optional**):
Whether to cache key/values during the negative prompt forward pass.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
The dragon flew over Paris, France, landing in Lyon, a city of a few million. Dragon-flying was a new form of
transport, and the dragon was the first in Europe.
>>> # with a negative prompt
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
people and injuring more than 350.
```
"""
def __init__(
self,
guidance_scale: float,
model,
unconditional_ids: Optional[torch.LongTensor] = None,
unconditional_attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = True,
):
self.guidance_scale = guidance_scale
self.model = model
self.unconditional_context = {
"input_ids": unconditional_ids,
"attention_mask": unconditional_attention_mask,
"use_cache": use_cache,
"past_key_values": None,
"first_pass": True,
}
def get_unconditional_logits(self, input_ids):
if self.unconditional_context["first_pass"]:
if self.unconditional_context["input_ids"] is None:
self.unconditional_context["input_ids"] = input_ids[:, -1:]
if self.unconditional_context["attention_mask"] is None:
self.unconditional_context["attention_mask"] = torch.ones_like(
self.unconditional_context["input_ids"], dtype=torch.long
)
input_ids = self.unconditional_context["input_ids"]
attention_mask = self.unconditional_context["attention_mask"]
self.unconditional_context["first_pass"] = False
else:
attention_mask = torch.cat(
[
self.unconditional_context["attention_mask"],
torch.ones_like(input_ids[:, -1:], dtype=torch.long),
],
dim=1,
)
if not self.unconditional_context["use_cache"]:
input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
else:
input_ids = input_ids[:, -1:]
self.unconditional_context["input_ids"] = input_ids
self.unconditional_context["attention_mask"] = attention_mask
out = self.model(
input_ids,
attention_mask=attention_mask,
use_cache=self.unconditional_context["use_cache"],
past_key_values=self.unconditional_context["past_key_values"],
)
self.unconditional_context["past_key_values"] = out.get("past_key_values", None)
return out.logits
def __call__(self, input_ids, scores):
scores = torch.nn.functional.log_softmax(scores, dim=-1)
if self.guidance_scale == 1:
return scores
logits = self.get_unconditional_logits(input_ids)
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
return out
...@@ -38,7 +38,6 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint ...@@ -38,7 +38,6 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig from .configuration_utils import GenerationConfig
from .logits_process import ( from .logits_process import (
ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
...@@ -64,6 +63,7 @@ from .logits_process import ( ...@@ -64,6 +63,7 @@ from .logits_process import (
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper, TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
) )
from .stopping_criteria import ( from .stopping_criteria import (
MaxLengthCriteria, MaxLengthCriteria,
...@@ -893,6 +893,9 @@ class GenerationMixin: ...@@ -893,6 +893,9 @@ class GenerationMixin:
encoder_input_ids: torch.LongTensor, encoder_input_ids: torch.LongTensor,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
logits_processor: Optional[LogitsProcessorList], logits_processor: Optional[LogitsProcessorList],
model_kwargs: Optional[Dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
...@@ -901,6 +904,16 @@ class GenerationMixin: ...@@ -901,6 +904,16 @@ class GenerationMixin:
# instantiate processors list # instantiate processors list
processors = LogitsProcessorList() processors = LogitsProcessorList()
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
processors.append(
UnbatchedClassifierFreeGuidanceLogitsProcessor(
generation_config.guidance_scale,
self,
unconditional_ids=negative_prompt_ids,
unconditional_attention_mask=negative_prompt_attention_mask,
use_cache=model_kwargs["use_cache"],
)
)
if generation_config.sequence_bias is not None: if generation_config.sequence_bias is not None:
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
...@@ -998,8 +1011,6 @@ class GenerationMixin: ...@@ -998,8 +1011,6 @@ class GenerationMixin:
) )
if generation_config.forced_decoder_ids is not None: if generation_config.forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
processors = self._merge_criteria_processor_list(processors, logits_processor) processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present # `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True: if generation_config.renormalize_logits is True:
...@@ -1251,6 +1262,8 @@ class GenerationMixin: ...@@ -1251,6 +1262,8 @@ class GenerationMixin:
synced_gpus: Optional[bool] = None, synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None, assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
r""" r"""
...@@ -1308,6 +1321,11 @@ class GenerationMixin: ...@@ -1308,6 +1321,11 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*): streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
The negative prompt needed for some processors such as CFG. The batch size must match the input batch
size. This is an experimental feature, subject to breaking API changes in future versions.
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Attention_mask for `negative_prompt_ids`.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
...@@ -1511,6 +1529,9 @@ class GenerationMixin: ...@@ -1511,6 +1529,9 @@ class GenerationMixin:
encoder_input_ids=inputs_tensor, encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor, logits_processor=logits_processor,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
) )
# 9. prepare stopping criteria # 9. prepare stopping criteria
......
...@@ -51,6 +51,7 @@ if is_torch_available(): ...@@ -51,6 +51,7 @@ if is_torch_available():
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper, TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
) )
...@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones)) self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1))) self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
def test_classifier_free_guidance(self):
class Namespace(dict):
pass
logits_uncond = torch.tensor([[[1.0, 0, 1.5]]])
logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]])
def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None):
out = Namespace()
out.logits = logits_uncond
out.past_key_values = None
return out
def lsm(x):
return torch.nn.functional.log_softmax(x, dim=-1)
# explicit unconditional prompt + attention mask
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(
1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long)
)
out = cfg(input_ids, logits_cond)[0, -1]
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())
# explicit unconditional prompt
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids)
out = cfg(input_ids, logits_cond)[0, -1]
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())
# all implicit
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model)
out = cfg(input_ids, logits_cond)[0, -1]
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())
...@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
], ],
) )
@slow
def test_cfg_mixin(self):
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
input["input_ids"] = input["input_ids"].to(torch_device)
input["attention_mask"] = input["attention_mask"].to(torch_device)
outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
'that they had to leave the city.\n\n"We\'re going to Paris!"\n'
],
)
neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True)
neg["input_ids"] = neg["input_ids"].to(torch_device)
neg["attention_mask"] = neg["attention_mask"].to(torch_device)
outputs = model.generate(
**input,
max_new_tokens=32,
guidance_scale=1.5,
negative_prompt_ids=neg["input_ids"],
negative_prompt_attention_mask=neg["attention_mask"],
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"'
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n'
],
)
@slow @slow
def test_constrained_beam_search_example_translation_mixin(self): def test_constrained_beam_search_example_translation_mixin(self):
# PT-only test: TF doesn't have constrained beam search # PT-only test: TF doesn't have constrained beam search
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment