"vscode:/vscode.git/clone" did not exist on "3b5962131093ceab09f0540cb99d84c18c45035f"
Unverified Commit 29f04002 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deal with nested configs better in base class (#25237)



* Deal better with nested configs

* Fixes

* More fixes

* Fix last test

* Clean up existing configs

* Remove hack in MPT Config

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Fix setting a nested config via dict in the kwargs

* Adapt common test

* Add test for nested config load with dict

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent aeb5a08a
......@@ -14,7 +14,6 @@
# limitations under the License.
""" InstructBLIP model configuration"""
import copy
import os
from typing import Union
......@@ -305,7 +304,6 @@ class InstructBlipConfig(PretrainedConfig):
```"""
model_type = "instructblip"
is_composition = True
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
super().__init__(**kwargs)
......@@ -358,17 +356,3 @@ class InstructBlipConfig(PretrainedConfig):
text_config=text_config.to_dict(),
**kwargs,
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["qformer_config"] = self.qformer_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" Jukebox configuration"""
import copy
import os
from typing import List, Union
......@@ -369,18 +368,6 @@ class JukeboxPriorConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder_config"] = self.encoder_config.to_dict() if self.encoder_config is not None else None
output["model_type"] = self.__class__.model_type
return output
class JukeboxVQVAEConfig(PretrainedConfig):
"""
......@@ -561,7 +548,6 @@ class JukeboxConfig(PretrainedConfig):
"""
model_type = "jukebox"
is_composition = True
def __init__(
self,
......@@ -620,18 +606,3 @@ class JukeboxConfig(PretrainedConfig):
"""
prior_config_list = [config.to_dict() for config in prior_configs]
return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
for i, config in enumerate(output.pop("prior_configs")):
output[f"prior_{i}"] = config.to_dict()
output["vqvae_config"] = self.vqvae_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mask2Former model configuration"""
import copy
from typing import Dict, List, Optional
from ...configuration_utils import PretrainedConfig
......@@ -230,15 +229,3 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config=backbone_config,
**kwargs,
)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" MaskFormer model configuration"""
import copy
from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig
......@@ -200,16 +199,3 @@ class MaskFormerConfig(PretrainedConfig):
decoder_config=decoder_config,
**kwargs,
)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["decoder_config"] = self.decoder_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mpt configuration"""
import copy
from typing import TYPE_CHECKING, Optional, Union
......@@ -197,7 +196,6 @@ class MptConfig(PretrainedConfig):
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
}
is_composition = True
def __init__(
self,
......@@ -222,7 +220,12 @@ class MptConfig(PretrainedConfig):
initializer_range=0.02,
**kwargs,
):
self.attn_config = attn_config
if attn_config is None:
self.attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
......@@ -242,35 +245,3 @@ class MptConfig(PretrainedConfig):
self.use_cache = use_cache
self.initializer_range = initializer_range
super().__init__(**kwargs)
@property
def attn_config(self):
return self._attn_config
@attn_config.setter
def attn_config(self, attn_config):
if attn_config is None:
self._attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self._attn_config = MptAttentionConfig(**attn_config)
elif isinstance(attn_config, MptAttentionConfig):
self._attn_config = attn_config
else:
raise ValueError(
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["attn_config"] = (
self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
)
del output["_attn_config"]
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" MusicGen model configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -227,17 +226,3 @@ class MusicgenConfig(PretrainedConfig):
decoder=decoder_config.to_dict(),
**kwargs,
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_encoder"] = self.text_encoder.to_dict()
output["audio_encoder"] = self.audio_encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""OneFormer model configuration"""
import copy
from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig
......@@ -250,13 +249,3 @@ class OneFormerConfig(PretrainedConfig):
self.num_hidden_layers = decoder_layers
super().__init__(**kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" OWL-ViT model configuration"""
import copy
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
......@@ -274,7 +273,6 @@ class OwlViTConfig(PretrainedConfig):
"""
model_type = "owlvit"
is_composition = True
def __init__(
self,
......@@ -332,19 +330,6 @@ class OwlViTConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class OwlViTOnnxConfig(OnnxConfig):
@property
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" Pix2Struct model configuration"""
import copy
import os
from typing import Union
......@@ -338,7 +337,6 @@ class Pix2StructConfig(PretrainedConfig):
```"""
model_type = "pix2struct"
is_composition = True
def __init__(
self,
......@@ -389,16 +387,3 @@ class Pix2StructConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" RAG model configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import add_start_docstrings
......@@ -179,16 +178,3 @@ class RagConfig(PretrainedConfig):
[`EncoderDecoderConfig`]: An instance of a configuration object
"""
return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["question_encoder"] = self.question_encoder.to_dict()
output["generator"] = self.generator.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" SAM model configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -286,7 +285,6 @@ class SamConfig(PretrainedConfig):
```"""
model_type = "sam"
is_composition = True
def __init__(
self,
......@@ -312,17 +310,3 @@ class SamConfig(PretrainedConfig):
self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
self.initializer_range = initializer_range
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["prompt_encoder_config"] = self.prompt_encoder_config.to_dict()
output["mask_decoder_config"] = self.mask_decoder_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -106,16 +105,3 @@ class SpeechEncoderDecoderConfig(PretrainedConfig):
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Table Transformer model configuration"""
import copy
from collections import OrderedDict
from typing import Dict, Mapping
from typing import Mapping
from packaging import version
......@@ -237,17 +236,6 @@ class TableTransformerConfig(PretrainedConfig):
def hidden_size(self) -> int:
return self.d_model
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# Copied from transformers.models.detr.configuration_detr.DetrOnnxConfig
class TableTransformerOnnxConfig(OnnxConfig):
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" UperNet model configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -108,13 +107,3 @@ class UperNetConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
self.loss_ignore_index = loss_ignore_index
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
from packaging import version
......@@ -114,19 +113,6 @@ class VisionEncoderDecoderConfig(PretrainedConfig):
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" VisionTextDualEncoder model configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -113,16 +112,3 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
"""
return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,8 +14,6 @@
# limitations under the License.
""" ViT Hybrid model configuration"""
import copy
from typing import Dict
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -146,13 +144,3 @@ class ViTHybridConfig(PretrainedConfig):
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" X-CLIP model configuration"""
import copy
import os
from typing import Union
......@@ -299,7 +298,6 @@ class XCLIPConfig(PretrainedConfig):
"""
model_type = "xclip"
is_composition = True
def __init__(
self,
......@@ -417,16 +415,3 @@ class XCLIPConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -118,9 +118,11 @@ class ConfigTester(object):
def check_config_can_be_init_without_params(self):
if self.config_class.is_composition:
return
config = self.config_class()
self.parent.assertIsNotNone(config)
with self.parent.assertRaises(ValueError):
config = self.config_class()
else:
config = self.config_class()
self.parent.assertIsNotNone(config)
def check_config_arguments_init(self):
kwargs = copy.deepcopy(config_common_kwargs)
......
......@@ -210,6 +210,13 @@ class ConfigTestUtils(unittest.TestCase):
f" {', '.join(keys_with_defaults)}."
)
def test_nested_config_load_from_dict(self):
config = AutoConfig.from_pretrained(
"hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
)
self.assertNotIsInstance(config.text_config, dict)
self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
def test_from_pretrained_subfolder(self):
with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment