"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7ac3311e487541b6b8a6a43a39c23ea343da3545"
Commit 4f338ed4 authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Explicit config_class instead of module inspection

parent 6fe1cc08
...@@ -101,20 +101,6 @@ class AutoConfig: ...@@ -101,20 +101,6 @@ class AutoConfig:
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
) )
@classmethod
def config_class_for_model_class(cls, model_class):
module = import_module(model_class.__module__)
return next(
(
module_attribute
for module_attribute_name in dir(module)
if module_attribute_name.endswith("Config")
for module_attribute in (getattr(module, module_attribute_name),)
if issubclass(module_attribute, PretrainedConfig)
),
None,
)
@classmethod @classmethod
def for_model(cls, model_type, *args, **kwargs): def for_model(cls, model_type, *args, **kwargs):
for pattern, config_class in CONFIG_MAPPING.items(): for pattern, config_class in CONFIG_MAPPING.items():
......
...@@ -480,6 +480,8 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -480,6 +480,8 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFAlbertMainLayer(tf.keras.layers.Layer): class TFAlbertMainLayer(tf.keras.layers.Layer):
config_class = AlbertConfig
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
......
...@@ -473,6 +473,8 @@ class TFBertNSPHead(tf.keras.layers.Layer): ...@@ -473,6 +473,8 @@ class TFBertNSPHead(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer): class TFBertMainLayer(tf.keras.layers.Layer):
config_class = BertConfig
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
......
...@@ -166,6 +166,8 @@ class TFEncoderLayer(tf.keras.layers.Layer): ...@@ -166,6 +166,8 @@ class TFEncoderLayer(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFCTRLMainLayer(tf.keras.layers.Layer): class TFCTRLMainLayer(tf.keras.layers.Layer):
config_class = CTRLConfig
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
......
...@@ -199,6 +199,8 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -199,6 +199,8 @@ class TFBlock(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFGPT2MainLayer(tf.keras.layers.Layer): class TFGPT2MainLayer(tf.keras.layers.Layer):
config_class = GPT2Config
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
......
...@@ -380,6 +380,8 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -380,6 +380,8 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFTransfoXLMainLayer(tf.keras.layers.Layer): class TFTransfoXLMainLayer(tf.keras.layers.Layer):
config_class = TransfoXLConfig
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
......
...@@ -50,11 +50,13 @@ class TFModelUtilsMixin: ...@@ -50,11 +50,13 @@ class TFModelUtilsMixin:
def keras_serializable(cls): def keras_serializable(cls):
initializer = cls.__init__ initializer = cls.__init__
config_class = getattr(cls, "config_class", None)
if config_class is None:
raise AttributeError("Must set `config_class` to use @keras_serializable")
def wrapped_init(self, config, *args, **kwargs): def wrapped_init(self, config, *args, **kwargs):
if isinstance(config, dict): if isinstance(config, dict):
from transformers import AutoConfig config = config_class.from_dict(config)
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
initializer(self, config, *args, **kwargs) initializer(self, config, *args, **kwargs)
self._transformers_config = config self._transformers_config = config
......
...@@ -351,6 +351,8 @@ class TFXLNetLMHead(tf.keras.layers.Layer): ...@@ -351,6 +351,8 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFXLNetMainLayer(tf.keras.layers.Layer): class TFXLNetMainLayer(tf.keras.layers.Layer):
config_class = XLNetConfig
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment