Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88a4f68f
Unverified
Commit
88a4f68f
authored
Mar 13, 2024
by
amyeroberts
Committed by
GitHub
Mar 13, 2024
Browse files
[`MaskFormer`, `Mask2Former`] Use einsum where possible (#29544)
* Use einsum where possible * Fix
parent
62478857
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
20 deletions
+46
-20
src/transformers/models/mask2former/modeling_mask2former.py
src/transformers/models/mask2former/modeling_mask2former.py
+17
-6
src/transformers/models/maskformer/modeling_maskformer.py
src/transformers/models/maskformer/modeling_maskformer.py
+29
-14
No files found.
src/transformers/models/mask2former/modeling_mask2former.py
View file @
88a4f68f
...
...
@@ -34,6 +34,7 @@ from ...file_utils import (
)
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithCrossAttentions
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
is_torch_greater_or_equal_than_2_1
from
...utils
import
is_accelerate_available
,
logging
from
...utils.backbone_utils
import
load_backbone
from
.configuration_mask2former
import
Mask2FormerConfig
...
...
@@ -2004,6 +2005,13 @@ class Mask2FormerMaskPredictor(nn.Module):
def
forward
(
self
,
outputs
:
torch
.
Tensor
,
pixel_embeddings
:
torch
.
Tensor
,
attention_mask_target_size
:
int
=
None
):
mask_embeddings
=
self
.
mask_embedder
(
outputs
.
transpose
(
0
,
1
))
is_tracing
=
(
torch
.
jit
.
is_tracing
()
or
isinstance
(
outputs
,
torch
.
fx
.
Proxy
)
or
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
())
)
# Sum up over the channels
if
is_tracing
and
not
is_torch_greater_or_equal_than_2_1
:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
...
...
@@ -2011,6 +2019,9 @@ class Mask2FormerMaskPredictor(nn.Module):
for
c
in
range
(
num_channels
):
outputs_mask
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
else
:
outputs_mask
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
attention_mask
=
nn
.
functional
.
interpolate
(
outputs_mask
,
size
=
attention_mask_target_size
,
mode
=
"bilinear"
,
align_corners
=
False
)
...
...
src/transformers/models/maskformer/modeling_maskformer.py
View file @
88a4f68f
...
...
@@ -27,6 +27,7 @@ from ...activations import ACT2FN
from
...modeling_attn_mask_utils
import
_prepare_4d_attention_mask
from
...modeling_outputs
import
BaseModelOutputWithCrossAttentions
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
is_torch_greater_or_equal_than_2_1
from
...utils
import
(
ModelOutput
,
add_start_docstrings
,
...
...
@@ -1762,6 +1763,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
pixel_embeddings
=
outputs
.
pixel_decoder_last_hidden_state
# get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits
:
List
[
str
,
Tensor
]
=
[]
is_tracing
=
(
torch
.
jit
.
is_tracing
()
or
isinstance
(
outputs
,
torch
.
fx
.
Proxy
)
or
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
())
)
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if
self
.
config
.
use_auxiliary_loss
:
stacked_transformer_decoder_outputs
=
torch
.
stack
(
outputs
.
transformer_decoder_hidden_states
)
...
...
@@ -1770,6 +1777,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
stacked_transformer_decoder_outputs
)
if
is_tracing
and
not
is_torch_greater_or_equal_than_2_1
:
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings
,
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
...
...
@@ -1778,6 +1786,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
)
for
c
in
range
(
num_channels
):
binaries_masks
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[
None
,
:,
None
,
c
]
else
:
binaries_masks
=
torch
.
einsum
(
"lbqc, bchw -> lbqhw"
,
mask_embeddings
,
pixel_embeddings
)
masks_queries_logits
=
binaries_masks
[
-
1
]
# go til [:-1] because the last one is always used
...
...
@@ -1794,12 +1804,17 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
mask_embeddings
=
self
.
mask_embedder
(
transformer_decoder_hidden_states
)
# sum up over the channels
if
is_tracing
and
not
is_torch_greater_or_equal_than_2_1
:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
masks_queries_logits
=
torch
.
zeros
((
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
masks_queries_logits
=
torch
.
zeros
(
(
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
masks_queries_logits
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
else
:
masks_queries_logits
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
return
class_queries_logits
,
masks_queries_logits
,
auxiliary_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment