Unverified Commit 6134b9b4 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make `can_generate` as class method (#24299)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e45bc143
......@@ -468,13 +468,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/")
def can_generate(self) -> bool:
@classmethod
def can_generate(cls) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__):
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
return False
return True
......
......@@ -1328,7 +1328,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
pass # Layers may not have the same dimensions
return output
def can_generate(self) -> bool:
@classmethod
def can_generate(cls) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
......@@ -1336,7 +1337,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__):
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
return False
return True
......
......@@ -1174,7 +1174,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
return getattr(self, self.base_model_prefix, self)
def can_generate(self) -> bool:
@classmethod
def can_generate(cls) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
......@@ -1182,7 +1183,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__):
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
return False
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment