Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
07998ef3
Unverified
Commit
07998ef3
authored
Aug 29, 2023
by
Joao Gante
Committed by
GitHub
Aug 29, 2023
Browse files
Generate: models with custom `generate()` return `True` in `can_generate()` (#25838)
parent
8c75cfda
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
27 deletions
+9
-27
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+3
-2
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+3
-2
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+3
-2
src/transformers/models/bark/modeling_bark.py
src/transformers/models/bark/modeling_bark.py
+0
-14
src/transformers/models/speecht5/modeling_speecht5.py
src/transformers/models/speecht5/modeling_speecht5.py
+0
-7
No files found.
src/transformers/modeling_flax_utils.py
View file @
07998ef3
...
@@ -475,8 +475,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -475,8 +475,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
Returns whether this model can generate sequences with `.generate()`. Returns:
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
`bool`: Whether this model can generate sequences with `.generate()`.
"""
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
):
# Alternativelly, the model can also have a custom `generate` function.
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
)
and
"GenerationMixin"
in
str
(
cls
.
generate
):
return
False
return
False
return
True
return
True
...
...
src/transformers/modeling_tf_utils.py
View file @
07998ef3
...
@@ -1307,8 +1307,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1307,8 +1307,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Returns:
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
`bool`: Whether this model can generate sequences with `.generate()`.
"""
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
):
# Alternativelly, the model can also have a custom `generate` function.
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
)
and
"GenerationMixin"
in
str
(
cls
.
generate
):
return
False
return
False
return
True
return
True
...
...
src/transformers/modeling_utils.py
View file @
07998ef3
...
@@ -1216,8 +1216,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1216,8 +1216,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Returns:
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
`bool`: Whether this model can generate sequences with `.generate()`.
"""
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
):
# Alternativelly, the model can also have a custom `generate` function.
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
)
and
"GenerationMixin"
in
str
(
cls
.
generate
):
return
False
return
False
return
True
return
True
...
...
src/transformers/models/bark/modeling_bark.py
View file @
07998ef3
...
@@ -1231,13 +1231,6 @@ class BarkFineModel(BarkPreTrainedModel):
...
@@ -1231,13 +1231,6 @@ class BarkFineModel(BarkPreTrainedModel):
attentions
=
all_self_attentions
,
attentions
=
all_self_attentions
,
)
)
def
can_generate
(
self
)
->
bool
:
"""
Returns True. Despite being an autoencoder, BarkFineModel shares some characteristics with generative models
due to the way audio are generated.
"""
return
True
def
generate
(
def
generate
(
self
,
self
,
coarse_output
:
torch
.
Tensor
,
coarse_output
:
torch
.
Tensor
,
...
@@ -1594,10 +1587,3 @@ class BarkModel(BarkPreTrainedModel):
...
@@ -1594,10 +1587,3 @@ class BarkModel(BarkPreTrainedModel):
self
.
codec_model_hook
.
offload
()
self
.
codec_model_hook
.
offload
()
return
audio
return
audio
def
can_generate
(
self
)
->
bool
:
"""
Returns True. Despite not having a `self.generate` method, this model can `generate` and thus needs a
BarkGenerationConfig.
"""
return
True
src/transformers/models/speecht5/modeling_speecht5.py
View file @
07998ef3
...
@@ -2779,13 +2779,6 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
...
@@ -2779,13 +2779,6 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
encoder_attentions
=
outputs
.
encoder_attentions
,
encoder_attentions
=
outputs
.
encoder_attentions
,
)
)
def
can_generate
(
self
)
->
bool
:
"""
Returns True. This model can `generate` and must therefore have this property set to True in order to be used
in the TTS pipeline.
"""
return
True
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
generate
(
def
generate
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment