"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f497f564bb76697edab09184a252fc1b1a326d1e"
Unverified Commit 29f04002 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deal with nested configs better in base class (#25237)



* Deal better with nested configs

* Fixes

* More fixes

* Fix last test

* Clean up existing configs

* Remove hack in MPT Config

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Fix setting a nested config via dict in the kwargs

* Adapt common test

* Add test for nested config load with dict

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent aeb5a08a
...@@ -762,6 +762,10 @@ class PretrainedConfig(PushToHubMixin): ...@@ -762,6 +762,10 @@ class PretrainedConfig(PushToHubMixin):
to_remove = [] to_remove = []
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(config, key): if hasattr(config, key):
current_attr = getattr(config, key)
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
value = current_attr.__class__(**value)
setattr(config, key, value) setattr(config, key, value)
if key != "torch_dtype": if key != "torch_dtype":
to_remove.append(key) to_remove.append(key)
...@@ -823,6 +827,18 @@ class PretrainedConfig(PushToHubMixin): ...@@ -823,6 +827,18 @@ class PretrainedConfig(PushToHubMixin):
# only serialize values that differ from the default config # only serialize values that differ from the default config
for key, value in config_dict.items(): for key, value in config_dict.items():
if ( if (
isinstance(getattr(self, key, None), PretrainedConfig)
and key in class_config_dict
and isinstance(class_config_dict[key], dict)
):
# For nested configs we need to clean the diff recursively
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
if "model_type" in value:
# Needs to be set even if it's not in the diff
diff["model_type"] = value["model_type"]
if len(diff) > 0:
serializable_config_dict[key] = diff
elif (
key not in default_config_dict key not in default_config_dict
or key == "transformers_version" or key == "transformers_version"
or value != default_config_dict[key] or value != default_config_dict[key]
...@@ -859,6 +875,14 @@ class PretrainedConfig(PushToHubMixin): ...@@ -859,6 +875,14 @@ class PretrainedConfig(PushToHubMixin):
# Transformers version when serializing the model # Transformers version when serializing the model
output["transformers_version"] = __version__ output["transformers_version"] = __version__
for key, value in output.items():
# Deal with nested configs like CLIP
if isinstance(value, PretrainedConfig):
value = value.to_dict()
del value["transformers_version"]
output[key] = value
if hasattr(self, "quantization_config"): if hasattr(self, "quantization_config"):
output["quantization_config"] = ( output["quantization_config"] = (
self.quantization_config.to_dict() self.quantization_config.to_dict()
...@@ -1020,6 +1044,24 @@ def get_configuration_file(configuration_files: List[str]) -> str: ...@@ -1020,6 +1044,24 @@ def get_configuration_file(configuration_files: List[str]) -> str:
return configuration_file return configuration_file
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
"""
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
values from `dict_a` that are different from values in `dict_b`.
"""
diff = {}
default = config_obj.__class__().to_dict() if config_obj is not None else {}
for key, value in dict_a.items():
obj_value = getattr(config_obj, str(key), None)
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
if len(diff_value) > 0:
diff[key] = diff_value
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
diff[key] = value
return diff
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub) PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
if PretrainedConfig.push_to_hub.__doc__ is not None: if PretrainedConfig.push_to_hub.__doc__ is not None:
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format( PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" ALIGN model configuration""" """ ALIGN model configuration"""
import copy
import os import os
from typing import TYPE_CHECKING, List, Union from typing import TYPE_CHECKING, List, Union
...@@ -344,7 +343,6 @@ class AlignConfig(PretrainedConfig): ...@@ -344,7 +343,6 @@ class AlignConfig(PretrainedConfig):
```""" ```"""
model_type = "align" model_type = "align"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -383,16 +381,3 @@ class AlignConfig(PretrainedConfig): ...@@ -383,16 +381,3 @@ class AlignConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" AltCLIP model configuration""" """ AltCLIP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -291,7 +290,6 @@ class AltCLIPConfig(PretrainedConfig): ...@@ -291,7 +290,6 @@ class AltCLIPConfig(PretrainedConfig):
```""" ```"""
model_type = "altclip" model_type = "altclip"
is_composition = True
def __init__( def __init__(
self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs
...@@ -392,16 +390,3 @@ class AltCLIPConfig(PretrainedConfig): ...@@ -392,16 +390,3 @@ class AltCLIPConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" BARK model configuration""" """ BARK model configuration"""
import copy
import os import os
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
...@@ -271,7 +270,6 @@ class BarkConfig(PretrainedConfig): ...@@ -271,7 +270,6 @@ class BarkConfig(PretrainedConfig):
""" """
model_type = "bark" model_type = "bark"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -329,20 +327,3 @@ class BarkConfig(PretrainedConfig): ...@@ -329,20 +327,3 @@ class BarkConfig(PretrainedConfig):
codec_config=codec_config.to_dict(), codec_config=codec_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["semantic_config"] = self.semantic_config.to_dict()
output["coarse_acoustics_config"] = self.coarse_acoustics_config.to_dict()
output["fine_acoustics_config"] = self.fine_acoustics_config.to_dict()
output["codec_config"] = self.codec_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Blip model configuration""" """ Blip model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -325,7 +324,6 @@ class BlipConfig(PretrainedConfig): ...@@ -325,7 +324,6 @@ class BlipConfig(PretrainedConfig):
```""" ```"""
model_type = "blip" model_type = "blip"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -368,16 +366,3 @@ class BlipConfig(PretrainedConfig): ...@@ -368,16 +366,3 @@ class BlipConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" BLIP-2 model configuration""" """ BLIP-2 model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -302,7 +301,6 @@ class Blip2Config(PretrainedConfig): ...@@ -302,7 +301,6 @@ class Blip2Config(PretrainedConfig):
```""" ```"""
model_type = "blip-2" model_type = "blip-2"
is_composition = True
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -355,17 +353,3 @@ class Blip2Config(PretrainedConfig): ...@@ -355,17 +353,3 @@ class Blip2Config(PretrainedConfig):
text_config=text_config.to_dict(), text_config=text_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["qformer_config"] = self.qformer_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" BridgeTower model configuration""" """ BridgeTower model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -349,16 +348,3 @@ class BridgeTowerConfig(PretrainedConfig): ...@@ -349,16 +348,3 @@ class BridgeTowerConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Chinese-CLIP model configuration""" """ Chinese-CLIP model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
...@@ -314,7 +313,6 @@ class ChineseCLIPConfig(PretrainedConfig): ...@@ -314,7 +313,6 @@ class ChineseCLIPConfig(PretrainedConfig):
```""" ```"""
model_type = "chinese_clip" model_type = "chinese_clip"
is_composition = True
def __init__( def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
...@@ -417,19 +415,6 @@ class ChineseCLIPConfig(PretrainedConfig): ...@@ -417,19 +415,6 @@ class ChineseCLIPConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class ChineseCLIPOnnxConfig(OnnxConfig): class ChineseCLIPOnnxConfig(OnnxConfig):
@property @property
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" CLAP model configuration""" """ CLAP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -382,7 +381,6 @@ class ClapConfig(PretrainedConfig): ...@@ -382,7 +381,6 @@ class ClapConfig(PretrainedConfig):
```""" ```"""
model_type = "clap" model_type = "clap"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -431,16 +429,3 @@ class ClapConfig(PretrainedConfig): ...@@ -431,16 +429,3 @@ class ClapConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["audio_config"] = self.audio_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" CLIP model configuration""" """ CLIP model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
...@@ -298,7 +297,6 @@ class CLIPConfig(PretrainedConfig): ...@@ -298,7 +297,6 @@ class CLIPConfig(PretrainedConfig):
```""" ```"""
model_type = "clip" model_type = "clip"
is_composition = True
def __init__( def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
...@@ -400,19 +398,6 @@ class CLIPConfig(PretrainedConfig): ...@@ -400,19 +398,6 @@ class CLIPConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class CLIPOnnxConfig(OnnxConfig): class CLIPOnnxConfig(OnnxConfig):
@property @property
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" CLIPSeg model configuration""" """ CLIPSeg model configuration"""
import copy
import os import os
from typing import Union from typing import Union
...@@ -302,7 +301,6 @@ class CLIPSegConfig(PretrainedConfig): ...@@ -302,7 +301,6 @@ class CLIPSegConfig(PretrainedConfig):
```""" ```"""
model_type = "clipseg" model_type = "clipseg"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -424,16 +422,3 @@ class CLIPSegConfig(PretrainedConfig): ...@@ -424,16 +422,3 @@ class CLIPSegConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Conditional DETR model configuration""" """ Conditional DETR model configuration"""
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import Mapping from typing import Mapping
...@@ -238,19 +237,6 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -238,19 +237,6 @@ class ConditionalDetrConfig(PretrainedConfig):
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if self.backbone_config is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class ConditionalDetrOnnxConfig(OnnxConfig): class ConditionalDetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11") torch_onnx_minimum_version = version.parse("1.11")
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Deformable DETR model configuration""" """ Deformable DETR model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -261,16 +260,3 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -261,16 +260,3 @@ class DeformableDetrConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if self.backbone_config is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" DETA model configuration""" """ DETA model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -230,13 +229,3 @@ class DetaConfig(PretrainedConfig): ...@@ -230,13 +229,3 @@ class DetaConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
""" DETR model configuration""" """ DETR model configuration"""
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Mapping from typing import Mapping
from packaging import version from packaging import version
...@@ -248,17 +247,6 @@ class DetrConfig(PretrainedConfig): ...@@ -248,17 +247,6 @@ class DetrConfig(PretrainedConfig):
""" """
return cls(backbone_config=backbone_config, **kwargs) return cls(backbone_config=backbone_config, **kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class DetrOnnxConfig(OnnxConfig): class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11") torch_onnx_minimum_version = version.parse("1.11")
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -104,16 +103,3 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -104,16 +103,3 @@ class EncoderDecoderConfig(PretrainedConfig):
decoder_config.add_cross_attention = True decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" FLAVA model configurations""" """ FLAVA model configurations"""
import copy
import os import os
from typing import Any, Dict, Union from typing import Any, Dict, Union
...@@ -536,7 +535,6 @@ class FlavaConfig(PretrainedConfig): ...@@ -536,7 +535,6 @@ class FlavaConfig(PretrainedConfig):
""" """
model_type = "flava" model_type = "flava"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -764,18 +762,3 @@ class FlavaConfig(PretrainedConfig): ...@@ -764,18 +762,3 @@ class FlavaConfig(PretrainedConfig):
image_codebook_config=image_codebook_config.to_dict(), image_codebook_config=image_codebook_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["image_config"] = self.image_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["multimodal_config"] = self.multimodal_config.to_dict()
output["image_codebook_config"] = self.image_codebook_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
""" FSMT configuration""" """ FSMT configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -216,15 +214,3 @@ class FSMTConfig(PretrainedConfig): ...@@ -216,15 +214,3 @@ class FSMTConfig(PretrainedConfig):
early_stopping=early_stopping, early_stopping=early_stopping,
**common_kwargs, **common_kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
from typing import Union from typing import Union
...@@ -239,13 +238,3 @@ class GitConfig(PretrainedConfig): ...@@ -239,13 +238,3 @@ class GitConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" GroupViT model configuration""" """ GroupViT model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
...@@ -296,7 +295,6 @@ class GroupViTConfig(PretrainedConfig): ...@@ -296,7 +295,6 @@ class GroupViTConfig(PretrainedConfig):
""" """
model_type = "groupvit" model_type = "groupvit"
is_composition = True
def __init__( def __init__(
self, self,
...@@ -407,19 +405,6 @@ class GroupViTConfig(PretrainedConfig): ...@@ -407,19 +405,6 @@ class GroupViTConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class GroupViTOnnxConfig(OnnxConfig): class GroupViTOnnxConfig(OnnxConfig):
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment