Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bff4313b
Unverified
Commit
bff4313b
authored
Aug 04, 2023
by
Joao Gante
Committed by
GitHub
Aug 04, 2023
Browse files
Generate: get generation mode as an enum (#25292)
parent
fab1a0aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
65 deletions
+68
-65
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+68
-65
No files found.
src/transformers/generation/utils.py
View file @
bff4313b
...
...
@@ -33,7 +33,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
..utils
import
ModelOutput
,
logging
from
..utils
import
ExplicitEnum
,
ModelOutput
,
logging
from
.beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
from
.beam_search
import
BeamScorer
,
BeamSearchScorer
,
ConstrainedBeamSearchScorer
from
.configuration_utils
import
GenerationConfig
...
...
@@ -468,6 +468,23 @@ ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, Contrasti
GenerateOutput
=
Union
[
GreedySearchOutput
,
SampleOutput
,
BeamSearchOutput
,
BeamSampleOutput
,
ContrastiveSearchOutput
]
class
GenerationMode
(
ExplicitEnum
):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
# Non-beam methods
CONTRASTIVE_SEARCH
=
"contrastive_search"
GREEDY_SEARCH
=
"greedy_search"
SAMPLE
=
"sample"
ASSISTED_GENERATION
=
"assisted_generation"
# Beam methods
BEAM_SEARCH
=
"beam_search"
BEAM_SAMPLE
=
"beam_sample"
CONSTRAINED_BEAM_SEARCH
=
"constrained_beam_search"
GROUP_BEAM_SEARCH
=
"group_beam_search"
class
GenerationMixin
:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
...
...
@@ -829,6 +846,46 @@ class GenerationMixin:
warpers
.
append
(
LogitNormalization
())
return
warpers
def
_get_generation_mode
(
self
,
generation_config
:
GenerationConfig
,
assistant_model
:
Optional
[
"PreTrainedModel"
]
)
->
GenerationMode
:
"""
Returns the generation mode triggered by a [`GenerationConfig`] instance.
"""
if
generation_config
.
constraints
is
not
None
or
generation_config
.
force_words_ids
is
not
None
:
generation_mode
=
GenerationMode
.
CONSTRAINED_BEAM_SEARCH
elif
generation_config
.
num_beams
==
1
:
if
generation_config
.
do_sample
is
False
:
if
(
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
>
1
and
generation_config
.
penalty_alpha
is
not
None
and
generation_config
.
penalty_alpha
>
0
):
generation_mode
=
GenerationMode
.
CONTRASTIVE_SEARCH
else
:
generation_mode
=
GenerationMode
.
GREEDY_SEARCH
else
:
generation_mode
=
GenerationMode
.
SAMPLE
else
:
if
generation_config
.
num_beam_groups
>
1
:
generation_mode
=
GenerationMode
.
GROUP_BEAM_SEARCH
elif
generation_config
.
do_sample
is
True
:
generation_mode
=
GenerationMode
.
BEAM_SAMPLE
else
:
generation_mode
=
GenerationMode
.
BEAM_SEARCH
# Assisted generation may extend some generation modes
if
assistant_model
is
not
None
:
if
generation_mode
in
(
"greedy_search"
,
"sample"
):
generation_mode
=
GenerationMode
.
ASSISTED_GENERATION
else
:
raise
ValueError
(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
return
generation_mode
def
_get_logits_processor
(
self
,
generation_config
:
GenerationConfig
,
...
...
@@ -1422,65 +1479,11 @@ class GenerationMixin:
)
# 7. determine generation mode
is_constraint_gen_mode
=
(
generation_config
.
constraints
is
not
None
or
generation_config
.
force_words_ids
is
not
None
)
is_contrastive_search_gen_mode
=
(
(
generation_config
.
num_beams
==
1
)
and
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
>
1
and
generation_config
.
do_sample
is
False
and
generation_config
.
penalty_alpha
is
not
None
and
generation_config
.
penalty_alpha
>
0
)
is_greedy_gen_mode
=
(
(
generation_config
.
num_beams
==
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
False
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_sample_gen_mode
=
(
(
generation_config
.
num_beams
==
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
True
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_beam_gen_mode
=
(
(
generation_config
.
num_beams
>
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
False
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode
=
(
(
generation_config
.
num_beams
>
1
)
and
(
generation_config
.
num_beam_groups
==
1
)
and
generation_config
.
do_sample
is
True
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_group_beam_gen_mode
=
(
(
generation_config
.
num_beams
>
1
)
and
(
generation_config
.
num_beam_groups
>
1
)
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_assisted_gen_mode
=
False
if
assistant_model
is
not
None
:
if
not
(
is_greedy_gen_mode
or
is_sample_gen_mode
):
raise
ValueError
(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
is_assisted_gen_mode
=
True
generation_mode
=
self
.
_get_generation_mode
(
generation_config
,
assistant_model
)
if
generation_config
.
num_beam_groups
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_beam_groups` has to be smaller or equal to `num_beams`"
)
if
is_group_beam_gen_mode
and
generation_config
.
do_sample
is
True
:
if
generation_mode
==
GenerationMode
.
GROUP_BEAM_SEARCH
and
generation_config
.
do_sample
is
True
:
raise
ValueError
(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
...
...
@@ -1515,7 +1518,7 @@ class GenerationMixin:
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
# 10. go into different generation modes
if
is_assisted_gen_mode
:
if
generation_mode
==
GenerationMode
.
ASSISTED_GENERATION
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing assisted generate, "
...
...
@@ -1553,7 +1556,7 @@ class GenerationMixin:
streamer
=
streamer
,
**
model_kwargs
,
)
if
is_greedy_gen_mode
:
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing greedy search, "
...
...
@@ -1574,7 +1577,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_cont
ra
s
ti
ve_search_gen_mode
:
elif
gene
rati
on_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing contrastive search, "
...
...
@@ -1599,7 +1602,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_sample_gen_mode
:
elif
generation_mode
==
GenerationMode
.
SAMPLE
:
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
...
...
@@ -1626,7 +1629,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_beam_gen_mode
:
elif
generation_mode
==
GenerationMode
.
BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
...
...
@@ -1664,7 +1667,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_beam_sample_gen_mode
:
elif
generation_mode
==
GenerationMode
.
BEAM_SAMPLE
:
# 11. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
...
...
@@ -1703,7 +1706,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_group_beam_gen_mode
:
elif
generation_mode
==
GenerationMode
.
GROUP_BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
...
...
@@ -1754,7 +1757,7 @@ class GenerationMixin:
**
model_kwargs
,
)
elif
is_constraint_gen_mode
:
elif
generation_mode
==
GenerationMode
.
CONSTRAINED_BEAM_SEARCH
:
if
generation_config
.
num_return_sequences
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment