Unverified Commit 07998ef3 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: models with custom `generate()` return `True` in `can_generate()` (#25838)

parent 8c75cfda
...@@ -475,8 +475,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -475,8 +475,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
Returns whether this model can generate sequences with `.generate()`. Returns: Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`. `bool`: Whether this model can generate sequences with `.generate()`.
""" """
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation): # Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False return False
return True return True
......
...@@ -1307,8 +1307,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1307,8 +1307,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Returns: Returns:
`bool`: Whether this model can generate sequences with `.generate()`. `bool`: Whether this model can generate sequences with `.generate()`.
""" """
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation): # Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False return False
return True return True
......
...@@ -1216,8 +1216,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1216,8 +1216,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Returns: Returns:
`bool`: Whether this model can generate sequences with `.generate()`. `bool`: Whether this model can generate sequences with `.generate()`.
""" """
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation): # Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False return False
return True return True
......
...@@ -1231,13 +1231,6 @@ class BarkFineModel(BarkPreTrainedModel): ...@@ -1231,13 +1231,6 @@ class BarkFineModel(BarkPreTrainedModel):
attentions=all_self_attentions, attentions=all_self_attentions,
) )
def can_generate(self) -> bool:
"""
Returns True. Despite being an autoencoder, BarkFineModel shares some characteristics with generative models
due to the way audio are generated.
"""
return True
def generate( def generate(
self, self,
coarse_output: torch.Tensor, coarse_output: torch.Tensor,
...@@ -1594,10 +1587,3 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1594,10 +1587,3 @@ class BarkModel(BarkPreTrainedModel):
self.codec_model_hook.offload() self.codec_model_hook.offload()
return audio return audio
def can_generate(self) -> bool:
"""
Returns True. Despite not having a `self.generate` method, this model can `generate` and thus needs a
BarkGenerationConfig.
"""
return True
...@@ -2779,13 +2779,6 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): ...@@ -2779,13 +2779,6 @@ class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
def can_generate(self) -> bool:
"""
Returns True. This model can `generate` and must therefore have this property set to True in order to be used
in the TTS pipeline.
"""
return True
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment