Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4bd36f18
Unverified
Commit
4bd36f18
authored
Sep 13, 2022
by
Joao Gante
Committed by
GitHub
Sep 13, 2022
Browse files
Generate: add model class validation (#18902)
parent
69df33f1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
106 additions
and
16 deletions
+106
-16
src/transformers/generation_flax_utils.py
src/transformers/generation_flax_utils.py
+31
-1
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+33
-7
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+35
-7
src/transformers/models/openai/modeling_openai.py
src/transformers/models/openai/modeling_openai.py
+4
-1
src/transformers/models/openai/modeling_tf_openai.py
src/transformers/models/openai/modeling_tf_openai.py
+3
-0
No files found.
src/transformers/generation_flax_utils.py
View file @
4bd36f18
...
...
@@ -36,6 +36,11 @@ from .generation_flax_logits_process import (
FlaxTopKLogitsWarper
,
FlaxTopPLogitsWarper
,
)
from
.models.auto
import
(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.utils
import
ModelOutput
,
logging
...
...
@@ -161,6 +166,30 @@ class FlaxGenerationMixin:
"""
return
logits
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args
=
[]
...
...
@@ -281,7 +310,8 @@ class FlaxGenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# Validate model kwargs
# Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
...
...
src/transformers/generation_tf_utils.py
View file @
4bd36f18
...
...
@@ -35,6 +35,12 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
)
from
.models.auto
import
(
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.tf_utils
import
shape_list
,
stable_softmax
from
.utils
import
ModelOutput
,
logging
...
...
@@ -357,12 +363,6 @@ class TFGenerationMixin:
supports_xla_generation
=
True
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
"""
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return
{
"input_ids"
:
inputs
}
def
_use_cache
(
self
,
outputs
,
use_cache
):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
use_cache
=
getattr
(
self
.
config
,
"use_cache"
,
False
)
...
...
@@ -1290,6 +1290,31 @@ class TFGenerationMixin:
else
:
return
logits
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
...
...
@@ -1508,7 +1533,8 @@ class TFGenerationMixin:
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
...
...
src/transformers/generation_utils.py
View file @
4bd36f18
...
...
@@ -51,6 +51,13 @@ from .generation_stopping_criteria import (
StoppingCriteriaList
,
validate_stopping_criteria
,
)
from
.models.auto
import
(
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.pytorch_utils
import
torch_int_div
from
.utils
import
ModelOutput
,
logging
...
...
@@ -463,12 +470,6 @@ class GenerationMixin:
return
can_retrieve_inputs
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return
{
"input_ids"
:
input_ids
}
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
...
...
@@ -840,6 +841,32 @@ class GenerationMixin:
return
transition_scores
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
...
...
@@ -1142,7 +1169,8 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
...
...
src/transformers/models/openai/modeling_openai.py
View file @
4bd36f18
...
...
@@ -20,7 +20,7 @@ import json
import
math
import
os
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -607,6 +607,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
attentions
=
transformer_outputs
.
attentions
,
)
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
return
{
"input_ids"
:
input_ids
}
@
add_start_docstrings
(
"""
...
...
src/transformers/models/openai/modeling_tf_openai.py
View file @
4bd36f18
...
...
@@ -638,6 +638,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
return
TFCausalLMOutput
(
logits
=
output
.
logits
,
hidden_states
=
hs
,
attentions
=
attns
)
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
return
{
"input_ids"
:
inputs
}
@
add_start_docstrings
(
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment