"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "2fd2f792d614174f72c9f49edaf97214c4048319"
Unverified Commit ee3af60b authored by Taylor Jackle Spriggs's avatar Taylor Jackle Spriggs Committed by GitHub
Browse files

Add support for fine-tuning CLIP-like models using contrastive-image-text example (#29070)

* add support for siglip and chinese-clip model training with contrastive-image-text example

* codebase fixups
parent 0996a100
...@@ -54,6 +54,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -54,6 +54,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("camembert", "CamembertConfig"), ("camembert", "CamembertConfig"),
("canine", "CanineConfig"), ("canine", "CanineConfig"),
("chinese_clip", "ChineseCLIPConfig"), ("chinese_clip", "ChineseCLIPConfig"),
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
("clap", "ClapConfig"), ("clap", "ClapConfig"),
("clip", "CLIPConfig"), ("clip", "CLIPConfig"),
("clip_vision_model", "CLIPVisionConfig"), ("clip_vision_model", "CLIPVisionConfig"),
...@@ -512,6 +513,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -512,6 +513,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("camembert", "CamemBERT"), ("camembert", "CamemBERT"),
("canine", "CANINE"), ("canine", "CANINE"),
("chinese_clip", "Chinese-CLIP"), ("chinese_clip", "Chinese-CLIP"),
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "CLAP"), ("clap", "CLAP"),
("clip", "CLIP"), ("clip", "CLIP"),
("clip_vision_model", "CLIPVisionModel"), ("clip_vision_model", "CLIPVisionModel"),
...@@ -773,6 +775,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( ...@@ -773,6 +775,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("xclip", "x_clip"), ("xclip", "x_clip"),
("clip_vision_model", "clip"), ("clip_vision_model", "clip"),
("siglip_vision_model", "siglip"), ("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
] ]
) )
......
...@@ -57,6 +57,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -57,6 +57,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("camembert", "CamembertModel"), ("camembert", "CamembertModel"),
("canine", "CanineModel"), ("canine", "CanineModel"),
("chinese_clip", "ChineseCLIPModel"), ("chinese_clip", "ChineseCLIPModel"),
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "ClapModel"), ("clap", "ClapModel"),
("clip", "CLIPModel"), ("clip", "CLIPModel"),
("clip_vision_model", "CLIPVisionModel"), ("clip_vision_model", "CLIPVisionModel"),
......
...@@ -171,8 +171,7 @@ class ChineseCLIPVisionConfig(PretrainedConfig): ...@@ -171,8 +171,7 @@ class ChineseCLIPVisionConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an
ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the ChineseCLIP configuration with the defaults will yield a similar configuration to that of the ChineseCLIP
[OFA-Sys/chinese-clip-vit-base-patch16](https: [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
//huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
......
...@@ -18,11 +18,19 @@ ...@@ -18,11 +18,19 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
from ..chinese_clip.configuration_chinese_clip import ChineseCLIPVisionConfig
from ..clip.configuration_clip import CLIPVisionConfig from ..clip.configuration_clip import CLIPVisionConfig
from ..siglip.configuration_siglip import SiglipVisionConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
VISION_MODEL_CONFIGS = {
"clip_vision_model": CLIPVisionConfig,
"chinese_clip_vision_model": ChineseCLIPVisionConfig,
"siglip_vision_model": SiglipVisionConfig,
}
class VisionTextDualEncoderConfig(PretrainedConfig): class VisionTextDualEncoderConfig(PretrainedConfig):
r""" r"""
...@@ -85,12 +93,13 @@ class VisionTextDualEncoderConfig(PretrainedConfig): ...@@ -85,12 +93,13 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
vision_model_type = vision_config.pop("model_type") vision_model_type = vision_config.pop("model_type")
text_model_type = text_config.pop("model_type") text_model_type = text_config.pop("model_type")
if vision_model_type == "clip": vision_config_class = VISION_MODEL_CONFIGS.get(vision_model_type)
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config if vision_config_class is not None:
elif vision_model_type == "clip_vision_model": self.vision_config = vision_config_class(**vision_config)
self.vision_config = CLIPVisionConfig(**vision_config)
else: else:
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config) self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
if hasattr(self.vision_config, "vision_config"):
self.vision_config = self.vision_config.vision_config
self.text_config = AutoConfig.for_model(text_model_type, **text_config) self.text_config = AutoConfig.for_model(text_model_type, **text_config)
......
...@@ -1070,6 +1070,7 @@ MODELS_NOT_IN_README = [ ...@@ -1070,6 +1070,7 @@ MODELS_NOT_IN_README = [
"VisionTextDualEncoder", "VisionTextDualEncoder",
"CLIPVisionModel", "CLIPVisionModel",
"SiglipVisionModel", "SiglipVisionModel",
"ChineseCLIPVisionModel",
] ]
# Template for new entries to add in the main README when we have missing models. # Template for new entries to add in the main README when we have missing models.
......
...@@ -171,7 +171,7 @@ MODEL_NAMES_WITH_SAME_CONFIG = { ...@@ -171,7 +171,7 @@ MODEL_NAMES_WITH_SAME_CONFIG = {
"XLS-R": "Wav2Vec2", "XLS-R": "Wav2Vec2",
"XLSR-Wav2Vec2": "Wav2Vec2", "XLSR-Wav2Vec2": "Wav2Vec2",
} }
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel"] MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel", "ChineseCLIPVisionModel"]
def get_model_table_from_auto_modules() -> str: def get_model_table_from_auto_modules() -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment