Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
923733c2
Unverified
Commit
923733c2
authored
Mar 07, 2024
by
Raushan Turganbay
Committed by
GitHub
Mar 07, 2024
Browse files
Flava multimodal add attention mask (#29446)
* flava multimodal add attn mask * make style * check mask is not None
parent
9288e759
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
9 deletions
+19
-9
src/transformers/models/flava/modeling_flava.py
src/transformers/models/flava/modeling_flava.py
+11
-1
tests/models/flava/test_modeling_flava.py
tests/models/flava/test_modeling_flava.py
+8
-8
No files found.
src/transformers/models/flava/modeling_flava.py
View file @
923733c2
...
@@ -1415,8 +1415,18 @@ class FlavaModel(FlavaPreTrainedModel):
...
@@ -1415,8 +1415,18 @@ class FlavaModel(FlavaPreTrainedModel):
multimodal_embeddings
=
None
multimodal_embeddings
=
None
multimodal_output
=
None
multimodal_output
=
None
if
image_mm_projection
is
not
None
and
text_mm_projection
is
not
None
and
not
skip_multimodal_encoder
:
if
image_mm_projection
is
not
None
and
text_mm_projection
is
not
None
and
not
skip_multimodal_encoder
:
if
attention_mask
is
not
None
:
batch_size
,
seq_len
,
_
=
image_mm_projection
.
shape
if
self
.
multimodal_model
.
use_cls_token
:
seq_len
+=
1
attention_mask_image
=
torch
.
ones
(
batch_size
,
seq_len
,
device
=
image_mm_projection
.
device
)
attention_multimodal
=
torch
.
cat
([
attention_mask_image
,
attention_mask
],
dim
=
1
)
else
:
attention_multimodal
=
None
multimodal_input
=
torch
.
cat
([
image_mm_projection
,
text_mm_projection
],
dim
=
1
)
multimodal_input
=
torch
.
cat
([
image_mm_projection
,
text_mm_projection
],
dim
=
1
)
multimodal_output
=
self
.
multimodal_model
(
multimodal_input
,
return_dict
=
return_dict
)
multimodal_output
=
self
.
multimodal_model
(
multimodal_input
,
attention_mask
=
attention_multimodal
,
return_dict
=
return_dict
)
multimodal_embeddings
=
multimodal_output
[
0
]
multimodal_embeddings
=
multimodal_output
[
0
]
if
not
return_dict
:
if
not
return_dict
:
...
...
tests/models/flava/test_modeling_flava.py
View file @
923733c2
...
@@ -1287,9 +1287,9 @@ class FlavaModelIntegrationTest(unittest.TestCase):
...
@@ -1287,9 +1287,9 @@ class FlavaModelIntegrationTest(unittest.TestCase):
outputs
=
model
(
**
inputs
,
return_dict
=
True
)
outputs
=
model
(
**
inputs
,
return_dict
=
True
)
# verify the embeddings
# verify the embeddings
self
.
assertAlmostEqual
(
outputs
.
image_embeddings
.
sum
().
item
(),
-
1352.5
3540
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
image_embeddings
.
sum
().
item
(),
-
1352.5
4943
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
text_embeddings
.
sum
().
item
(),
-
198.98225
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
text_embeddings
.
sum
().
item
(),
-
198.98225
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
multimodal_embeddings
.
sum
().
item
(),
-
3988.51367
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
multimodal_embeddings
.
sum
().
item
(),
-
4030.466552
,
places
=
4
)
@
require_vision
@
require_vision
...
@@ -1339,9 +1339,9 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
...
@@ -1339,9 +1339,9 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
expected_logits
=
torch
.
tensor
([[
16.1291
,
8.4033
],
[
16.1291
,
8.4033
]],
device
=
torch_device
)
expected_logits
=
torch
.
tensor
([[
16.1291
,
8.4033
],
[
16.1291
,
8.4033
]],
device
=
torch_device
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
contrastive_logits_per_image
,
expected_logits
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
contrastive_logits_per_image
,
expected_logits
,
atol
=
1e-3
))
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_text
.
item
(),
1.75533199
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_text
.
item
(),
2.0736470
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_image
.
item
(),
7.02
90069
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_image
.
item
(),
7.02
5580
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
11.
0626
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
11.
37761
,
places
=
4
)
@
slow
@
slow
def
test_inference_with_itm_labels
(
self
):
def
test_inference_with_itm_labels
(
self
):
...
@@ -1390,6 +1390,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
...
@@ -1390,6 +1390,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
expected_logits
=
torch
.
tensor
([[
16.1291
,
8.4033
],
[
16.1291
,
8.4033
]],
device
=
torch_device
)
expected_logits
=
torch
.
tensor
([[
16.1291
,
8.4033
],
[
16.1291
,
8.4033
]],
device
=
torch_device
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
contrastive_logits_per_image
,
expected_logits
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
contrastive_logits_per_image
,
expected_logits
,
atol
=
1e-3
))
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_text
.
item
(),
1.75533199
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_text
.
item
(),
2.0736470
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_image
.
item
(),
6.89
590501
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss_info
.
mmm_image
.
item
(),
6.89
62264
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
9.
1995
,
places
=
4
)
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
9.
6090
,
places
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment