Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d5334651
Unverified
Commit
d5334651
authored
Aug 06, 2023
by
Guillaume "Vermeille" Sanchez
Committed by
GitHub
Aug 06, 2023
Browse files
add CFG for .generate() (#24654)
parent
a6e6b1c6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
235 additions
and
4 deletions
+235
-4
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+2
-0
src/transformers/generation/logits_process.py
src/transformers/generation/logits_process.py
+117
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+24
-3
tests/generation/test_logits_process.py
tests/generation/test_logits_process.py
+52
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+40
-0
No files found.
src/transformers/generation/__init__.py
View file @
d5334651
...
...
@@ -65,6 +65,7 @@ else:
"EncoderNoRepeatNGramLogitsProcessor"
,
"ExponentialDecayLengthPenalty"
,
"LogitNormalization"
,
"UnbatchedClassifierFreeGuidanceLogitsProcessor"
,
]
_import_structure
[
"stopping_criteria"
]
=
[
"MaxNewTokensCriteria"
,
...
...
@@ -188,6 +189,7 @@ if TYPE_CHECKING:
TopKLogitsWarper
,
TopPLogitsWarper
,
TypicalLogitsWarper
,
UnbatchedClassifierFreeGuidanceLogitsProcessor
,
)
from
.stopping_criteria
import
(
MaxLengthCriteria
,
...
...
src/transformers/generation/logits_process.py
View file @
d5334651
...
...
@@ -15,7 +15,7 @@
import
inspect
import
math
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -1334,3 +1334,119 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
scores
[:,
:
self
.
semantic_vocab_size
+
self
.
codebook_size
]
=
-
float
(
"inf"
)
return
scores
class
UnbatchedClassifierFreeGuidanceLogitsProcessor
(
LogitsProcessor
):
r
"""Logits processor for Classifier-Free Guidance (CFG). The processors
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
the `unconditional_ids` branch.
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
Args:
guidance_scale (`float`):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
the last token of the prompt.
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, **optional**):
Attention mask for unconditional_ids.
model (`PreTrainedModel`):
The model computing the unconditional scores. Supposedly the same as the one computing the conditional
scores. Both models must use the same tokenizer.
smooth_factor (`float`, **optional**):
The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without
CFG. Turn it lower if the output degenerates.
use_cache (`bool`, **optional**):
Whether to cache key/values during the negative prompt forward pass.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
The dragon flew over Paris, France, landing in Lyon, a city of a few million. Dragon-flying was a new form of
transport, and the dragon was the first in Europe.
>>> # with a negative prompt
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
people and injuring more than 350.
```
"""
def
__init__
(
self
,
guidance_scale
:
float
,
model
,
unconditional_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
unconditional_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
True
,
):
self
.
guidance_scale
=
guidance_scale
self
.
model
=
model
self
.
unconditional_context
=
{
"input_ids"
:
unconditional_ids
,
"attention_mask"
:
unconditional_attention_mask
,
"use_cache"
:
use_cache
,
"past_key_values"
:
None
,
"first_pass"
:
True
,
}
def
get_unconditional_logits
(
self
,
input_ids
):
if
self
.
unconditional_context
[
"first_pass"
]:
if
self
.
unconditional_context
[
"input_ids"
]
is
None
:
self
.
unconditional_context
[
"input_ids"
]
=
input_ids
[:,
-
1
:]
if
self
.
unconditional_context
[
"attention_mask"
]
is
None
:
self
.
unconditional_context
[
"attention_mask"
]
=
torch
.
ones_like
(
self
.
unconditional_context
[
"input_ids"
],
dtype
=
torch
.
long
)
input_ids
=
self
.
unconditional_context
[
"input_ids"
]
attention_mask
=
self
.
unconditional_context
[
"attention_mask"
]
self
.
unconditional_context
[
"first_pass"
]
=
False
else
:
attention_mask
=
torch
.
cat
(
[
self
.
unconditional_context
[
"attention_mask"
],
torch
.
ones_like
(
input_ids
[:,
-
1
:],
dtype
=
torch
.
long
),
],
dim
=
1
,
)
if
not
self
.
unconditional_context
[
"use_cache"
]:
input_ids
=
torch
.
cat
([
self
.
unconditional_context
[
"input_ids"
],
input_ids
[:,
-
1
:]],
dim
=
1
)
else
:
input_ids
=
input_ids
[:,
-
1
:]
self
.
unconditional_context
[
"input_ids"
]
=
input_ids
self
.
unconditional_context
[
"attention_mask"
]
=
attention_mask
out
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
use_cache
=
self
.
unconditional_context
[
"use_cache"
],
past_key_values
=
self
.
unconditional_context
[
"past_key_values"
],
)
self
.
unconditional_context
[
"past_key_values"
]
=
out
.
get
(
"past_key_values"
,
None
)
return
out
.
logits
def
__call__
(
self
,
input_ids
,
scores
):
scores
=
torch
.
nn
.
functional
.
log_softmax
(
scores
,
dim
=-
1
)
if
self
.
guidance_scale
==
1
:
return
scores
logits
=
self
.
get_unconditional_logits
(
input_ids
)
unconditional_logits
=
torch
.
nn
.
functional
.
log_softmax
(
logits
[:,
-
1
],
dim
=-
1
)
out
=
self
.
guidance_scale
*
(
scores
-
unconditional_logits
)
+
unconditional_logits
return
out
src/transformers/generation/utils.py
View file @
d5334651
...
...
@@ -38,7 +38,6 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from
.beam_search
import
BeamScorer
,
BeamSearchScorer
,
ConstrainedBeamSearchScorer
from
.configuration_utils
import
GenerationConfig
from
.logits_process
import
(
ClassifierFreeGuidanceLogitsProcessor
,
EncoderNoRepeatNGramLogitsProcessor
,
EncoderRepetitionPenaltyLogitsProcessor
,
EpsilonLogitsWarper
,
...
...
@@ -64,6 +63,7 @@ from .logits_process import (
TopKLogitsWarper
,
TopPLogitsWarper
,
TypicalLogitsWarper
,
UnbatchedClassifierFreeGuidanceLogitsProcessor
,
)
from
.stopping_criteria
import
(
MaxLengthCriteria
,
...
...
@@ -893,6 +893,9 @@ class GenerationMixin:
encoder_input_ids
:
torch
.
LongTensor
,
prefix_allowed_tokens_fn
:
Callable
[[
int
,
torch
.
Tensor
],
List
[
int
]],
logits_processor
:
Optional
[
LogitsProcessorList
],
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
negative_prompt_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
negative_prompt_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
LogitsProcessorList
:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
...
...
@@ -901,6 +904,16 @@ class GenerationMixin:
# instantiate processors list
processors
=
LogitsProcessorList
()
if
generation_config
.
guidance_scale
is
not
None
and
generation_config
.
guidance_scale
>
1
:
processors
.
append
(
UnbatchedClassifierFreeGuidanceLogitsProcessor
(
generation_config
.
guidance_scale
,
self
,
unconditional_ids
=
negative_prompt_ids
,
unconditional_attention_mask
=
negative_prompt_attention_mask
,
use_cache
=
model_kwargs
[
"use_cache"
],
)
)
if
generation_config
.
sequence_bias
is
not
None
:
processors
.
append
(
SequenceBiasLogitsProcessor
(
sequence_bias
=
generation_config
.
sequence_bias
))
...
...
@@ -998,8 +1011,6 @@ class GenerationMixin:
)
if
generation_config
.
forced_decoder_ids
is
not
None
:
processors
.
append
(
ForceTokensLogitsProcessor
(
generation_config
.
forced_decoder_ids
))
if
generation_config
.
guidance_scale
is
not
None
and
generation_config
.
guidance_scale
>
1
:
processors
.
append
(
ClassifierFreeGuidanceLogitsProcessor
(
generation_config
.
guidance_scale
))
processors
=
self
.
_merge_criteria_processor_list
(
processors
,
logits_processor
)
# `LogitNormalization` should always be the last logit processor, when present
if
generation_config
.
renormalize_logits
is
True
:
...
...
@@ -1251,6 +1262,8 @@ class GenerationMixin:
synced_gpus
:
Optional
[
bool
]
=
None
,
assistant_model
:
Optional
[
"PreTrainedModel"
]
=
None
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
negative_prompt_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
negative_prompt_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
GenerateOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -1308,6 +1321,11 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
The negative prompt needed for some processors such as CFG. The batch size must match the input batch
size. This is an experimental feature, subject to breaking API changes in future versions.
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Attention_mask for `negative_prompt_ids`.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
...
...
@@ -1511,6 +1529,9 @@ class GenerationMixin:
encoder_input_ids
=
inputs_tensor
,
prefix_allowed_tokens_fn
=
prefix_allowed_tokens_fn
,
logits_processor
=
logits_processor
,
model_kwargs
=
model_kwargs
,
negative_prompt_ids
=
negative_prompt_ids
,
negative_prompt_attention_mask
=
negative_prompt_attention_mask
,
)
# 9. prepare stopping criteria
...
...
tests/generation/test_logits_process.py
View file @
d5334651
...
...
@@ -51,6 +51,7 @@ if is_torch_available():
TopKLogitsWarper
,
TopPLogitsWarper
,
TypicalLogitsWarper
,
UnbatchedClassifierFreeGuidanceLogitsProcessor
,
)
...
...
@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase):
self
.
assertTrue
(
normalized_scores
.
sum
(
dim
=-
1
).
allclose
(
ones
))
self
.
assertTrue
(
normalized_scores
.
allclose
(
scores
.
softmax
(
dim
=-
1
)))
def
test_classifier_free_guidance
(
self
):
class
Namespace
(
dict
):
pass
logits_uncond
=
torch
.
tensor
([[[
1.0
,
0
,
1.5
]]])
logits_cond
=
torch
.
tensor
([[[
1.0
,
1.0
,
1.0
]]])
def
dummy_model
(
input_ids
,
attention_mask
,
use_cache
=
True
,
past_key_values
=
None
):
out
=
Namespace
()
out
.
logits
=
logits_uncond
out
.
past_key_values
=
None
return
out
def
lsm
(
x
):
return
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=-
1
)
# explicit unconditional prompt + attention mask
input_ids
=
torch
.
LongTensor
([[
0
]])
cfg
=
UnbatchedClassifierFreeGuidanceLogitsProcessor
(
1.5
,
dummy_model
,
input_ids
,
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
long
)
)
out
=
cfg
(
input_ids
,
logits_cond
)[
0
,
-
1
]
res
=
(
lsm
(
logits_uncond
)
+
1.5
*
(
lsm
(
logits_cond
)
-
lsm
(
logits_uncond
)))[
0
,
-
1
]
self
.
assertAlmostEqual
(
out
[
0
].
item
(),
res
[
0
].
item
())
self
.
assertAlmostEqual
(
out
[
1
].
item
(),
res
[
1
].
item
())
self
.
assertAlmostEqual
(
out
[
2
].
item
(),
res
[
2
].
item
())
# explicit unconditional prompt
input_ids
=
torch
.
LongTensor
([[
0
]])
cfg
=
UnbatchedClassifierFreeGuidanceLogitsProcessor
(
1.5
,
dummy_model
,
input_ids
)
out
=
cfg
(
input_ids
,
logits_cond
)[
0
,
-
1
]
res
=
(
lsm
(
logits_uncond
)
+
1.5
*
(
lsm
(
logits_cond
)
-
lsm
(
logits_uncond
)))[
0
,
-
1
]
self
.
assertAlmostEqual
(
out
[
0
].
item
(),
res
[
0
].
item
())
self
.
assertAlmostEqual
(
out
[
1
].
item
(),
res
[
1
].
item
())
self
.
assertAlmostEqual
(
out
[
2
].
item
(),
res
[
2
].
item
())
# all implicit
input_ids
=
torch
.
LongTensor
([[
0
]])
cfg
=
UnbatchedClassifierFreeGuidanceLogitsProcessor
(
1.5
,
dummy_model
)
out
=
cfg
(
input_ids
,
logits_cond
)[
0
,
-
1
]
res
=
(
lsm
(
logits_uncond
)
+
1.5
*
(
lsm
(
logits_cond
)
-
lsm
(
logits_uncond
)))[
0
,
-
1
]
self
.
assertAlmostEqual
(
out
[
0
].
item
(),
res
[
0
].
item
())
self
.
assertAlmostEqual
(
out
[
1
].
item
(),
res
[
1
].
item
())
self
.
assertAlmostEqual
(
out
[
2
].
item
(),
res
[
2
].
item
())
tests/generation/test_utils.py
View file @
d5334651
...
...
@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
],
)
@
slow
def
test_cfg_mixin
(
self
):
model
=
GPT2LMHeadModel
.
from_pretrained
(
"gpt2"
).
to
(
torch_device
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input
=
tokenizer
([
"The dragon flew over Paris,"
],
return_tensors
=
"pt"
,
return_attention_mask
=
True
)
input
[
"input_ids"
]
=
input
[
"input_ids"
].
to
(
torch_device
)
input
[
"attention_mask"
]
=
input
[
"attention_mask"
].
to
(
torch_device
)
outputs
=
model
.
generate
(
**
input
,
max_new_tokens
=
32
,
guidance_scale
=
1.5
)
generated_text
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
'that they had to leave the city.
\n\n
"We
\'
re going to Paris!"
\n
'
],
)
neg
=
tokenizer
([
"France,"
],
return_tensors
=
"pt"
,
return_attention_mask
=
True
)
neg
[
"input_ids"
]
=
neg
[
"input_ids"
].
to
(
torch_device
)
neg
[
"attention_mask"
]
=
neg
[
"attention_mask"
].
to
(
torch_device
)
outputs
=
model
.
generate
(
**
input
,
max_new_tokens
=
32
,
guidance_scale
=
1.5
,
negative_prompt_ids
=
neg
[
"input_ids"
],
negative_prompt_attention_mask
=
neg
[
"attention_mask"
],
)
generated_text
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
generated_text
,
[
'The dragon flew over Paris, landing on the pavement.
\n\n
"Paris!"
\n\n
"Paris!"
\n\n
"'
'Paris!"
\n\n
"Paris!"
\n\n
"Paris!"
\n\n
'
],
)
@
slow
def
test_constrained_beam_search_example_translation_mixin
(
self
):
# PT-only test: TF doesn't have constrained beam search
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment