Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ddb4fda3
Unverified
Commit
ddb4fda3
authored
Mar 06, 2024
by
Joao Gante
Committed by
GitHub
Mar 06, 2024
Browse files
Generate: torch.compile-ready generation config preparation (#29443)
parent
9322576e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
33 deletions
+60
-33
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+59
-33
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+1
-0
No files found.
src/transformers/generation/utils.py
View file @
ddb4fda3
...
...
@@ -34,7 +34,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
..utils
import
ModelOutput
,
is_accelerate_available
,
logging
from
..utils
import
ModelOutput
,
is_accelerate_available
,
is_torchdynamo_compiling
,
logging
from
.beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
from
.beam_search
import
BeamScorer
,
BeamSearchScorer
,
ConstrainedBeamSearchScorer
from
.candidate_generator
import
(
...
...
@@ -1162,6 +1162,59 @@ class GenerationMixin:
UserWarning
,
)
def
_prepare_generation_config
(
self
,
generation_config
:
GenerationConfig
,
**
kwargs
:
Dict
)
->
Tuple
[
GenerationConfig
,
Dict
]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs.
"""
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if
(
not
is_torchdynamo_compiling
()
and
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
)
and
self
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
if
is_torchdynamo_compiling
():
model_kwargs
=
kwargs
generate_attributes_in_kwargs
=
[
key
for
key
,
value
in
kwargs
.
items
()
if
getattr
(
generation_config
,
key
,
None
)
!=
value
]
if
len
(
generate_attributes_in_kwargs
)
>
0
:
raise
ValueError
(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
f
"`generation_config` instance passed to `generate` (found:
{
generate_attributes_in_kwargs
}
)."
)
else
:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
return
generation_config
,
model_kwargs
@
torch
.
no_grad
()
def
generate
(
self
,
...
...
@@ -1260,44 +1313,17 @@ class GenerationMixin:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self
.
_validate_model_class
()
generation_config
,
model_kwargs
=
self
.
_prepare_generation_config
(
generation_config
,
**
kwargs
)
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Set generation parameters if not already defined
if
synced_gpus
is
None
:
if
is_deepspeed_zero3_enabled
()
and
dist
.
get_world_size
()
>
1
:
synced_gpus
=
True
else
:
synced_gpus
=
False
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self
.
_validate_model_class
()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
if
(
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
)
and
self
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Set generation parameters if not already defined
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
...
...
src/transformers/utils/__init__.py
View file @
ddb4fda3
...
...
@@ -193,6 +193,7 @@ from .import_utils import (
is_torchaudio_available
,
is_torchdistx_available
,
is_torchdynamo_available
,
is_torchdynamo_compiling
,
is_torchvision_available
,
is_training_run_on_sagemaker
,
is_vision_available
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment