"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cb966e640b8b9d0f6e9c06c1655d078a917e5196"
Unverified Commit 6134b9b4 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make `can_generate` as class method (#24299)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e45bc143
...@@ -468,13 +468,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -468,13 +468,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params # the state dict is unflattened to the match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/") return unflatten_dict(state_sharded_dict, sep="/")
def can_generate(self) -> bool: @classmethod
def can_generate(cls) -> bool:
""" """
Returns whether this model can generate sequences with `.generate()`. Returns: Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`. `bool`: Whether this model can generate sequences with `.generate()`.
""" """
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__): if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
return False return False
return True return True
......
...@@ -1328,7 +1328,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1328,7 +1328,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
pass # Layers may not have the same dimensions pass # Layers may not have the same dimensions
return output return output
def can_generate(self) -> bool: @classmethod
def can_generate(cls) -> bool:
""" """
Returns whether this model can generate sequences with `.generate()`. Returns whether this model can generate sequences with `.generate()`.
...@@ -1336,7 +1337,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1336,7 +1337,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`bool`: Whether this model can generate sequences with `.generate()`. `bool`: Whether this model can generate sequences with `.generate()`.
""" """
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__): if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
return False return False
return True return True
......
...@@ -1174,7 +1174,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1174,7 +1174,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
""" """
return getattr(self, self.base_model_prefix, self) return getattr(self, self.base_model_prefix, self)
def can_generate(self) -> bool: @classmethod
def can_generate(cls) -> bool:
""" """
Returns whether this model can generate sequences with `.generate()`. Returns whether this model can generate sequences with `.generate()`.
...@@ -1182,7 +1183,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1182,7 +1183,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
`bool`: Whether this model can generate sequences with `.generate()`. `bool`: Whether this model can generate sequences with `.generate()`.
""" """
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__): if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
return False return False
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment