"examples/vscode:/vscode.git/clone" did not exist on "948b730f9777174335812cf76de2a9dd9e4cf20e"
Unverified Commit 7f744338 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[CLIP] allow loading projection layer in vision and text model (#18962)



* allow loading projection in text and vision model

* begin tests

* finish test for CLIPTextModelTest

* style

* add slow tests

* add new classes for projection heads

* remove with_projection

* add in init

* add in doc

* fix tests

* fix some more tests

* fix copies

* fix docs

* remove leftover from fix-copies

* add the head models in IGNORE_NON_AUTO_CONFIGURED

* fix docstr

* fix tests

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add docstr for models
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9643ecf8
...@@ -125,6 +125,17 @@ This model was contributed by [valhalla](https://huggingface.co/valhalla). The o ...@@ -125,6 +125,17 @@ This model was contributed by [valhalla](https://huggingface.co/valhalla). The o
[[autodoc]] CLIPTextModel [[autodoc]] CLIPTextModel
- forward - forward
## CLIPTextModelWithProjection
[[autodoc]] CLIPTextModelWithProjection
- forward
## CLIPVisionModelWithProjection
[[autodoc]] CLIPVisionModelWithProjection
- forward
## CLIPVisionModel ## CLIPVisionModel
[[autodoc]] CLIPVisionModel [[autodoc]] CLIPVisionModel
......
...@@ -1096,7 +1096,9 @@ else: ...@@ -1096,7 +1096,9 @@ else:
"CLIPModel", "CLIPModel",
"CLIPPreTrainedModel", "CLIPPreTrainedModel",
"CLIPTextModel", "CLIPTextModel",
"CLIPTextModelWithProjection",
"CLIPVisionModel", "CLIPVisionModel",
"CLIPVisionModelWithProjection",
] ]
) )
_import_structure["models.clipseg"].extend( _import_structure["models.clipseg"].extend(
...@@ -4086,7 +4088,9 @@ if TYPE_CHECKING: ...@@ -4086,7 +4088,9 @@ if TYPE_CHECKING:
CLIPModel, CLIPModel,
CLIPPreTrainedModel, CLIPPreTrainedModel,
CLIPTextModel, CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModel, CLIPVisionModel,
CLIPVisionModelWithProjection,
) )
from .models.clipseg import ( from .models.clipseg import (
CLIPSEG_PRETRAINED_MODEL_ARCHIVE_LIST, CLIPSEG_PRETRAINED_MODEL_ARCHIVE_LIST,
......
...@@ -68,7 +68,9 @@ else: ...@@ -68,7 +68,9 @@ else:
"CLIPModel", "CLIPModel",
"CLIPPreTrainedModel", "CLIPPreTrainedModel",
"CLIPTextModel", "CLIPTextModel",
"CLIPTextModelWithProjection",
"CLIPVisionModel", "CLIPVisionModel",
"CLIPVisionModelWithProjection",
] ]
try: try:
...@@ -140,7 +142,9 @@ if TYPE_CHECKING: ...@@ -140,7 +142,9 @@ if TYPE_CHECKING:
CLIPModel, CLIPModel,
CLIPPreTrainedModel, CLIPPreTrainedModel,
CLIPTextModel, CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModel, CLIPVisionModel,
CLIPVisionModelWithProjection,
) )
try: try:
......
...@@ -98,6 +98,7 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -98,6 +98,7 @@ class CLIPTextConfig(PretrainedConfig):
vocab_size=49408, vocab_size=49408,
hidden_size=512, hidden_size=512,
intermediate_size=2048, intermediate_size=2048,
projection_dim=512,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=8, num_attention_heads=8,
max_position_embeddings=77, max_position_embeddings=77,
...@@ -117,6 +118,7 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -117,6 +118,7 @@ class CLIPTextConfig(PretrainedConfig):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.projection_dim = projection_dim
self.dropout = dropout self.dropout = dropout
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -204,6 +206,7 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -204,6 +206,7 @@ class CLIPVisionConfig(PretrainedConfig):
self, self,
hidden_size=768, hidden_size=768,
intermediate_size=3072, intermediate_size=3072,
projection_dim=512,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=12, num_attention_heads=12,
num_channels=3, num_channels=3,
...@@ -221,6 +224,7 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -221,6 +224,7 @@ class CLIPVisionConfig(PretrainedConfig):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.projection_dim = projection_dim
self.dropout = dropout self.dropout = dropout
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
......
...@@ -72,6 +72,64 @@ def clip_loss(similarity: torch.Tensor) -> torch.Tensor: ...@@ -72,6 +72,64 @@ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
return (caption_loss + image_loss) / 2.0 return (caption_loss + image_loss) / 2.0
@dataclass
class CLIPVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class CLIPTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
class CLIPOutput(ModelOutput): class CLIPOutput(ModelOutput):
""" """
...@@ -386,6 +444,16 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -386,6 +444,16 @@ class CLIPPreTrainedModel(PreTrainedModel):
module.visual_projection.weight, module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
) )
elif isinstance(module, CLIPVisionModelWithProjection):
nn.init.normal_(
module.visual_projection.weight,
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
)
elif isinstance(module, CLIPTextModelWithProjection):
nn.init.normal_(
module.text_projection.weight,
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
)
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
...@@ -399,9 +467,13 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -399,9 +467,13 @@ class CLIPPreTrainedModel(PreTrainedModel):
CLIP_START_DOCSTRING = r""" CLIP_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
behavior. etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters: Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
...@@ -685,6 +757,10 @@ class CLIPTextTransformer(nn.Module): ...@@ -685,6 +757,10 @@ class CLIPTextTransformer(nn.Module):
return mask return mask
@add_start_docstrings(
"""The text model from CLIP without any head or projection on top.""",
CLIP_START_DOCSTRING,
)
class CLIPTextModel(CLIPPreTrainedModel): class CLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig config_class = CLIPTextConfig
...@@ -730,6 +806,8 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -730,6 +806,8 @@ class CLIPTextModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state >>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.text_model( return self.text_model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -798,6 +876,10 @@ class CLIPVisionTransformer(nn.Module): ...@@ -798,6 +876,10 @@ class CLIPVisionTransformer(nn.Module):
) )
@add_start_docstrings(
"""The vision model from CLIP without any head or projection on top.""",
CLIP_START_DOCSTRING,
)
class CLIPVisionModel(CLIPPreTrainedModel): class CLIPVisionModel(CLIPPreTrainedModel):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
...@@ -842,6 +924,8 @@ class CLIPVisionModel(CLIPPreTrainedModel): ...@@ -842,6 +924,8 @@ class CLIPVisionModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state >>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states >>> pooled_output = outputs.pooler_output # pooled CLS states
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model( return self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1074,3 +1158,162 @@ class CLIPModel(CLIPPreTrainedModel): ...@@ -1074,3 +1158,162 @@ class CLIPModel(CLIPPreTrainedModel):
text_model_output=text_outputs, text_model_output=text_outputs,
vision_model_output=vision_outputs, vision_model_output=vision_outputs,
) )
@add_start_docstrings(
"""
CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
""",
CLIP_START_DOCSTRING,
)
class CLIPTextModelWithProjection(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(config)
self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPTextModelOutput]:
r"""
Returns:
Examples:
```python
>>> from transformers import CLIPTokenizer, CLIPTextModelWithProjection
>>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> text_embeds = outputs.text_embeds
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)
if not return_dict:
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
return tuple(output for output in outputs if output is not None)
return CLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)
@add_start_docstrings(
"""
CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
""",
CLIP_START_DOCSTRING,
)
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
self.vision_model = CLIPVisionTransformer(config)
self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPVisionModelOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import CLIPProcessor, CLIPVisionModelWithProjection
>>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> image_embeds = outputs.image_embeds
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = vision_outputs[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
if not return_dict:
outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)
return CLIPVisionModelOutput(
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
)
...@@ -420,7 +420,6 @@ class CLIPSegEncoderLayer(nn.Module): ...@@ -420,7 +420,6 @@ class CLIPSegEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.clip.modeling_clip.CLIPPreTrainedModel with CLIP->CLIPSeg
class CLIPSegPreTrainedModel(PreTrainedModel): class CLIPSegPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......
...@@ -1207,6 +1207,13 @@ class CLIPTextModel(metaclass=DummyObject): ...@@ -1207,6 +1207,13 @@ class CLIPTextModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class CLIPTextModelWithProjection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class CLIPVisionModel(metaclass=DummyObject): class CLIPVisionModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -1214,6 +1221,13 @@ class CLIPVisionModel(metaclass=DummyObject): ...@@ -1214,6 +1221,13 @@ class CLIPVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class CLIPVisionModelWithProjection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
CLIPSEG_PRETRAINED_MODEL_ARCHIVE_LIST = None CLIPSEG_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -151,7 +151,9 @@ for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: ...@@ -151,7 +151,9 @@ for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
_SPECIAL_SUPPORTED_MODELS = [ _SPECIAL_SUPPORTED_MODELS = [
"CLIPTextModel", "CLIPTextModel",
"CLIPTextModelWithProjection",
"CLIPVisionModel", "CLIPVisionModel",
"CLIPVisionModelWithProjection",
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"Speech2Text2Decoder", "Speech2Text2Decoder",
"TrOCRDecoder", "TrOCRDecoder",
......
...@@ -49,7 +49,13 @@ if is_torch_available(): ...@@ -49,7 +49,13 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPModel, CLIPTextModel, CLIPVisionModel from transformers import (
CLIPModel,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModel,
CLIPVisionModelWithProjection,
)
from transformers.models.clip.modeling_clip import CLIP_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.clip.modeling_clip import CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -77,6 +83,7 @@ class CLIPVisionModelTester: ...@@ -77,6 +83,7 @@ class CLIPVisionModelTester:
num_channels=3, num_channels=3,
is_training=True, is_training=True,
hidden_size=32, hidden_size=32,
projection_dim=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
...@@ -92,6 +99,7 @@ class CLIPVisionModelTester: ...@@ -92,6 +99,7 @@ class CLIPVisionModelTester:
self.num_channels = num_channels self.num_channels = num_channels
self.is_training = is_training self.is_training = is_training
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
...@@ -116,6 +124,7 @@ class CLIPVisionModelTester: ...@@ -116,6 +124,7 @@ class CLIPVisionModelTester:
patch_size=self.patch_size, patch_size=self.patch_size,
num_channels=self.num_channels, num_channels=self.num_channels,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size, intermediate_size=self.intermediate_size,
...@@ -137,6 +146,19 @@ class CLIPVisionModelTester: ...@@ -137,6 +146,19 @@ class CLIPVisionModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_model_with_projection(self, config, pixel_values):
model = CLIPVisionModelWithProjection(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs config, pixel_values = config_and_inputs
...@@ -151,7 +173,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -151,7 +173,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
attention_mask and seq_length. attention_mask and seq_length.
""" """
all_model_classes = (CLIPVisionModel,) if is_torch_available() else () all_model_classes = (CLIPVisionModel, CLIPVisionModelWithProjection) if is_torch_available() else ()
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
...@@ -193,6 +215,10 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -193,6 +215,10 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_projection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
def test_training(self): def test_training(self):
pass pass
...@@ -213,6 +239,13 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -213,6 +239,13 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
model = CLIPVisionModel.from_pretrained(model_name) model = CLIPVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow
def test_model_with_projection_from_pretrained(self):
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = CLIPVisionModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "visual_projection"))
class CLIPTextModelTester: class CLIPTextModelTester:
def __init__( def __init__(
...@@ -225,6 +258,7 @@ class CLIPTextModelTester: ...@@ -225,6 +258,7 @@ class CLIPTextModelTester:
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
hidden_size=32, hidden_size=32,
projection_dim=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
...@@ -242,6 +276,7 @@ class CLIPTextModelTester: ...@@ -242,6 +276,7 @@ class CLIPTextModelTester:
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
...@@ -273,6 +308,7 @@ class CLIPTextModelTester: ...@@ -273,6 +308,7 @@ class CLIPTextModelTester:
return CLIPTextConfig( return CLIPTextConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size, intermediate_size=self.intermediate_size,
...@@ -292,6 +328,16 @@ class CLIPTextModelTester: ...@@ -292,6 +328,16 @@ class CLIPTextModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_model_with_projection(self, config, input_ids, input_mask):
model = CLIPTextModelWithProjection(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.text_embeds.shape, (self.batch_size, self.projection_dim))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs config, input_ids, input_mask = config_and_inputs
...@@ -302,7 +348,7 @@ class CLIPTextModelTester: ...@@ -302,7 +348,7 @@ class CLIPTextModelTester:
@require_torch @require_torch
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel,) if is_torch_available() else () all_model_classes = (CLIPTextModel, CLIPTextModelWithProjection) if is_torch_available() else ()
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
...@@ -318,6 +364,10 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -318,6 +364,10 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_projection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
def test_training(self): def test_training(self):
pass pass
...@@ -342,6 +392,13 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -342,6 +392,13 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
model = CLIPTextModel.from_pretrained(model_name) model = CLIPTextModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow
def test_model_with_projection_from_pretrained(self):
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = CLIPTextModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "text_projection"))
class CLIPModelTester: class CLIPModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
......
...@@ -177,7 +177,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -177,7 +177,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"PLBartDecoderWrapper", "PLBartDecoderWrapper",
"BeitForMaskedImageModeling", "BeitForMaskedImageModeling",
"CLIPTextModel", "CLIPTextModel",
"CLIPTextModelWithProjection",
"CLIPVisionModel", "CLIPVisionModel",
"CLIPVisionModelWithProjection",
"GroupViTTextModel", "GroupViTTextModel",
"GroupViTVisionModel", "GroupViTVisionModel",
"TFCLIPTextModel", "TFCLIPTextModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment