Unverified Commit 130e1542 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: faster `can_generate` check on TF and Flax (#23398)

parent 2922e394
......@@ -474,7 +474,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__):
return False
return True
......
......@@ -1243,7 +1243,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
if "GenerationMixin" in str(self.prepare_inputs_for_generation.__func__):
return False
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment