Unverified Commit 8827e1b2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

clean up vision/text config dict arguments (#19954)



* clean up

* For backward compatibility

* clean up

* Same changes for more models
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent cb630ffa
......@@ -262,9 +262,9 @@ class CLIPConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
text_config_dict (`dict`, *optional*):
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`CLIPTextConfig`].
vision_config_dict (`dict`, *optional*):
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
projection_dim (`int`, *optional*, defaults to 512):
Dimentionality of text and vision projection layers.
......@@ -300,25 +300,28 @@ class CLIPConfig(PretrainedConfig):
is_composition = True
def __init__(
self,
text_config_dict=None,
vision_config_dict=None,
projection_dim=512,
logit_scale_init_value=2.6592,
**kwargs
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
):
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
super().__init__(**kwargs)
# If `_config_dict` exist, we use them for the backward compatibility.
text_config_dict = kwargs.pop("text_config_dict", None)
vision_config_dict = kwargs.pop("vision_config_dict", None)
if text_config_dict is not None:
text_config = text_config_dict
if vision_config_dict is not None:
vision_config = vision_config_dict
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the CLIPTextConfig with default values.")
if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the CLIPTextConfig with default values.")
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the CLIPVisionConfig with default values.")
if vision_config is None:
vision_config = {}
logger.info("vision_config is None. initializing the CLIPVisionConfig with default values.")
self.text_config = CLIPTextConfig(**text_config_dict)
self.vision_config = CLIPVisionConfig(**vision_config_dict)
self.text_config = CLIPTextConfig(**text_config)
self.vision_config = CLIPVisionConfig(**vision_config)
self.projection_dim = projection_dim
self.logit_scale_init_value = logit_scale_init_value
......@@ -334,7 +337,7 @@ class CLIPConfig(PretrainedConfig):
[`CLIPConfig`]: An instance of a configuration object
"""
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
......
......@@ -471,11 +471,11 @@ class FlavaConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
text_config_dict (`dict`, *optional*):
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`FlavaTextConfig`].
image_config_dict (`dict`, *optional*):
image_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`FlavaImageConfig`].
multimodal_config_dict (`dict`, *optional*):
multimodal_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
......@@ -535,10 +535,10 @@ class FlavaConfig(PretrainedConfig):
def __init__(
self,
image_config_dict: Dict[str, Any] = None,
text_config_dict: Dict[str, Any] = None,
multimodal_config_dict: Dict[str, Any] = None,
image_codebook_config_dict: Dict[str, Any] = None,
image_config: Dict[str, Any] = None,
text_config: Dict[str, Any] = None,
multimodal_config: Dict[str, Any] = None,
image_codebook_config: Dict[str, Any] = None,
hidden_size: int = 768,
layer_norm_eps: float = 1e-12,
projection_dim: int = 768,
......@@ -559,33 +559,42 @@ class FlavaConfig(PretrainedConfig):
):
super().__init__(**kwargs)
if image_config_dict is None:
image_config_dict = {}
logger.info("image_config_dict is None. initializing the FlavaImageConfig with default values.")
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the FlavaTextConfig with default values.")
if multimodal_config_dict is None:
multimodal_config_dict = {}
logger.info("multimodal_config_dict is None. initializing the FlavaMultimodalConfig with default values.")
if image_codebook_config_dict is None:
image_codebook_config_dict = {}
# If `_config_dict` exist, we use them for the backward compatibility.
text_config_dict = kwargs.pop("text_config_dict", None)
image_config_dict = kwargs.pop("vision_config_dict", None)
multimodal_config_dict = kwargs.pop("multimodal_config_dict", None)
image_codebook_config_dict = kwargs.pop("image_codebook_config_dict", None)
if text_config_dict is not None:
text_config = text_config_dict
if image_config_dict is not None:
image_config = image_config_dict
if multimodal_config_dict is not None:
multimodal_config = multimodal_config_dict
if image_codebook_config_dict is not None:
image_codebook_config = image_codebook_config_dict
if image_config is None:
image_config = {}
logger.info("image_config is None. initializing the FlavaImageConfig with default values.")
if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the FlavaTextConfig with default values.")
if multimodal_config is None:
multimodal_config = {}
logger.info("multimodal_config is None. initializing the FlavaMultimodalConfig with default values.")
if image_codebook_config is None:
image_codebook_config = {}
logger.info(
"image_codebook_config_dict is None. initializing the FlavaImageCodebookConfig with default values."
"image_codebook_config is None. initializing the FlavaImageCodebookConfig with default values."
)
self.image_config_dict = image_config_dict
self.text_config_dict = text_config_dict
self.multimodal_config_dict = multimodal_config_dict
self.image_codebook_config_dict = image_codebook_config_dict
self.image_config = FlavaImageConfig(**self.image_config_dict)
self.text_config = FlavaTextConfig(**self.text_config_dict)
self.multimodal_config = FlavaMultimodalConfig(**self.multimodal_config_dict)
self.image_codebook_config = FlavaImageCodebookConfig(**self.image_codebook_config_dict)
self.image_config = FlavaImageConfig(**image_config)
self.text_config = FlavaTextConfig(**text_config)
self.multimodal_config = FlavaMultimodalConfig(**multimodal_config)
self.image_codebook_config = FlavaImageCodebookConfig(**image_codebook_config)
self.projection_dim = projection_dim
self.init_codebook = init_codebook
......@@ -623,10 +632,10 @@ class FlavaConfig(PretrainedConfig):
"""
return cls(
image_config_dict=image_config.to_dict(),
text_config_dict=text_config.to_dict(),
multimodal_config_dict=multimodal_config.to_dict(),
image_codebook_config_dict=image_codebook_config.to_dict(),
image_config=image_config.to_dict(),
text_config=text_config.to_dict(),
multimodal_config=multimodal_config.to_dict(),
image_codebook_config=image_codebook_config.to_dict(),
**kwargs,
)
......
......@@ -280,9 +280,9 @@ class GroupViTConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
text_config_dict (`dict`, *optional*):
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`GroupViTTextConfig`].
vision_config_dict (`dict`, *optional*):
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`GroupViTVisionConfig`].
projection_dim (`int`, *optional*, defaults to 256):
Dimentionality of text and vision projection layers.
......@@ -300,25 +300,33 @@ class GroupViTConfig(PretrainedConfig):
def __init__(
self,
text_config_dict=None,
vision_config_dict=None,
text_config=None,
vision_config=None,
projection_dim=256,
projection_intermediate_dim=4096,
logit_scale_init_value=2.6592,
**kwargs
):
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
super().__init__(**kwargs)
# If `_config_dict` exist, we use them for the backward compatibility.
text_config_dict = kwargs.pop("text_config_dict", None)
vision_config_dict = kwargs.pop("vision_config_dict", None)
if text_config_dict is not None:
text_config = text_config_dict
if vision_config_dict is not None:
vision_config = vision_config_dict
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the GroupViTTextConfig with default values.")
if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the GroupViTTextConfig with default values.")
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the GroupViTVisionConfig with default values.")
if vision_config is None:
vision_config = {}
logger.info("vision_config is None. initializing the GroupViTVisionConfig with default values.")
self.text_config = GroupViTTextConfig(**text_config_dict)
self.vision_config = GroupViTVisionConfig(**vision_config_dict)
self.text_config = GroupViTTextConfig(**text_config)
self.vision_config = GroupViTVisionConfig(**vision_config)
self.projection_dim = projection_dim
self.projection_intermediate_dim = projection_intermediate_dim
......@@ -337,7 +345,7 @@ class GroupViTConfig(PretrainedConfig):
[`GroupViTConfig`]: An instance of a configuration object
"""
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
......
......@@ -260,9 +260,9 @@ class OwlViTConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
text_config_dict (`dict`, *optional*):
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`OwlViTTextConfig`].
vision_config_dict (`dict`, *optional*):
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`OwlViTVisionConfig`].
projection_dim (`int`, *optional*, defaults to 512):
Dimensionality of text and vision projection layers.
......@@ -285,15 +285,15 @@ class OwlViTConfig(PretrainedConfig):
return_dict=True,
**kwargs
):
super().__init__(text_config=text_config, vision_config=vision_config, **kwargs)
super().__init__(**kwargs)
if text_config is None:
text_config = {}
logger.info("text_config_dict is None. Initializing the OwlViTTextConfig with default values.")
logger.info("text_config is None. Initializing the OwlViTTextConfig with default values.")
if vision_config is None:
vision_config = {}
logger.info("vision_config_dict is None. initializing the OwlViTVisionConfig with default values.")
logger.info("vision_config is None. initializing the OwlViTVisionConfig with default values.")
self.text_config = OwlViTTextConfig(**text_config)
self.vision_config = OwlViTVisionConfig(**vision_config)
......
......@@ -35,9 +35,9 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
text_config_dict (`dict`):
text_config (`dict`):
Dictionary of configuration options that defines text model config.
vision_config_dict (`dict`):
vision_config (`dict`):
Dictionary of configuration options that defines vison model config.
projection_dim (`int`, *optional*, defaults to 512):
Dimentionality of text and vision projection layers.
......
......@@ -279,9 +279,9 @@ class XCLIPConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
text_config_dict (`dict`, *optional*):
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`XCLIPTextConfig`].
vision_config_dict (`dict`, *optional*):
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`XCLIPVisionConfig`].
projection_dim (`int`, *optional*, defaults to 512):
Dimentionality of text and vision projection layers.
......@@ -309,8 +309,8 @@ class XCLIPConfig(PretrainedConfig):
def __init__(
self,
text_config_dict=None,
vision_config_dict=None,
text_config=None,
vision_config=None,
projection_dim=512,
prompt_layers=2,
prompt_alpha=0.1,
......@@ -321,18 +321,26 @@ class XCLIPConfig(PretrainedConfig):
logit_scale_init_value=2.6592,
**kwargs
):
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
super().__init__(**kwargs)
# If `_config_dict` exist, we use them for the backward compatibility.
text_config_dict = kwargs.pop("text_config_dict", None)
vision_config_dict = kwargs.pop("vision_config_dict", None)
if text_config_dict is not None:
text_config = text_config_dict
if vision_config_dict is not None:
vision_config = vision_config_dict
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the XCLIPTextConfig with default values.")
if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the XCLIPTextConfig with default values.")
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the XCLIPVisionConfig with default values.")
if vision_config is None:
vision_config = {}
logger.info("vision_config is None. initializing the XCLIPVisionConfig with default values.")
self.text_config = XCLIPTextConfig(**text_config_dict)
self.vision_config = XCLIPVisionConfig(**vision_config_dict)
self.text_config = XCLIPTextConfig(**text_config)
self.vision_config = XCLIPVisionConfig(**vision_config)
self.projection_dim = projection_dim
self.prompt_layers = prompt_layers
......@@ -354,7 +362,7 @@ class XCLIPConfig(PretrainedConfig):
[`XCLIPConfig`]: An instance of a configuration object
"""
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment