Unverified Commit c651ea98 authored by Eduardo Pacheco's avatar Eduardo Pacheco Committed by GitHub
Browse files

[Grounding DINO] Add support for cross-attention in GroundingDinoMultiHeadAttention (#30364)

* Added cross attention support

* Fixed dtypes

* Fixed assumption

* Moved to decoder
parent 408453b4
...@@ -818,7 +818,7 @@ class GroundingDinoTextEnhancerLayer(nn.Module): ...@@ -818,7 +818,7 @@ class GroundingDinoTextEnhancerLayer(nn.Module):
attention_masks = attention_masks[:, None, :, :] attention_masks = attention_masks[:, None, :, :]
attention_masks = attention_masks.repeat(1, self.num_heads, 1, 1) attention_masks = attention_masks.repeat(1, self.num_heads, 1, 1)
dtype = torch.float16 dtype = hidden_states.dtype
attention_masks = attention_masks.to(dtype=dtype) # fp16 compatibility attention_masks = attention_masks.to(dtype=dtype) # fp16 compatibility
attention_masks = (1.0 - attention_masks) * torch.finfo(dtype).min attention_masks = (1.0 - attention_masks) * torch.finfo(dtype).min
...@@ -1425,12 +1425,11 @@ class GroundingDinoDecoderLayer(nn.Module): ...@@ -1425,12 +1425,11 @@ class GroundingDinoDecoderLayer(nn.Module):
# Cross-Attention Text # Cross-Attention Text
queries = self.with_pos_embed(hidden_states, position_embeddings) queries = self.with_pos_embed(hidden_states, position_embeddings)
hidden_states, text_cross_attn_weights = self.encoder_attn_text( hidden_states, text_cross_attn_weights = self.encoder_attn_text(
queries=queries, queries=queries,
keys=text_encoder_hidden_states, keys=text_encoder_hidden_states,
values=text_encoder_hidden_states, values=text_encoder_hidden_states,
# attention_mask=text_encoder_attention_mask, # TODO fix cross-attention mask here attention_mask=text_encoder_attention_mask,
output_attentions=True, output_attentions=True,
) )
...@@ -1893,6 +1892,16 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel): ...@@ -1893,6 +1892,16 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel):
intermediate = () intermediate = ()
intermediate_reference_points = () intermediate_reference_points = ()
if text_encoder_attention_mask is not None:
dtype = text_encoder_hidden_states.dtype
text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :]
text_encoder_attention_mask = text_encoder_attention_mask.repeat(
1, self.config.decoder_attention_heads, self.config.num_queries, 1
)
text_encoder_attention_mask = text_encoder_attention_mask.to(dtype=dtype)
text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(dtype).min
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
num_coordinates = reference_points.shape[-1] num_coordinates = reference_points.shape[-1]
if num_coordinates == 4: if num_coordinates == 4:
......
...@@ -687,3 +687,29 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase): ...@@ -687,3 +687,29 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3)) self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))
self.assertTrue(torch.allclose(results_cpu["boxes"], result_gpu["boxes"].cpu(), atol=1e-3)) self.assertTrue(torch.allclose(results_cpu["boxes"], result_gpu["boxes"].cpu(), atol=1e-3))
def test_cross_attention_mask(self):
model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(torch_device)
processor = self.default_processor
image = prepare_img()
text1 = "a cat."
text2 = "a remote control."
text_batched = [text1, text2]
encoding1 = processor(images=image, text=text1, return_tensors="pt").to(torch_device)
encoding2 = processor(images=image, text=text2, return_tensors="pt").to(torch_device)
# If we batch the text and cross attention masking is working the batched result should be equal to
# The singe text result
encoding_batched = processor(
images=[image] * len(text_batched), text=text_batched, padding="longest", return_tensors="pt"
).to(torch_device)
with torch.no_grad():
outputs1 = model(**encoding1)
outputs2 = model(**encoding2)
outputs_batched = model(**encoding_batched)
self.assertTrue(torch.allclose(outputs1.logits, outputs_batched.logits[:1], atol=1e-3))
# For some reason 12 elements are > 1e-3, but the rest are fine
self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment