"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "754202de4f7b40869626e9a4ee8990e7a3067d68"
Unverified Commit a1d4563f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

`accelerate` support for `OwlViT` (#20411)

* `accelerate` support for `OwlViT`

- added `accelerate` support
- added slow `fp16` tests

* apply suggestions
parent afce73bd
...@@ -434,6 +434,9 @@ class OwlViTAttention(nn.Module): ...@@ -434,6 +434,9 @@ class OwlViTAttention(nn.Module):
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# For int8 compatibility, sometimes the `attn_probs` are in `fp32`
attn_probs = attn_probs.to(value_states.dtype)
attn_output = torch.bmm(attn_probs, value_states) attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
...@@ -528,6 +531,7 @@ class OwlViTPreTrainedModel(PreTrainedModel): ...@@ -528,6 +531,7 @@ class OwlViTPreTrainedModel(PreTrainedModel):
base_model_prefix = "owlvit" base_model_prefix = "owlvit"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
_no_split_modules = ["OwlViTEncoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -836,7 +840,8 @@ class OwlViTTextTransformer(nn.Module): ...@@ -836,7 +840,8 @@ class OwlViTTextTransformer(nn.Module):
# take features from the end of tokens embedding (end of token is the highest number in each sequence) # take features from the end of tokens embedding (end of token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[ pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device),
] ]
if not return_dict: if not return_dict:
...@@ -939,8 +944,13 @@ class OwlViTVisionTransformer(nn.Module): ...@@ -939,8 +944,13 @@ class OwlViTVisionTransformer(nn.Module):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Cast the input to the expected `dtype`
expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
pixel_values = pixel_values.to(expected_input_dtype)
hidden_states = self.embeddings(pixel_values) hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layernorm(hidden_states) hidden_states = self.pre_layernorm(hidden_states)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1193,8 +1203,9 @@ class OwlViTModel(OwlViTPreTrainedModel): ...@@ -1193,8 +1203,9 @@ class OwlViTModel(OwlViTPreTrainedModel):
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True) image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True) text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
# cosine similarity as logits # cosine similarity as logits and set it on the correct device
logit_scale = self.logit_scale.exp() logit_scale = self.logit_scale.exp().to(image_embeds.device)
logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t() logits_per_image = logits_per_text.t()
......
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
import requests import requests
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -778,3 +778,28 @@ class OwlViTModelIntegrationTest(unittest.TestCase): ...@@ -778,3 +778,28 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]] [[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
@slow
@require_torch_gpu
def test_inference_one_shot_object_detection_fp16(self):
model_name = "google/owlvit-base-patch32"
model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
processor = OwlViTProcessor.from_pretrained(model_name)
image = prepare_img()
query_image = prepare_img()
inputs = processor(
images=image,
query_images=query_image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
# No need to check the logits, we just check inference runs fine.
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment