"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "eeb70cdd770a218746342881a68c7b6bdc04690a"
Unverified Commit 923733c2 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Flava multimodal add attention mask (#29446)

* flava multimodal add attn mask

* make style

* check mask is not None
parent 9288e759
...@@ -1415,8 +1415,18 @@ class FlavaModel(FlavaPreTrainedModel): ...@@ -1415,8 +1415,18 @@ class FlavaModel(FlavaPreTrainedModel):
multimodal_embeddings = None multimodal_embeddings = None
multimodal_output = None multimodal_output = None
if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder: if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder:
if attention_mask is not None:
batch_size, seq_len, _ = image_mm_projection.shape
if self.multimodal_model.use_cls_token:
seq_len += 1
attention_mask_image = torch.ones(batch_size, seq_len, device=image_mm_projection.device)
attention_multimodal = torch.cat([attention_mask_image, attention_mask], dim=1)
else:
attention_multimodal = None
multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1) multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1)
multimodal_output = self.multimodal_model(multimodal_input, return_dict=return_dict) multimodal_output = self.multimodal_model(
multimodal_input, attention_mask=attention_multimodal, return_dict=return_dict
)
multimodal_embeddings = multimodal_output[0] multimodal_embeddings = multimodal_output[0]
if not return_dict: if not return_dict:
......
...@@ -1287,9 +1287,9 @@ class FlavaModelIntegrationTest(unittest.TestCase): ...@@ -1287,9 +1287,9 @@ class FlavaModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs, return_dict=True) outputs = model(**inputs, return_dict=True)
# verify the embeddings # verify the embeddings
self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4) self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.54943, places=4)
self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4) self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -3988.51367, places=4) self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.466552, places=4)
@require_vision @require_vision
...@@ -1339,9 +1339,9 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase): ...@@ -1339,9 +1339,9 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device) expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3)) self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4) self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4) self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.025580, places=4)
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4) self.assertAlmostEqual(outputs.loss.item(), 11.37761, places=4)
@slow @slow
def test_inference_with_itm_labels(self): def test_inference_with_itm_labels(self):
...@@ -1390,6 +1390,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase): ...@@ -1390,6 +1390,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device) expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3)) self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4) self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.89590501, places=4) self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.8962264, places=4)
self.assertAlmostEqual(outputs.loss.item(), 9.1995, places=4) self.assertAlmostEqual(outputs.loss.item(), 9.6090, places=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment