"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fa01127a677103f359597656c9d995d92b517f71"
Unverified Commit 9cea3e7b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`MptConfig`] support from pretrained args (#25116)



* support from pretrained args

* draft addition of tests

* update test

* use parrent assert true

* Update src/transformers/models/mpt/configuration_mpt.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent a1c4954d
...@@ -101,6 +101,23 @@ class MptAttentionConfig(PretrainedConfig): ...@@ -101,6 +101,23 @@ class MptAttentionConfig(PretrainedConfig):
f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}" f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
) )
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "mpt":
config_dict = config_dict["attn_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class MptConfig(PretrainedConfig): class MptConfig(PretrainedConfig):
""" """
...@@ -180,6 +197,7 @@ class MptConfig(PretrainedConfig): ...@@ -180,6 +197,7 @@ class MptConfig(PretrainedConfig):
"hidden_size": "d_model", "hidden_size": "d_model",
"num_hidden_layers": "n_layers", "num_hidden_layers": "n_layers",
} }
is_composition = True
def __init__( def __init__(
self, self,
...@@ -204,6 +222,7 @@ class MptConfig(PretrainedConfig): ...@@ -204,6 +222,7 @@ class MptConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
**kwargs, **kwargs,
): ):
self.attn_config = attn_config
self.d_model = d_model self.d_model = d_model
self.n_heads = n_heads self.n_heads = n_heads
self.n_layers = n_layers self.n_layers = n_layers
...@@ -222,20 +241,25 @@ class MptConfig(PretrainedConfig): ...@@ -222,20 +241,25 @@ class MptConfig(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.use_cache = use_cache self.use_cache = use_cache
self.initializer_range = initializer_range self.initializer_range = initializer_range
super().__init__(**kwargs)
@property
def attn_config(self):
return self._attn_config
@attn_config.setter
def attn_config(self, attn_config):
if attn_config is None: if attn_config is None:
self.attn_config = MptAttentionConfig() self._attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict): elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config) self._attn_config = MptAttentionConfig(**attn_config)
elif isinstance(attn_config, MptAttentionConfig): elif isinstance(attn_config, MptAttentionConfig):
self.attn_config = attn_config self._attn_config = attn_config
else: else:
raise ValueError( raise ValueError(
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}" f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
) )
super().__init__(**kwargs)
def to_dict(self): def to_dict(self):
""" """
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
...@@ -245,7 +269,8 @@ class MptConfig(PretrainedConfig): ...@@ -245,7 +269,8 @@ class MptConfig(PretrainedConfig):
""" """
output = copy.deepcopy(self.__dict__) output = copy.deepcopy(self.__dict__)
output["attn_config"] = ( output["attn_config"] = (
self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
) )
del output["_attn_config"]
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
return output return output
...@@ -327,6 +327,20 @@ class MptModelTester: ...@@ -327,6 +327,20 @@ class MptModelTester:
return config, inputs_dict return config, inputs_dict
class MptConfigTester(ConfigTester):
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs)
def test_attn_config_as_dict(self):
config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None})
self.parent.assertTrue(config.attn_config.attn_impl == "flash")
self.parent.assertTrue(config.attn_config.softmax_scale is None)
def run_common_tests(self):
self.test_attn_config_as_dict()
return super().run_common_tests()
@require_torch @require_torch
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
...@@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def setUp(self): def setUp(self):
self.model_tester = MptModelTester(self) self.model_tester = MptModelTester(self)
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37) self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment