"docs/vscode:/vscode.git/clone" did not exist on "2c7b26f5083becb429bdae4c919feca28fdf5699"
Unverified Commit 07998ef3 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: models with custom `generate()` return `True` in `can_generate()` (#25838)

parent 8c75cfda
......@@ -475,8 +475,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True
......
......@@ -1307,8 +1307,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True
......
......@@ -1216,8 +1216,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True
......
......@@ -1231,13 +1231,6 @@ class BarkFineModel(BarkPreTrainedModel):
attentions=all_self_attentions,
)
def can_generate(self) -> bool:
"""
Returns True. Despite being an autoencoder, BarkFineModel shares some characteristics with generative models
due to the way audio are generated.
"""
return True
def generate(
self,
coarse_output: torch.Tensor,
......@@ -1594,10 +1587,3 @@ class BarkModel(BarkPreTrainedModel):
self.codec_model_hook.offload()
return audio
def can_generate(self) -> bool:
"""
Returns True. Despite not having a `self.generate` method, this model can `generate` and thus needs a
BarkGenerationConfig.
"""
return True
......@@ -2779,13 +2779,6 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
encoder_attentions=outputs.encoder_attentions,
)
def can_generate(self) -> bool:
"""
Returns True. This model can `generate` and must therefore have this property set to True in order to be used
in the TTS pipeline.
"""
return True
@torch.no_grad()
def generate(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment