Unverified Commit 29f04002 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deal with nested configs better in base class (#25237)



* Deal better with nested configs

* Fixes

* More fixes

* Fix last test

* Clean up existing configs

* Remove hack in MPT Config

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Fix setting a nested config via dict in the kwargs

* Adapt common test

* Add test for nested config load with dict

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent aeb5a08a
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" InstructBLIP model configuration""" """ InstructBLIP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -305,7 +304,6 @@ class InstructBlipConfig(PretrainedConfig): ...@@ -305,7 +304,6 @@ class InstructBlipConfig(PretrainedConfig):
```""" ```"""
model_type = "instructblip" model_type = "instructblip"
is_composition = True
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -358,17 +356,3 @@ class InstructBlipConfig(PretrainedConfig): ...@@ -358,17 +356,3 @@ class InstructBlipConfig(PretrainedConfig):
text_config=text_config.to_dict(), text_config=text_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["qformer_config"] = self.qformer_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Jukebox configuration""" """ Jukebox configuration"""
import copy
import os import os
from typing import List, Union from typing import List, Union
...@@ -369,18 +368,6 @@ class JukeboxPriorConfig(PretrainedConfig): ...@@ -369,18 +368,6 @@ class JukeboxPriorConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder_config"] = self.encoder_config.to_dict() if self.encoder_config is not None else None
output["model_type"] = self.__class__.model_type
return output
class JukeboxVQVAEConfig(PretrainedConfig): class JukeboxVQVAEConfig(PretrainedConfig):
""" """
...@@ -561,7 +548,6 @@ class JukeboxConfig(PretrainedConfig): ...@@ -561,7 +548,6 @@ class JukeboxConfig(PretrainedConfig):
""" """
model_type = "jukebox" model_type = "jukebox"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -620,18 +606,3 @@ class JukeboxConfig(PretrainedConfig): ...@@ -620,18 +606,3 @@ class JukeboxConfig(PretrainedConfig):
""" """
prior_config_list = [config.to_dict() for config in prior_configs] prior_config_list = [config.to_dict() for config in prior_configs]
return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs) return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
for i, config in enumerate(output.pop("prior_configs")):
output[f"prior_{i}"] = config.to_dict()
output["vqvae_config"] = self.vqvae_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Mask2Former model configuration""" """ Mask2Former model configuration"""
import copy
from typing import Dict, List, Optional from typing import Dict, List, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
...@@ -230,15 +229,3 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -230,15 +229,3 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config=backbone_config, backbone_config=backbone_config,
**kwargs, **kwargs,
) )
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MaskFormer model configuration""" """ MaskFormer model configuration"""
import copy
from typing import Dict, Optional from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
...@@ -200,16 +199,3 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -200,16 +199,3 @@ class MaskFormerConfig(PretrainedConfig):
decoder_config=decoder_config, decoder_config=decoder_config,
**kwargs, **kwargs,
) )
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["decoder_config"] = self.decoder_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Mpt configuration""" """ Mpt configuration"""
import copy
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
...@@ -197,7 +196,6 @@ class MptConfig(PretrainedConfig): ...@@ -197,7 +196,6 @@ class MptConfig(PretrainedConfig):
"hidden_size": "d_model", "hidden_size": "d_model",
"num_hidden_layers": "n_layers", "num_hidden_layers": "n_layers",
} }
is_composition = True
def __init__( def __init__(
self, self,
...@@ -222,6 +220,11 @@ class MptConfig(PretrainedConfig): ...@@ -222,6 +220,11 @@ class MptConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
**kwargs, **kwargs,
): ):
if attn_config is None:
self.attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
else:
self.attn_config = attn_config self.attn_config = attn_config
self.d_model = d_model self.d_model = d_model
self.n_heads = n_heads self.n_heads = n_heads
...@@ -242,35 +245,3 @@ class MptConfig(PretrainedConfig): ...@@ -242,35 +245,3 @@ class MptConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.initializer_range = initializer_range self.initializer_range = initializer_range
super().__init__(**kwargs) super().__init__(**kwargs)
@property
def attn_config(self):
return self._attn_config
@attn_config.setter
def attn_config(self, attn_config):
if attn_config is None:
self._attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self._attn_config = MptAttentionConfig(**attn_config)
elif isinstance(attn_config, MptAttentionConfig):
self._attn_config = attn_config
else:
raise ValueError(
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["attn_config"] = (
self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
)
del output["_attn_config"]
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MusicGen model configuration""" """ MusicGen model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -227,17 +226,3 @@ class MusicgenConfig(PretrainedConfig): ...@@ -227,17 +226,3 @@ class MusicgenConfig(PretrainedConfig):
decoder=decoder_config.to_dict(), decoder=decoder_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_encoder"] = self.text_encoder.to_dict()
output["audio_encoder"] = self.audio_encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""OneFormer model configuration""" """OneFormer model configuration"""
import copy
from typing import Dict, Optional from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
...@@ -250,13 +249,3 @@ class OneFormerConfig(PretrainedConfig): ...@@ -250,13 +249,3 @@ class OneFormerConfig(PretrainedConfig):
self.num_hidden_layers = decoder_layers self.num_hidden_layers = decoder_layers
super().__init__(**kwargs) super().__init__(**kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" OWL-ViT model configuration""" """ OWL-ViT model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
...@@ -274,7 +273,6 @@ class OwlViTConfig(PretrainedConfig): ...@@ -274,7 +273,6 @@ class OwlViTConfig(PretrainedConfig):
""" """
model_type = "owlvit" model_type = "owlvit"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -332,19 +330,6 @@ class OwlViTConfig(PretrainedConfig): ...@@ -332,19 +330,6 @@ class OwlViTConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class OwlViTOnnxConfig(OnnxConfig): class OwlViTOnnxConfig(OnnxConfig):
@property @property
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Pix2Struct model configuration""" """ Pix2Struct model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -338,7 +337,6 @@ class Pix2StructConfig(PretrainedConfig): ...@@ -338,7 +337,6 @@ class Pix2StructConfig(PretrainedConfig):
```""" ```"""
model_type = "pix2struct" model_type = "pix2struct"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -389,16 +387,3 @@ class Pix2StructConfig(PretrainedConfig): ...@@ -389,16 +387,3 @@ class Pix2StructConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" RAG model configuration""" """ RAG model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import add_start_docstrings from ...utils import add_start_docstrings
...@@ -179,16 +178,3 @@ class RagConfig(PretrainedConfig): ...@@ -179,16 +178,3 @@ class RagConfig(PretrainedConfig):
[`EncoderDecoderConfig`]: An instance of a configuration object [`EncoderDecoderConfig`]: An instance of a configuration object
""" """
return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs) return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["question_encoder"] = self.question_encoder.to_dict()
output["generator"] = self.generator.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" SAM model configuration""" """ SAM model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -286,7 +285,6 @@ class SamConfig(PretrainedConfig): ...@@ -286,7 +285,6 @@ class SamConfig(PretrainedConfig):
```""" ```"""
model_type = "sam" model_type = "sam"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -312,17 +310,3 @@ class SamConfig(PretrainedConfig): ...@@ -312,17 +310,3 @@ class SamConfig(PretrainedConfig):
self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config) self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config) self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
self.initializer_range = initializer_range self.initializer_range = initializer_range
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["prompt_encoder_config"] = self.prompt_encoder_config.to_dict()
output["mask_decoder_config"] = self.mask_decoder_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -106,16 +105,3 @@ class SpeechEncoderDecoderConfig(PretrainedConfig): ...@@ -106,16 +105,3 @@ class SpeechEncoderDecoderConfig(PretrainedConfig):
decoder_config.add_cross_attention = True decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,9 +13,8 @@ ...@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Table Transformer model configuration""" """ Table Transformer model configuration"""
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Mapping from typing import Mapping
from packaging import version from packaging import version
...@@ -237,17 +236,6 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -237,17 +236,6 @@ class TableTransformerConfig(PretrainedConfig):
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# Copied from transformers.models.detr.configuration_detr.DetrOnnxConfig # Copied from transformers.models.detr.configuration_detr.DetrOnnxConfig
class TableTransformerOnnxConfig(OnnxConfig): class TableTransformerOnnxConfig(OnnxConfig):
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" UperNet model configuration""" """ UperNet model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -108,13 +107,3 @@ class UperNetConfig(PretrainedConfig): ...@@ -108,13 +107,3 @@ class UperNetConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input self.auxiliary_concat_input = auxiliary_concat_input
self.loss_ignore_index = loss_ignore_index self.loss_ignore_index = loss_ignore_index
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
from packaging import version from packaging import version
...@@ -114,19 +113,6 @@ class VisionEncoderDecoderConfig(PretrainedConfig): ...@@ -114,19 +113,6 @@ class VisionEncoderDecoderConfig(PretrainedConfig):
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig): class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11") torch_onnx_minimum_version = version.parse("1.11")
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" VisionTextDualEncoder model configuration""" """ VisionTextDualEncoder model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -113,16 +112,3 @@ class VisionTextDualEncoderConfig(PretrainedConfig): ...@@ -113,16 +112,3 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
""" """
return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs) return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
""" ViT Hybrid model configuration""" """ ViT Hybrid model configuration"""
import copy
from typing import Dict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -146,13 +144,3 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -146,13 +144,3 @@ class ViTHybridConfig(PretrainedConfig):
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels self.num_channels = num_channels
self.qkv_bias = qkv_bias self.qkv_bias = qkv_bias
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" X-CLIP model configuration""" """ X-CLIP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -299,7 +298,6 @@ class XCLIPConfig(PretrainedConfig): ...@@ -299,7 +298,6 @@ class XCLIPConfig(PretrainedConfig):
""" """
model_type = "xclip" model_type = "xclip"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -417,16 +415,3 @@ class XCLIPConfig(PretrainedConfig): ...@@ -417,16 +415,3 @@ class XCLIPConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -118,7 +118,9 @@ class ConfigTester(object): ...@@ -118,7 +118,9 @@ class ConfigTester(object):
def check_config_can_be_init_without_params(self): def check_config_can_be_init_without_params(self):
if self.config_class.is_composition: if self.config_class.is_composition:
return with self.parent.assertRaises(ValueError):
config = self.config_class()
else:
config = self.config_class() config = self.config_class()
self.parent.assertIsNotNone(config) self.parent.assertIsNotNone(config)
......
...@@ -210,6 +210,13 @@ class ConfigTestUtils(unittest.TestCase): ...@@ -210,6 +210,13 @@ class ConfigTestUtils(unittest.TestCase):
f" {', '.join(keys_with_defaults)}." f" {', '.join(keys_with_defaults)}."
) )
def test_nested_config_load_from_dict(self):
config = AutoConfig.from_pretrained(
"hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
)
self.assertNotIsInstance(config.text_config, dict)
self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
def test_from_pretrained_subfolder(self): def test_from_pretrained_subfolder(self):
with self.assertRaises(OSError): with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder # config is in subfolder, the following should not work without specifying the subfolder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment