Unverified Commit 29f04002 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deal with nested configs better in base class (#25237)



* Deal better with nested configs

* Fixes

* More fixes

* Fix last test

* Clean up existing configs

* Remove hack in MPT Config

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Fix setting a nested config via dict in the kwargs

* Adapt common test

* Add test for nested config load with dict

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent aeb5a08a
......@@ -762,6 +762,10 @@ class PretrainedConfig(PushToHubMixin):
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
current_attr = getattr(config, key)
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
value = current_attr.__class__(**value)
setattr(config, key, value)
if key != "torch_dtype":
to_remove.append(key)
......@@ -823,6 +827,18 @@ class PretrainedConfig(PushToHubMixin):
# only serialize values that differ from the default config
for key, value in config_dict.items():
if (
isinstance(getattr(self, key, None), PretrainedConfig)
and key in class_config_dict
and isinstance(class_config_dict[key], dict)
):
# For nested configs we need to clean the diff recursively
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
if "model_type" in value:
# Needs to be set even if it's not in the diff
diff["model_type"] = value["model_type"]
if len(diff) > 0:
serializable_config_dict[key] = diff
elif (
key not in default_config_dict
or key == "transformers_version"
or value != default_config_dict[key]
......@@ -859,6 +875,14 @@ class PretrainedConfig(PushToHubMixin):
# Transformers version when serializing the model
output["transformers_version"] = __version__
for key, value in output.items():
# Deal with nested configs like CLIP
if isinstance(value, PretrainedConfig):
value = value.to_dict()
del value["transformers_version"]
output[key] = value
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
......@@ -1020,6 +1044,24 @@ def get_configuration_file(configuration_files: List[str]) -> str:
return configuration_file
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
"""
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
values from `dict_a` that are different from values in `dict_b`.
"""
diff = {}
default = config_obj.__class__().to_dict() if config_obj is not None else {}
for key, value in dict_a.items():
obj_value = getattr(config_obj, str(key), None)
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
if len(diff_value) > 0:
diff[key] = diff_value
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
diff[key] = value
return diff
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
if PretrainedConfig.push_to_hub.__doc__ is not None:
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" ALIGN model configuration"""
import copy
import os
from typing import TYPE_CHECKING, List, Union
......@@ -344,7 +343,6 @@ class AlignConfig(PretrainedConfig):
```"""
model_type = "align"
is_composition = True
def __init__(
self,
......@@ -383,16 +381,3 @@ class AlignConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" AltCLIP model configuration"""
import copy
import os
from typing import Union
......@@ -291,7 +290,6 @@ class AltCLIPConfig(PretrainedConfig):
```"""
model_type = "altclip"
is_composition = True
def __init__(
self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs
......@@ -392,16 +390,3 @@ class AltCLIPConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" BARK model configuration"""
import copy
import os
from typing import Dict, Optional, Union
......@@ -271,7 +270,6 @@ class BarkConfig(PretrainedConfig):
"""
model_type = "bark"
is_composition = True
def __init__(
self,
......@@ -329,20 +327,3 @@ class BarkConfig(PretrainedConfig):
codec_config=codec_config.to_dict(),
**kwargs,
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["semantic_config"] = self.semantic_config.to_dict()
output["coarse_acoustics_config"] = self.coarse_acoustics_config.to_dict()
output["fine_acoustics_config"] = self.fine_acoustics_config.to_dict()
output["codec_config"] = self.codec_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" Blip model configuration"""
import copy
import os
from typing import Union
......@@ -325,7 +324,6 @@ class BlipConfig(PretrainedConfig):
```"""
model_type = "blip"
is_composition = True
def __init__(
self,
......@@ -368,16 +366,3 @@ class BlipConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" BLIP-2 model configuration"""
import copy
import os
from typing import Union
......@@ -302,7 +301,6 @@ class Blip2Config(PretrainedConfig):
```"""
model_type = "blip-2"
is_composition = True
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
super().__init__(**kwargs)
......@@ -355,17 +353,3 @@ class Blip2Config(PretrainedConfig):
text_config=text_config.to_dict(),
**kwargs,
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["qformer_config"] = self.qformer_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" BridgeTower model configuration"""
import copy
import os
from typing import Union
......@@ -349,16 +348,3 @@ class BridgeTowerConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" Chinese-CLIP model configuration"""
import copy
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
......@@ -314,7 +313,6 @@ class ChineseCLIPConfig(PretrainedConfig):
```"""
model_type = "chinese_clip"
is_composition = True
def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
......@@ -417,19 +415,6 @@ class ChineseCLIPConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class ChineseCLIPOnnxConfig(OnnxConfig):
@property
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" CLAP model configuration"""
import copy
import os
from typing import Union
......@@ -382,7 +381,6 @@ class ClapConfig(PretrainedConfig):
```"""
model_type = "clap"
is_composition = True
def __init__(
self,
......@@ -431,16 +429,3 @@ class ClapConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["audio_config"] = self.audio_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" CLIP model configuration"""
import copy
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
......@@ -298,7 +297,6 @@ class CLIPConfig(PretrainedConfig):
```"""
model_type = "clip"
is_composition = True
def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
......@@ -400,19 +398,6 @@ class CLIPConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class CLIPOnnxConfig(OnnxConfig):
@property
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" CLIPSeg model configuration"""
import copy
import os
from typing import Union
......@@ -302,7 +301,6 @@ class CLIPSegConfig(PretrainedConfig):
```"""
model_type = "clipseg"
is_composition = True
def __init__(
self,
......@@ -424,16 +422,3 @@ class CLIPSegConfig(PretrainedConfig):
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional DETR model configuration"""
import copy
from collections import OrderedDict
from typing import Mapping
......@@ -238,19 +237,6 @@ class ConditionalDetrConfig(PretrainedConfig):
def hidden_size(self) -> int:
return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if self.backbone_config is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class ConditionalDetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Deformable DETR model configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -261,16 +260,3 @@ class DeformableDetrConfig(PretrainedConfig):
@property
def hidden_size(self) -> int:
return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if self.backbone_config is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" DETA model configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -230,13 +229,3 @@ class DetaConfig(PretrainedConfig):
@property
def hidden_size(self) -> int:
return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,9 +14,8 @@
# limitations under the License.
""" DETR model configuration"""
import copy
from collections import OrderedDict
from typing import Dict, Mapping
from typing import Mapping
from packaging import version
......@@ -248,17 +247,6 @@ class DetrConfig(PretrainedConfig):
"""
return cls(backbone_config=backbone_config, **kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
......
......@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -104,16 +103,3 @@ class EncoderDecoderConfig(PretrainedConfig):
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" FLAVA model configurations"""
import copy
import os
from typing import Any, Dict, Union
......@@ -536,7 +535,6 @@ class FlavaConfig(PretrainedConfig):
"""
model_type = "flava"
is_composition = True
def __init__(
self,
......@@ -764,18 +762,3 @@ class FlavaConfig(PretrainedConfig):
image_codebook_config=image_codebook_config.to_dict(),
**kwargs,
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["image_config"] = self.image_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["multimodal_config"] = self.multimodal_config.to_dict()
output["image_codebook_config"] = self.image_codebook_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -15,8 +15,6 @@
""" FSMT configuration"""
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -216,15 +214,3 @@ class FSMTConfig(PretrainedConfig):
early_stopping=early_stopping,
**common_kwargs,
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from typing import Union
......@@ -239,13 +238,3 @@ class GitConfig(PretrainedConfig):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -14,7 +14,6 @@
# limitations under the License.
""" GroupViT model configuration"""
import copy
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
......@@ -296,7 +295,6 @@ class GroupViTConfig(PretrainedConfig):
"""
model_type = "groupvit"
is_composition = True
def __init__(
self,
......@@ -407,19 +405,6 @@ class GroupViTConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class GroupViTOnnxConfig(OnnxConfig):
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment