Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bff4313b
Unverified
Commit
bff4313b
authored
Aug 04, 2023
by
Joao Gante
Committed by
GitHub
Aug 04, 2023
Browse files
Generate: get generation mode as an enum (#25292)
parent
fab1a0aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
65 deletions
+68
-65
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+68
-65
No files found.
src/transformers/generation/utils.py
View file @
bff4313b
...
...
@@ -33,7 +33,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
..utils
import
ModelOutput
,
logging
from
..utils
import
ExplicitEnum
,
ModelOutput
,
logging
from
.beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
from
.beam_search
import
BeamScorer
,
BeamSearchScorer
,
ConstrainedBeamSearchScorer
from
.configuration_utils
import
GenerationConfig
...
...
@@ -468,6 +468,23 @@ ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, Contrasti
GenerateOutput
=
Union
[
GreedySearchOutput
,
SampleOutput
,
BeamSearchOutput
,
BeamSampleOutput
,
ContrastiveSearchOutput
]
class
GenerationMode
(
ExplicitEnum
):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
# Non-beam methods
CONTRASTIVE_SEARCH
=
"contrastive_search"
GREEDY_SEARCH
=
"greedy_search"
SAMPLE
=
"sample"
ASSISTED_GENERATION
=
"assisted_generation"
# Beam methods
BEAM_SEARCH
=
"beam_search"
BEAM_SAMPLE
=
"beam_sample"
CONSTRAINED_BEAM_SEARCH
=
"constrained_beam_search"
GROUP_BEAM_SEARCH
=
"group_beam_search"
class
GenerationMixin
:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
...
...
@@ -829,6 +846,46 @@ class GenerationMixin:
warpers
.
append
(
LogitNormalization
())
return
warpers
def
_get_generation_mode
(
self
,
generation_config
:
GenerationConfig
,
assistant_model
:
Optional
[
"PreTrainedModel"
]
)
->
GenerationMode
:
"""
Returns the generation mode triggered by a [`GenerationConfig`] instance.
"""
if
generation_config
.
constraints
is
not
None
or
generation_config
.
force_words_ids
is
not
None
:
generation_mode
=
GenerationMode
.
CONSTRAINED_BEAM_SEARCH
elif
generation_config
.
num_beams
==
1
:
if
generation_config
.
do_sample
is
False
:
if
(
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
>
1
and
generation_config
.
penalty_alpha
is
not
None
and
generation_config
.
penalty_alpha
>
0
):
generation_mode
=
GenerationMode
.
CONTRASTIVE_SEARCH
else
:
generation_mode
=
GenerationMode
.
GREEDY_SEARCH
else
:
generation_mode
=
GenerationMode
.
SAMPLE
else
:
if
generation_config
.
num_beam_groups
>
1
:
generation_mode
=
GenerationMode
.
GROUP_BEAM_SEARCH
elif
generation_config
.
do_sample
is
True
:
generation_mode
=
GenerationMode
.
BEAM_SAMPLE
else
:
generation_mode
=
GenerationMode
.
BEAM_SEARCH
# Assisted generation may extend some generation modes
if
assistant_model
is
not
None
:
if
generation_mode
in
(
"greedy_search"
,
"sample"
):
generation_mode
=
GenerationMode
.
ASSISTED_GENERATION
else
:
raise
ValueError
(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
return
generation_mode
def
_get_logits_processor
(
self
,
generation_config
:
GenerationConfig
,
...
...
@@ -1422,65 +1479,11 @@ class GenerationMixin:
)
# 7. determine generation mode
is_constraint_gen_mode
=
(
generation_config
.
constraints
is
not
None
or
generation_config
.
force_words_ids
is
not
None
)
is_contrastive_search_gen_mode
=
(
(
generation_config
.
num_beams
==
1
)
and
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
>
1
and
generation_config
.
do_sample
is
False
and
generation_config
.
penalty_alpha
is
not
None
and
generation_config
.
penalty_alpha
>
0
)
is_greedy_gen_mode
=
(
(
generation_config
.
num_beams
==
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
False
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_sample_gen_mode
=
(
(
generation_config
.
num_beams
==
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
True
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_beam_gen_mode
=
(
(
generation_config
.
num_beams
>
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
False
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode
=
(
(
generation_config
.
num_beams
>
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
True
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_group_beam_gen_mode
=
(
(
generation_config
.
num_beams
>
1
)
and
(
generation_config
.
num_beam_groups
>
1
)
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_assisted_gen_mode
=
False
if
assistant_model
is
not
None
:
if
not
(
is_greedy_gen_mode
or
is_sample_gen_mode
):
raise
ValueError
(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
is_assisted_gen_mode
=
True
generation_mode
=
self
.
_get_generation_mode
(
generation_config
,
assistant_model
)
if
generation_config
.
num_beam_groups
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_beam_groups` has to be smaller or equal to `num_beams`"
)
if
is_group_beam_gen_mode
and
generation_config
.
do_sample
is
True
:
if
generation_mode
==
GenerationMode
.
GROUP_BEAM_SEARCH
and
generation_config
.
do_sample
is
True
:
raise
ValueError
(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
...
...
@@ -1515,7 +1518,7 @@ class GenerationMixin:
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
# 10. go into different generation modes
if
is_assisted_gen_mode
:
if
generation_mode
==
GenerationMode
.
ASSISTED_GENERATION
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing assisted generate, "
...
...
@@ -1553,7 +1556,7 @@ class GenerationMixin:
streamer
=
streamer
,
**
model_kwargs
,
)
if
is_greedy_gen_mode
:
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing greedy search, "
...
...
@@ -1574,7 +1577,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_cont
ra
s
ti
ve_search_gen_mode
:
elif
gene
rati
on_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing contrastive search, "
...
...
@@ -1599,7 +1602,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_sample_gen_mode
:
elif
generation_mode
==
GenerationMode
.
SAMPLE
:
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
...
...
@@ -1626,7 +1629,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_beam_gen_mode
:
elif
generation_mode
==
GenerationMode
.
BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
...
...
@@ -1664,7 +1667,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_beam_sample_gen_mode
:
elif
generation_mode
==
GenerationMode
.
BEAM_SAMPLE
:
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
...
...
@@ -1703,7 +1706,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_group_beam_gen_mode
:
elif
generation_mode
==
GenerationMode
.
GROUP_BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
...
...
@@ -1754,7 +1757,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_constraint_gen_mode
:
elif
generation_mode
==
GenerationMode
.
CONSTRAINED_BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment