Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
700d48fb
Unverified
Commit
700d48fb
authored
Mar 06, 2024
by
Joao Gante
Committed by
GitHub
Mar 06, 2024
Browse files
Generate: get generation mode from the generation config instance 🧼 (#29441)
parent
41f7b7ae
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
65 deletions
+103
-65
docs/source/en/main_classes/text_generation.md
docs/source/en/main_classes/text_generation.md
+3
-0
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+2
-2
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+77
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+3
-60
tests/generation/test_configuration_utils.py
tests/generation/test_configuration_utils.py
+18
-0
No files found.
docs/source/en/main_classes/text_generation.md
View file @
700d48fb
...
...
@@ -37,6 +37,9 @@ like token streaming.
-
from_pretrained
-
from_model_config
-
save_pretrained
-
update
-
validate
-
get_generation_mode
## GenerationMixin
...
...
src/transformers/generation/__init__.py
View file @
700d48fb
...
...
@@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
_import_structure
=
{
"configuration_utils"
:
[
"GenerationConfig"
],
"configuration_utils"
:
[
"GenerationConfig"
,
"GenerationMode"
],
"streamers"
:
[
"TextIteratorStreamer"
,
"TextStreamer"
],
}
...
...
@@ -172,7 +172,7 @@ else:
]
if
TYPE_CHECKING
:
from
.configuration_utils
import
GenerationConfig
from
.configuration_utils
import
GenerationConfig
,
GenerationMode
from
.streamers
import
TextIteratorStreamer
,
TextStreamer
try
:
...
...
src/transformers/generation/configuration_utils.py
View file @
700d48fb
...
...
@@ -18,12 +18,13 @@ import copy
import
json
import
os
import
warnings
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Union
from
..
import
__version__
from
..configuration_utils
import
PretrainedConfig
from
..utils
import
(
GENERATION_CONFIG_NAME
,
ExplicitEnum
,
PushToHubMixin
,
cached_file
,
download_url
,
...
...
@@ -33,10 +34,31 @@ from ..utils import (
)
if
TYPE_CHECKING
:
from
..modeling_utils
import
PreTrainedModel
logger
=
logging
.
get_logger
(
__name__
)
METADATA_FIELDS
=
(
"_from_model_config"
,
"_commit_hash"
,
"_original_object_hash"
,
"transformers_version"
)
class
GenerationMode
(
ExplicitEnum
):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
# Non-beam methods
CONTRASTIVE_SEARCH
=
"contrastive_search"
GREEDY_SEARCH
=
"greedy_search"
SAMPLE
=
"sample"
ASSISTED_GENERATION
=
"assisted_generation"
# Beam methods
BEAM_SEARCH
=
"beam_search"
BEAM_SAMPLE
=
"beam_sample"
CONSTRAINED_BEAM_SEARCH
=
"constrained_beam_search"
GROUP_BEAM_SEARCH
=
"group_beam_search"
class
GenerationConfig
(
PushToHubMixin
):
# no-format
r
"""
...
...
@@ -376,13 +398,65 @@ class GenerationConfig(PushToHubMixin):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
(
ignore_metadata
=
True
)
}
"
def
get_generation_mode
(
self
,
assistant_model
:
Optional
[
"PreTrainedModel"
]
=
None
)
->
GenerationMode
:
"""
Returns the generation mode triggered by the [`GenerationConfig`] instance.
Arg:
assistant_model (`PreTrainedModel`, *optional*):
The assistant model to be used for assisted generation. If set, the generation mode will be
assisted generation.
Returns:
`GenerationMode`: The generation mode triggered by the instance.
"""
# TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a
# property and part of the `__repr__`
if
self
.
constraints
is
not
None
or
self
.
force_words_ids
is
not
None
:
generation_mode
=
GenerationMode
.
CONSTRAINED_BEAM_SEARCH
elif
self
.
num_beams
==
1
:
if
self
.
do_sample
is
False
:
if
(
self
.
top_k
is
not
None
and
self
.
top_k
>
1
and
self
.
penalty_alpha
is
not
None
and
self
.
penalty_alpha
>
0
):
generation_mode
=
GenerationMode
.
CONTRASTIVE_SEARCH
else
:
generation_mode
=
GenerationMode
.
GREEDY_SEARCH
else
:
generation_mode
=
GenerationMode
.
SAMPLE
else
:
if
self
.
num_beam_groups
>
1
:
generation_mode
=
GenerationMode
.
GROUP_BEAM_SEARCH
elif
self
.
do_sample
is
True
:
generation_mode
=
GenerationMode
.
BEAM_SAMPLE
else
:
generation_mode
=
GenerationMode
.
BEAM_SEARCH
# Assisted generation may extend some generation modes
if
assistant_model
is
not
None
or
self
.
prompt_lookup_num_tokens
is
not
None
:
if
generation_mode
in
(
"greedy_search"
,
"sample"
):
generation_mode
=
GenerationMode
.
ASSISTED_GENERATION
else
:
raise
ValueError
(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
return
generation_mode
def
validate
(
self
,
is_init
=
False
):
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
model, such as parameters related to the generation length.
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
other inputs and/or the model, such as parameters related to the generation length.
Arg:
is_init (`bool`, *optional*, defaults to `False`):
Whether the validation is performed during the initialization of the instance.
"""
# Validation of individual attributes
...
...
src/transformers/generation/utils.py
View file @
700d48fb
...
...
@@ -34,7 +34,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
..utils
import
ExplicitEnum
,
ModelOutput
,
is_accelerate_available
,
logging
from
..utils
import
ModelOutput
,
is_accelerate_available
,
logging
from
.beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
from
.beam_search
import
BeamScorer
,
BeamSearchScorer
,
ConstrainedBeamSearchScorer
from
.candidate_generator
import
(
...
...
@@ -45,7 +45,7 @@ from .candidate_generator import (
_prepare_attention_mask
,
_prepare_token_type_ids
,
)
from
.configuration_utils
import
GenerationConfig
from
.configuration_utils
import
GenerationConfig
,
GenerationMode
from
.logits_process
import
(
EncoderNoRepeatNGramLogitsProcessor
,
EncoderRepetitionPenaltyLogitsProcessor
,
...
...
@@ -325,23 +325,6 @@ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDec
GenerateOutput
=
Union
[
GenerateNonBeamOutput
,
GenerateBeamOutput
]
class
GenerationMode
(
ExplicitEnum
):
"""
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
"""
# Non-beam methods
CONTRASTIVE_SEARCH
=
"contrastive_search"
GREEDY_SEARCH
=
"greedy_search"
SAMPLE
=
"sample"
ASSISTED_GENERATION
=
"assisted_generation"
# Beam methods
BEAM_SEARCH
=
"beam_search"
BEAM_SAMPLE
=
"beam_sample"
CONSTRAINED_BEAM_SEARCH
=
"constrained_beam_search"
GROUP_BEAM_SEARCH
=
"group_beam_search"
class
GenerationMixin
:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
...
...
@@ -764,46 +747,6 @@ class GenerationMixin:
warpers
.
append
(
LogitNormalization
())
return
warpers
def
_get_generation_mode
(
self
,
generation_config
:
GenerationConfig
,
assistant_model
:
Optional
[
"PreTrainedModel"
]
)
->
GenerationMode
:
"""
Returns the generation mode triggered by a [`GenerationConfig`] instance.
"""
if
generation_config
.
constraints
is
not
None
or
generation_config
.
force_words_ids
is
not
None
:
generation_mode
=
GenerationMode
.
CONSTRAINED_BEAM_SEARCH
elif
generation_config
.
num_beams
==
1
:
if
generation_config
.
do_sample
is
False
:
if
(
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
>
1
and
generation_config
.
penalty_alpha
is
not
None
and
generation_config
.
penalty_alpha
>
0
):
generation_mode
=
GenerationMode
.
CONTRASTIVE_SEARCH
else
:
generation_mode
=
GenerationMode
.
GREEDY_SEARCH
else
:
generation_mode
=
GenerationMode
.
SAMPLE
else
:
if
generation_config
.
num_beam_groups
>
1
:
generation_mode
=
GenerationMode
.
GROUP_BEAM_SEARCH
elif
generation_config
.
do_sample
is
True
:
generation_mode
=
GenerationMode
.
BEAM_SAMPLE
else
:
generation_mode
=
GenerationMode
.
BEAM_SEARCH
# Assisted generation may extend some generation modes
if
assistant_model
is
not
None
or
generation_config
.
prompt_lookup_num_tokens
is
not
None
:
if
generation_mode
in
(
"greedy_search"
,
"sample"
):
generation_mode
=
GenerationMode
.
ASSISTED_GENERATION
else
:
raise
ValueError
(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)
return
generation_mode
def
_get_logits_processor
(
self
,
generation_config
:
GenerationConfig
,
...
...
@@ -1474,7 +1417,7 @@ class GenerationMixin:
self
.
_validate_generated_length
(
generation_config
,
input_ids_length
,
has_default_max_length
)
# 7. determine generation mode
generation_mode
=
self
.
_get_
generation_
mode
(
generation_
config
,
assistant_model
)
generation_mode
=
generation_
config
.
get_
generation_
mode
(
assistant_model
)
if
streamer
is
not
None
and
(
generation_config
.
num_beams
>
1
):
raise
ValueError
(
...
...
tests/generation/test_configuration_utils.py
View file @
700d48fb
...
...
@@ -24,6 +24,7 @@ from parameterized import parameterized
from
requests.exceptions
import
HTTPError
from
transformers
import
AutoConfig
,
GenerationConfig
from
transformers.generation
import
GenerationMode
from
transformers.testing_utils
import
TOKEN
,
USER
,
is_staging_test
...
...
@@ -202,6 +203,23 @@ class GenerationConfigTest(unittest.TestCase):
self
.
assertEqual
(
len
(
captured_warnings
),
0
)
self
.
assertTrue
(
len
(
os
.
listdir
(
tmp_dir
))
==
1
)
def
test_generation_mode
(
self
):
"""Tests that the `get_generation_mode` method is working as expected."""
config
=
GenerationConfig
()
self
.
assertEqual
(
config
.
get_generation_mode
(),
GenerationMode
.
GREEDY_SEARCH
)
config
=
GenerationConfig
(
do_sample
=
True
)
self
.
assertEqual
(
config
.
get_generation_mode
(),
GenerationMode
.
SAMPLE
)
config
=
GenerationConfig
(
num_beams
=
2
)
self
.
assertEqual
(
config
.
get_generation_mode
(),
GenerationMode
.
BEAM_SEARCH
)
config
=
GenerationConfig
(
top_k
=
10
,
do_sample
=
False
,
penalty_alpha
=
0.6
)
self
.
assertEqual
(
config
.
get_generation_mode
(),
GenerationMode
.
CONTRASTIVE_SEARCH
)
config
=
GenerationConfig
()
self
.
assertEqual
(
config
.
get_generation_mode
(
assistant_model
=
"foo"
),
GenerationMode
.
ASSISTED_GENERATION
)
@
is_staging_test
class
ConfigPushToHubTester
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment