Unverified Commit 8827e1b2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

clean up vision/text config dict arguments (#19954)



* clean up

* For backward compatibility

* clean up

* Same changes for more models
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent cb630ffa
...@@ -262,9 +262,9 @@ class CLIPConfig(PretrainedConfig): ...@@ -262,9 +262,9 @@ class CLIPConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
text_config_dict (`dict`, *optional*): text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`CLIPTextConfig`]. Dictionary of configuration options used to initialize [`CLIPTextConfig`].
vision_config_dict (`dict`, *optional*): vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
projection_dim (`int`, *optional*, defaults to 512): projection_dim (`int`, *optional*, defaults to 512):
Dimentionality of text and vision projection layers. Dimentionality of text and vision projection layers.
...@@ -300,25 +300,28 @@ class CLIPConfig(PretrainedConfig): ...@@ -300,25 +300,28 @@ class CLIPConfig(PretrainedConfig):
is_composition = True is_composition = True
def __init__( def __init__(
self, self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
text_config_dict=None,
vision_config_dict=None,
projection_dim=512,
logit_scale_init_value=2.6592,
**kwargs
): ):
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs) super().__init__(**kwargs)
# If `_config_dict` exist, we use them for the backward compatibility.
text_config_dict = kwargs.pop("text_config_dict", None)
vision_config_dict = kwargs.pop("vision_config_dict", None)
if text_config_dict is not None:
text_config = text_config_dict
if vision_config_dict is not None:
vision_config = vision_config_dict
if text_config_dict is None: if text_config is None:
text_config_dict = {} text_config = {}
logger.info("text_config_dict is None. Initializing the CLIPTextConfig with default values.") logger.info("text_config is None. Initializing the CLIPTextConfig with default values.")
if vision_config_dict is None: if vision_config is None:
vision_config_dict = {} vision_config = {}
logger.info("vision_config_dict is None. initializing the CLIPVisionConfig with default values.") logger.info("vision_config is None. initializing the CLIPVisionConfig with default values.")
self.text_config = CLIPTextConfig(**text_config_dict) self.text_config = CLIPTextConfig(**text_config)
self.vision_config = CLIPVisionConfig(**vision_config_dict) self.vision_config = CLIPVisionConfig(**vision_config)
self.projection_dim = projection_dim self.projection_dim = projection_dim
self.logit_scale_init_value = logit_scale_init_value self.logit_scale_init_value = logit_scale_init_value
...@@ -334,7 +337,7 @@ class CLIPConfig(PretrainedConfig): ...@@ -334,7 +337,7 @@ class CLIPConfig(PretrainedConfig):
[`CLIPConfig`]: An instance of a configuration object [`CLIPConfig`]: An instance of a configuration object
""" """
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self): def to_dict(self):
""" """
......
...@@ -471,11 +471,11 @@ class FlavaConfig(PretrainedConfig): ...@@ -471,11 +471,11 @@ class FlavaConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
text_config_dict (`dict`, *optional*): text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`FlavaTextConfig`]. Dictionary of configuration options used to initialize [`FlavaTextConfig`].
image_config_dict (`dict`, *optional*): image_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`FlavaImageConfig`]. Dictionary of configuration options used to initialize [`FlavaImageConfig`].
multimodal_config_dict (`dict`, *optional*): multimodal_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`]. Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`].
hidden_size (`int`, *optional*, defaults to 768): hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer. Dimensionality of the encoder layers and the pooler layer.
...@@ -535,10 +535,10 @@ class FlavaConfig(PretrainedConfig): ...@@ -535,10 +535,10 @@ class FlavaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
image_config_dict: Dict[str, Any] = None, image_config: Dict[str, Any] = None,
text_config_dict: Dict[str, Any] = None, text_config: Dict[str, Any] = None,
multimodal_config_dict: Dict[str, Any] = None, multimodal_config: Dict[str, Any] = None,
image_codebook_config_dict: Dict[str, Any] = None, image_codebook_config: Dict[str, Any] = None,
hidden_size: int = 768, hidden_size: int = 768,
layer_norm_eps: float = 1e-12, layer_norm_eps: float = 1e-12,
projection_dim: int = 768, projection_dim: int = 768,
...@@ -559,33 +559,42 @@ class FlavaConfig(PretrainedConfig): ...@@ -559,33 +559,42 @@ class FlavaConfig(PretrainedConfig):
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if image_config_dict is None: # If `_config_dict` exist, we use them for the backward compatibility.
image_config_dict = {} text_config_dict = kwargs.pop("text_config_dict", None)
logger.info("image_config_dict is None. initializing the FlavaImageConfig with default values.") image_config_dict = kwargs.pop("vision_config_dict", None)
multimodal_config_dict = kwargs.pop("multimodal_config_dict", None)
if text_config_dict is None: image_codebook_config_dict = kwargs.pop("image_codebook_config_dict", None)
text_config_dict = {} if text_config_dict is not None:
logger.info("text_config_dict is None. Initializing the FlavaTextConfig with default values.") text_config = text_config_dict
if image_config_dict is not None:
if multimodal_config_dict is None: image_config = image_config_dict
multimodal_config_dict = {} if multimodal_config_dict is not None:
logger.info("multimodal_config_dict is None. initializing the FlavaMultimodalConfig with default values.") multimodal_config = multimodal_config_dict
if image_codebook_config_dict is not None:
if image_codebook_config_dict is None: image_codebook_config = image_codebook_config_dict
image_codebook_config_dict = {}
if image_config is None:
image_config = {}
logger.info("image_config is None. initializing the FlavaImageConfig with default values.")
if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the FlavaTextConfig with default values.")
if multimodal_config is None:
multimodal_config = {}
logger.info("multimodal_config is None. initializing the FlavaMultimodalConfig with default values.")
if image_codebook_config is None:
image_codebook_config = {}
logger.info( logger.info(
"image_codebook_config_dict is None. initializing the FlavaImageCodebookConfig with default values." "image_codebook_config is None. initializing the FlavaImageCodebookConfig with default values."
) )
self.image_config_dict = image_config_dict self.image_config = FlavaImageConfig(**image_config)
self.text_config_dict = text_config_dict self.text_config = FlavaTextConfig(**text_config)
self.multimodal_config_dict = multimodal_config_dict self.multimodal_config = FlavaMultimodalConfig(**multimodal_config)
self.image_codebook_config_dict = image_codebook_config_dict self.image_codebook_config = FlavaImageCodebookConfig(**image_codebook_config)
self.image_config = FlavaImageConfig(**self.image_config_dict)
self.text_config = FlavaTextConfig(**self.text_config_dict)
self.multimodal_config = FlavaMultimodalConfig(**self.multimodal_config_dict)
self.image_codebook_config = FlavaImageCodebookConfig(**self.image_codebook_config_dict)
self.projection_dim = projection_dim self.projection_dim = projection_dim
self.init_codebook = init_codebook self.init_codebook = init_codebook
...@@ -623,10 +632,10 @@ class FlavaConfig(PretrainedConfig): ...@@ -623,10 +632,10 @@ class FlavaConfig(PretrainedConfig):
""" """
return cls( return cls(
image_config_dict=image_config.to_dict(), image_config=image_config.to_dict(),
text_config_dict=text_config.to_dict(), text_config=text_config.to_dict(),
multimodal_config_dict=multimodal_config.to_dict(), multimodal_config=multimodal_config.to_dict(),
image_codebook_config_dict=image_codebook_config.to_dict(), image_codebook_config=image_codebook_config.to_dict(),
**kwargs, **kwargs,
) )
......
...@@ -280,9 +280,9 @@ class GroupViTConfig(PretrainedConfig): ...@@ -280,9 +280,9 @@ class GroupViTConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
text_config_dict (`dict`, *optional*): text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`GroupViTTextConfig`]. Dictionary of configuration options used to initialize [`GroupViTTextConfig`].
vision_config_dict (`dict`, *optional*): vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`GroupViTVisionConfig`]. Dictionary of configuration options used to initialize [`GroupViTVisionConfig`].
projection_dim (`int`, *optional*, defaults to 256): projection_dim (`int`, *optional*, defaults to 256):
Dimentionality of text and vision projection layers. Dimentionality of text and vision projection layers.
...@@ -300,25 +300,33 @@ class GroupViTConfig(PretrainedConfig): ...@@ -300,25 +300,33 @@ class GroupViTConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
text_config_dict=None, text_config=None,
vision_config_dict=None, vision_config=None,
projection_dim=256, projection_dim=256,
projection_intermediate_dim=4096, projection_intermediate_dim=4096,
logit_scale_init_value=2.6592, logit_scale_init_value=2.6592,
**kwargs **kwargs
): ):
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs) super().__init__(**kwargs)
# If `_config_dict` exist, we use them for the backward compatibility.
text_config_dict = kwargs.pop("text_config_dict", None)
vision_config_dict = kwargs.pop("vision_config_dict", None)
if text_config_dict is not None:
text_config = text_config_dict
if vision_config_dict is not None:
vision_config = vision_config_dict
if text_config_dict is None: if text_config is None:
text_config_dict = {} text_config = {}
logger.info("text_config_dict is None. Initializing the GroupViTTextConfig with default values.") logger.info("text_config is None. Initializing the GroupViTTextConfig with default values.")
if vision_config_dict is None: if vision_config is None:
vision_config_dict = {} vision_config = {}
logger.info("vision_config_dict is None. initializing the GroupViTVisionConfig with default values.") logger.info("vision_config is None. initializing the GroupViTVisionConfig with default values.")
self.text_config = GroupViTTextConfig(**text_config_dict) self.text_config = GroupViTTextConfig(**text_config)
self.vision_config = GroupViTVisionConfig(**vision_config_dict) self.vision_config = GroupViTVisionConfig(**vision_config)
self.projection_dim = projection_dim self.projection_dim = projection_dim
self.projection_intermediate_dim = projection_intermediate_dim self.projection_intermediate_dim = projection_intermediate_dim
...@@ -337,7 +345,7 @@ class GroupViTConfig(PretrainedConfig): ...@@ -337,7 +345,7 @@ class GroupViTConfig(PretrainedConfig):
[`GroupViTConfig`]: An instance of a configuration object [`GroupViTConfig`]: An instance of a configuration object
""" """
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self): def to_dict(self):
""" """
......
...@@ -260,9 +260,9 @@ class OwlViTConfig(PretrainedConfig): ...@@ -260,9 +260,9 @@ class OwlViTConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
text_config_dict (`dict`, *optional*): text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`OwlViTTextConfig`]. Dictionary of configuration options used to initialize [`OwlViTTextConfig`].
vision_config_dict (`dict`, *optional*): vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`OwlViTVisionConfig`]. Dictionary of configuration options used to initialize [`OwlViTVisionConfig`].
projection_dim (`int`, *optional*, defaults to 512): projection_dim (`int`, *optional*, defaults to 512):
Dimensionality of text and vision projection layers. Dimensionality of text and vision projection layers.
...@@ -285,15 +285,15 @@ class OwlViTConfig(PretrainedConfig): ...@@ -285,15 +285,15 @@ class OwlViTConfig(PretrainedConfig):
return_dict=True, return_dict=True,
**kwargs **kwargs
): ):
super().__init__(text_config=text_config, vision_config=vision_config, **kwargs) super().__init__(**kwargs)
if text_config is None: if text_config is None:
text_config = {} text_config = {}
logger.info("text_config_dict is None. Initializing the OwlViTTextConfig with default values.") logger.info("text_config is None. Initializing the OwlViTTextConfig with default values.")
if vision_config is None: if vision_config is None:
vision_config = {} vision_config = {}
logger.info("vision_config_dict is None. initializing the OwlViTVisionConfig with default values.") logger.info("vision_config is None. initializing the OwlViTVisionConfig with default values.")
self.text_config = OwlViTTextConfig(**text_config) self.text_config = OwlViTTextConfig(**text_config)
self.vision_config = OwlViTVisionConfig(**vision_config) self.vision_config = OwlViTVisionConfig(**vision_config)
......
...@@ -35,9 +35,9 @@ class VisionTextDualEncoderConfig(PretrainedConfig): ...@@ -35,9 +35,9 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
text_config_dict (`dict`): text_config (`dict`):
Dictionary of configuration options that defines text model config. Dictionary of configuration options that defines text model config.
vision_config_dict (`dict`): vision_config (`dict`):
Dictionary of configuration options that defines vison model config. Dictionary of configuration options that defines vison model config.
projection_dim (`int`, *optional*, defaults to 512): projection_dim (`int`, *optional*, defaults to 512):
Dimentionality of text and vision projection layers. Dimentionality of text and vision projection layers.
......
...@@ -279,9 +279,9 @@ class XCLIPConfig(PretrainedConfig): ...@@ -279,9 +279,9 @@ class XCLIPConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
text_config_dict (`dict`, *optional*): text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`XCLIPTextConfig`]. Dictionary of configuration options used to initialize [`XCLIPTextConfig`].
vision_config_dict (`dict`, *optional*): vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`XCLIPVisionConfig`]. Dictionary of configuration options used to initialize [`XCLIPVisionConfig`].
projection_dim (`int`, *optional*, defaults to 512): projection_dim (`int`, *optional*, defaults to 512):
Dimentionality of text and vision projection layers. Dimentionality of text and vision projection layers.
...@@ -309,8 +309,8 @@ class XCLIPConfig(PretrainedConfig): ...@@ -309,8 +309,8 @@ class XCLIPConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
text_config_dict=None, text_config=None,
vision_config_dict=None, vision_config=None,
projection_dim=512, projection_dim=512,
prompt_layers=2, prompt_layers=2,
prompt_alpha=0.1, prompt_alpha=0.1,
...@@ -321,18 +321,26 @@ class XCLIPConfig(PretrainedConfig): ...@@ -321,18 +321,26 @@ class XCLIPConfig(PretrainedConfig):
logit_scale_init_value=2.6592, logit_scale_init_value=2.6592,
**kwargs **kwargs
): ):
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs) super().__init__(**kwargs)
# If `_config_dict` exist, we use them for the backward compatibility.
text_config_dict = kwargs.pop("text_config_dict", None)
vision_config_dict = kwargs.pop("vision_config_dict", None)
if text_config_dict is not None:
text_config = text_config_dict
if vision_config_dict is not None:
vision_config = vision_config_dict
if text_config_dict is None: if text_config is None:
text_config_dict = {} text_config = {}
logger.info("text_config_dict is None. Initializing the XCLIPTextConfig with default values.") logger.info("text_config is None. Initializing the XCLIPTextConfig with default values.")
if vision_config_dict is None: if vision_config is None:
vision_config_dict = {} vision_config = {}
logger.info("vision_config_dict is None. initializing the XCLIPVisionConfig with default values.") logger.info("vision_config is None. initializing the XCLIPVisionConfig with default values.")
self.text_config = XCLIPTextConfig(**text_config_dict) self.text_config = XCLIPTextConfig(**text_config)
self.vision_config = XCLIPVisionConfig(**vision_config_dict) self.vision_config = XCLIPVisionConfig(**vision_config)
self.projection_dim = projection_dim self.projection_dim = projection_dim
self.prompt_layers = prompt_layers self.prompt_layers = prompt_layers
...@@ -354,7 +362,7 @@ class XCLIPConfig(PretrainedConfig): ...@@ -354,7 +362,7 @@ class XCLIPConfig(PretrainedConfig):
[`XCLIPConfig`]: An instance of a configuration object [`XCLIPConfig`]: An instance of a configuration object
""" """
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self): def to_dict(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment