Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ddb4fda3
Unverified
Commit
ddb4fda3
authored
Mar 06, 2024
by
Joao Gante
Committed by
GitHub
Mar 06, 2024
Browse files
Generate: torch.compile-ready generation config preparation (#29443)
parent
9322576e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
33 deletions
+60
-33
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+59
-33
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+1
-0
No files found.
src/transformers/generation/utils.py
View file @
ddb4fda3
...
...
@@ -34,7 +34,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
..utils
import
ModelOutput
,
is_accelerate_available
,
logging
from
..utils
import
ModelOutput
,
is_accelerate_available
,
is_torchdynamo_compiling
,
logging
from
.beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
from
.beam_search
import
BeamScorer
,
BeamSearchScorer
,
ConstrainedBeamSearchScorer
from
.candidate_generator
import
(
...
...
@@ -1162,6 +1162,59 @@ class GenerationMixin:
UserWarning
,
)
def
_prepare_generation_config
(
self
,
generation_config
:
GenerationConfig
,
**
kwargs
:
Dict
)
->
Tuple
[
GenerationConfig
,
Dict
]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs.
"""
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if
(
not
is_torchdynamo_compiling
()
and
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
)
and
self
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
if
is_torchdynamo_compiling
():
model_kwargs
=
kwargs
generate_attributes_in_kwargs
=
[
key
for
key
,
value
in
kwargs
.
items
()
if
getattr
(
generation_config
,
key
,
None
)
!=
value
]
if
len
(
generate_attributes_in_kwargs
)
>
0
:
raise
ValueError
(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
f
"`generation_config` instance passed to `generate` (found:
{
generate_attributes_in_kwargs
}
)."
)
else
:
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
return
generation_config
,
model_kwargs
@
torch
.
no_grad
()
def
generate
(
self
,
...
...
@@ -1260,44 +1313,17 @@ class GenerationMixin:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self
.
_validate_model_class
()
generation_config
,
model_kwargs
=
self
.
_prepare_generation_config
(
generation_config
,
**
kwargs
)
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Set generation parameters if not already defined
if
synced_gpus
is
None
:
if
is_deepspeed_zero3_enabled
()
and
dist
.
get_world_size
()
>
1
:
synced_gpus
=
True
else
:
synced_gpus
=
False
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self
.
_validate_model_class
()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
if
(
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
)
and
self
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 2. Set generation parameters if not already defined
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
...
...
src/transformers/utils/__init__.py
View file @
ddb4fda3
...
...
@@ -193,6 +193,7 @@ from .import_utils import (
is_torchaudio_available
,
is_torchdistx_available
,
is_torchdynamo_available
,
is_torchdynamo_compiling
,
is_torchvision_available
,
is_training_run_on_sagemaker
,
is_vision_available
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment