Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4bd36f18
Unverified
Commit
4bd36f18
authored
Sep 13, 2022
by
Joao Gante
Committed by
GitHub
Sep 13, 2022
Browse files
Generate: add model class validation (#18902)
parent
69df33f1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
106 additions
and
16 deletions
+106
-16
src/transformers/generation_flax_utils.py
src/transformers/generation_flax_utils.py
+31
-1
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+33
-7
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+35
-7
src/transformers/models/openai/modeling_openai.py
src/transformers/models/openai/modeling_openai.py
+4
-1
src/transformers/models/openai/modeling_tf_openai.py
src/transformers/models/openai/modeling_tf_openai.py
+3
-0
No files found.
src/transformers/generation_flax_utils.py
View file @
4bd36f18
...
@@ -36,6 +36,11 @@ from .generation_flax_logits_process import (
...
@@ -36,6 +36,11 @@ from .generation_flax_logits_process import (
FlaxTopKLogitsWarper
,
FlaxTopKLogitsWarper
,
FlaxTopPLogitsWarper
,
FlaxTopPLogitsWarper
,
)
)
from
.models.auto
import
(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -161,6 +166,30 @@ class FlaxGenerationMixin:
...
@@ -161,6 +166,30 @@ class FlaxGenerationMixin:
"""
"""
return
logits
return
logits
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args
=
[]
unused_model_args
=
[]
...
@@ -281,7 +310,8 @@ class FlaxGenerationMixin:
...
@@ -281,7 +310,8 @@ class FlaxGenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
```"""
# Validate model kwargs
# Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
# set init values
...
...
src/transformers/generation_tf_utils.py
View file @
4bd36f18
...
@@ -35,6 +35,12 @@ from .generation_tf_logits_process import (
...
@@ -35,6 +35,12 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper
,
TFTopKLogitsWarper
,
TFTopPLogitsWarper
,
TFTopPLogitsWarper
,
)
)
from
.models.auto
import
(
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.tf_utils
import
shape_list
,
stable_softmax
from
.tf_utils
import
shape_list
,
stable_softmax
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -357,12 +363,6 @@ class TFGenerationMixin:
...
@@ -357,12 +363,6 @@ class TFGenerationMixin:
supports_xla_generation
=
True
supports_xla_generation
=
True
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
"""
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return
{
"input_ids"
:
inputs
}
def
_use_cache
(
self
,
outputs
,
use_cache
):
def
_use_cache
(
self
,
outputs
,
use_cache
):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
use_cache
=
getattr
(
self
.
config
,
"use_cache"
,
False
)
use_cache
=
getattr
(
self
.
config
,
"use_cache"
,
False
)
...
@@ -1290,6 +1290,31 @@ class TFGenerationMixin:
...
@@ -1290,6 +1290,31 @@ class TFGenerationMixin:
else
:
else
:
return
logits
return
logits
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
# Excludes arguments that are handled before calling any model function
...
@@ -1508,7 +1533,8 @@ class TFGenerationMixin:
...
@@ -1508,7 +1533,8 @@ class TFGenerationMixin:
# generate sequences without allowing bad_words to be generated
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
# 1. Set generation parameters if not already defined
...
...
src/transformers/generation_utils.py
View file @
4bd36f18
...
@@ -51,6 +51,13 @@ from .generation_stopping_criteria import (
...
@@ -51,6 +51,13 @@ from .generation_stopping_criteria import (
StoppingCriteriaList
,
StoppingCriteriaList
,
validate_stopping_criteria
,
validate_stopping_criteria
,
)
)
from
.models.auto
import
(
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
.pytorch_utils
import
torch_int_div
from
.pytorch_utils
import
torch_int_div
from
.utils
import
ModelOutput
,
logging
from
.utils
import
ModelOutput
,
logging
...
@@ -463,12 +470,6 @@ class GenerationMixin:
...
@@ -463,12 +470,6 @@ class GenerationMixin:
return
can_retrieve_inputs
return
can_retrieve_inputs
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
"""
return
{
"input_ids"
:
input_ids
}
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
"""
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
...
@@ -840,6 +841,32 @@ class GenerationMixin:
...
@@ -840,6 +841,32 @@ class GenerationMixin:
return
transition_scores
return
transition_scores
def
_validate_model_class
(
self
):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for_generation"
):
generate_compatible_mappings
=
[
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
]
generate_compatible_classes
=
set
()
for
model_mapping
in
generate_compatible_mappings
:
supported_models
=
model_mapping
.
get
(
type
(
self
.
config
),
default
=
None
)
if
supported_models
is
not
None
:
generate_compatible_classes
.
add
(
supported_models
.
__name__
)
exception_message
=
(
f
"The current model class (
{
self
.
__class__
.
__name__
}
) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if
generate_compatible_classes
:
exception_message
+=
f
" Please use one of the following classes instead:
{
generate_compatible_classes
}
"
raise
TypeError
(
exception_message
)
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
# Excludes arguments that are handled before calling any model function
...
@@ -1142,7 +1169,8 @@ class GenerationMixin:
...
@@ -1142,7 +1169,8 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
```"""
# 0. Validate model kwargs
# 0. Validate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
# 1. Set generation parameters if not already defined
...
...
src/transformers/models/openai/modeling_openai.py
View file @
4bd36f18
...
@@ -20,7 +20,7 @@ import json
...
@@ -20,7 +20,7 @@ import json
import
math
import
math
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -607,6 +607,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -607,6 +607,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
attentions
=
transformer_outputs
.
attentions
,
attentions
=
transformer_outputs
.
attentions
,
)
)
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
**
kwargs
)
->
Dict
[
str
,
Any
]:
return
{
"input_ids"
:
input_ids
}
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
src/transformers/models/openai/modeling_tf_openai.py
View file @
4bd36f18
...
@@ -638,6 +638,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
...
@@ -638,6 +638,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
return
TFCausalLMOutput
(
logits
=
output
.
logits
,
hidden_states
=
hs
,
attentions
=
attns
)
return
TFCausalLMOutput
(
logits
=
output
.
logits
,
hidden_states
=
hs
,
attentions
=
attns
)
def
prepare_inputs_for_generation
(
self
,
inputs
,
**
kwargs
):
return
{
"input_ids"
:
inputs
}
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment