"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a573777901e662ec2e565be312ffaeedef6effec"
Unverified Commit 60293bd2 authored by David Xue's avatar David Xue Committed by GitHub
Browse files

Add dynamic resolution input/interpolate position embedding to SigLIP (#30719)



* Add interpolate positional encoding to siglip

* Change # of patches for siglip interpolation test

* fix formatting

* Apply nit suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent f26e4073
...@@ -265,11 +265,53 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -265,11 +265,53 @@ class SiglipVisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs)
that allows the model to interpolate the pre-trained position encodings such that it can be usable on
higher resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1]
num_positions = position_embeddings.shape[1]
if num_patches == num_positions and height == width:
return position_embeddings
dim = embeddings.shape[-1]
height = height // self.patch_size
width = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = position_embeddings.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
_, _, height, width = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2) embeddings = patch_embeds.flatten(2).transpose(1, 2)
embeddings = embeddings + self.position_embedding(self.position_ids) if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings return embeddings
...@@ -564,6 +606,8 @@ SIGLIP_VISION_INPUTS_DOCSTRING = r""" ...@@ -564,6 +606,8 @@ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
...@@ -601,6 +645,8 @@ SIGLIP_INPUTS_DOCSTRING = r""" ...@@ -601,6 +645,8 @@ SIGLIP_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
...@@ -848,6 +894,7 @@ class SiglipVisionTransformer(nn.Module): ...@@ -848,6 +894,7 @@ class SiglipVisionTransformer(nn.Module):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]: ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -859,7 +906,7 @@ class SiglipVisionTransformer(nn.Module): ...@@ -859,7 +906,7 @@ class SiglipVisionTransformer(nn.Module):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(pixel_values) hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
...@@ -935,6 +982,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): ...@@ -935,6 +982,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]: ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -965,6 +1013,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): ...@@ -965,6 +1013,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
...@@ -1055,6 +1104,7 @@ class SiglipModel(SiglipPreTrainedModel): ...@@ -1055,6 +1104,7 @@ class SiglipModel(SiglipPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r""" r"""
Returns: Returns:
...@@ -1092,6 +1142,7 @@ class SiglipModel(SiglipPreTrainedModel): ...@@ -1092,6 +1142,7 @@ class SiglipModel(SiglipPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
pooled_output = vision_outputs[1] pooled_output = vision_outputs[1]
...@@ -1110,6 +1161,7 @@ class SiglipModel(SiglipPreTrainedModel): ...@@ -1110,6 +1161,7 @@ class SiglipModel(SiglipPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, SiglipOutput]: ) -> Union[Tuple, SiglipOutput]:
r""" r"""
Returns: Returns:
...@@ -1152,6 +1204,7 @@ class SiglipModel(SiglipPreTrainedModel): ...@@ -1152,6 +1204,7 @@ class SiglipModel(SiglipPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
text_outputs = self.text_model( text_outputs = self.text_model(
...@@ -1226,6 +1279,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel): ...@@ -1226,6 +1279,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, ImageClassifierOutput]: ) -> Union[tuple, ImageClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...@@ -1271,6 +1325,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel): ...@@ -1271,6 +1325,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -687,3 +687,25 @@ class SiglipModelIntegrationTest(unittest.TestCase): ...@@ -687,3 +687,25 @@ class SiglipModelIntegrationTest(unittest.TestCase):
probs = torch.sigmoid(logits_per_image) # these are the probabilities probs = torch.sigmoid(logits_per_image) # these are the probabilities
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device) expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3)) self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))
@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipModel.from_pretrained(model_name).to(torch_device)
# 640 x 480 image
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640})
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the shape
# patch size = 16
# batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size
expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment