Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4bd36f18
Unverified
Commit
4bd36f18
authored
Sep 13, 2022
by
Joao Gante
Committed by
GitHub
Sep 13, 2022
Browse files
Generate: add model class validation (#18902)
parent
69df33f1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
106 additions
and
16 deletions
+106
-16
src/transformers/generation_flax_utils.py
src/transformers/generation_flax_utils.py
+31
-1
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+33
-7
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+35
-7
src/transformers/models/openai/modeling_openai.py
src/transformers/models/openai/modeling_openai.py
+4
-1
src/transformers/models/openai/modeling_tf_openai.py
src/transformers/models/openai/modeling_tf_openai.py
+3
-0
No files found.
src/transformers/generation_flax_utils.py
View file @
4bd36f18
...
@@ -36,6 +36,11 @@ from .generation_flax_logits_process import (
...
@@ -36,6 +36,11 @@ from .generation_flax_logits_process import (
FlaxTopKLogitsWarper
,
FlaxTopKLogitsWarper
,
FlaxTopPLogitsWarper
,
FlaxTopPLogitsWarper
,
)
)
from
.models.auto
import
(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -161,6 +166,30 @@ class FlaxGenerationMixin:
...
@@ -161,6 +166,30 @@ class FlaxGenerationMixin:
"""
"""
return
logits
return
logits
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args
=
[]
unused_model_args
=
[]
...
@@ -281,7 +310,8 @@ class FlaxGenerationMixin:
...
@@ -281,7 +310,8 @@ class FlaxGenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
```"""
# Validate model kwargs
# Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
# set init values
...
...
src/transformers/generation_tf_utils.py
View file @
4bd36f18
...
@@ -35,6 +35,12 @@ from .generation_tf_logits_process import (
...
@@ -35,6 +35,12 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
TFTopPLogitsWarper
,
)
)
from
.models.auto
import
(
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.tf_utils
import
shape_list
,
stable_softmax
from
.tf_utils
import
shape_list
,
stable_softmax
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -357,12 +363,6 @@ class TFGenerationMixin:
...
@@ -357,12 +363,6 @@ class TFGenerationMixin:
supports_xla_generation
=
True
supports_xla_generation
=
True
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
"""
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return
{
"input_ids"
:
inputs
}
def
_use_cache
(
self
,
outputs
,
use_cache
):
def
_use_cache
(
self
,
outputs
,
use_cache
):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
use_cache
=
getattr
(
self
.
config
,
"use_cache"
,
False
)
use_cache
=
getattr
(
self
.
config
,
"use_cache"
,
False
)
...
@@ -1290,6 +1290,31 @@ class TFGenerationMixin:
...
@@ -1290,6 +1290,31 @@ class TFGenerationMixin:
else
:
else
:
return
logits
return
logits
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
# Excludes arguments that are handled before calling any model function
...
@@ -1508,7 +1533,8 @@ class TFGenerationMixin:
...
@@ -1508,7 +1533,8 @@ class TFGenerationMixin:
# generate sequences without allowing bad_words to be generated
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
# 1. Set generation parameters if not already defined
...
...
src/transformers/generation_utils.py
View file @
4bd36f18
...
@@ -51,6 +51,13 @@ from .generation_stopping_criteria import (
...
@@ -51,6 +51,13 @@ from .generation_stopping_criteria import (
StoppingCriteriaList
,
StoppingCriteriaList
,
validate_stopping_criteria
,
validate_stopping_criteria
,
)
)
from
.models.auto
import
(
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.pytorch_utils
import
torch_int_div
from
.pytorch_utils
import
torch_int_div
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -463,12 +470,6 @@ class GenerationMixin:
...
@@ -463,12 +470,6 @@ class GenerationMixin:
return
can_retrieve_inputs
return
can_retrieve_inputs
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return
{
"input_ids"
:
input_ids
}
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
"""
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
...
@@ -840,6 +841,32 @@ class GenerationMixin:
...
@@ -840,6 +841,32 @@ class GenerationMixin:
return
transition_scores
return
transition_scores
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
# Excludes arguments that are handled before calling any model function
...
@@ -1142,7 +1169,8 @@ class GenerationMixin:
...
@@ -1142,7 +1169,8 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
# 1. Set generation parameters if not already defined
...
...
src/transformers/models/openai/modeling_openai.py
View file @
4bd36f18
...
@@ -20,7 +20,7 @@ import json
...
@@ -20,7 +20,7 @@ import json
import
math
import
math
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -607,6 +607,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -607,6 +607,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
attentions
=
transformer_outputs
.
attentions
,
attentions
=
transformer_outputs
.
attentions
,
)
)
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
return
{
"input_ids"
:
input_ids
}
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
src/transformers/models/openai/modeling_tf_openai.py
View file @
4bd36f18
...
@@ -638,6 +638,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
...
@@ -638,6 +638,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
return
TFCausalLMOutput
(
logits
=
output
.
logits
,
hidden_states
=
hs
,
attentions
=
attns
)
return
TFCausalLMOutput
(
logits
=
output
.
logits
,
hidden_states
=
hs
,
attentions
=
attns
)
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
return
{
"input_ids"
:
inputs
}
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment