Unverified Commit 811c4c9f authored by Shu Takayama's avatar Shu Takayama Committed by GitHub
Browse files

fix bug: register_for_auto_class should be defined on TFPreTrainedModel...

fix bug: register_for_auto_class should be defined on TFPreTrainedModel instead of TFSequenceSummary (#18607)
parent ee407024
......@@ -2541,6 +2541,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token
)
@classmethod
def register_for_auto_class(cls, auto_class="TFAutoModel"):
"""
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
The auto class to register this new model with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
class TFConv1D(tf.keras.layers.Layer):
"""
......@@ -2795,32 +2821,6 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output
@classmethod
def register_for_auto_class(cls, auto_class="TFAutoModel"):
"""
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
The auto class to register this new model with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment