Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88a4f68f
Unverified
Commit
88a4f68f
authored
Mar 13, 2024
by
amyeroberts
Committed by
GitHub
Mar 13, 2024
Browse files
[`MaskFormer`, `Mask2Former`] Use einsum where possible (#29544)
* Use einsum where possible * Fix
parent
62478857
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
20 deletions
+46
-20
src/transformers/models/mask2former/modeling_mask2former.py
src/transformers/models/mask2former/modeling_mask2former.py
+17
-6
src/transformers/models/maskformer/modeling_maskformer.py
src/transformers/models/maskformer/modeling_maskformer.py
+29
-14
No files found.
src/transformers/models/mask2former/modeling_mask2former.py
View file @
88a4f68f
...
...
@@ -34,6 +34,7 @@ from ...file_utils import (
)
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithCrossAttentions
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
is_torch_greater_or_equal_than_2_1
from
...utils
import
is_accelerate_available
,
logging
from
...utils.backbone_utils
import
load_backbone
from
.configuration_mask2former
import
Mask2FormerConfig
...
...
@@ -2004,12 +2005,22 @@ class Mask2FormerMaskPredictor(nn.Module):
def
forward
(
self
,
outputs
:
torch
.
Tensor
,
pixel_embeddings
:
torch
.
Tensor
,
attention_mask_target_size
:
int
=
None
):
mask_embeddings
=
self
.
mask_embedder
(
outputs
.
transpose
(
0
,
1
))
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
outputs_mask
=
torch
.
zeros
((
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
outputs_mask
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
is_tracing
=
(
torch
.
jit
.
is_tracing
()
or
isinstance
(
outputs
,
torch
.
fx
.
Proxy
)
or
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
())
)
# Sum up over the channels
if
is_tracing
and
not
is_torch_greater_or_equal_than_2_1
:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
outputs_mask
=
torch
.
zeros
((
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
outputs_mask
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
else
:
outputs_mask
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
attention_mask
=
nn
.
functional
.
interpolate
(
outputs_mask
,
size
=
attention_mask_target_size
,
mode
=
"bilinear"
,
align_corners
=
False
...
...
src/transformers/models/maskformer/modeling_maskformer.py
View file @
88a4f68f
...
...
@@ -27,6 +27,7 @@ from ...activations import ACT2FN
from
...modeling_attn_mask_utils
import
_prepare_4d_attention_mask
from
...modeling_outputs
import
BaseModelOutputWithCrossAttentions
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
is_torch_greater_or_equal_than_2_1
from
...utils
import
(
ModelOutput
,
add_start_docstrings
,
...
...
@@ -1762,6 +1763,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
pixel_embeddings
=
outputs
.
pixel_decoder_last_hidden_state
# get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits
:
List
[
str
,
Tensor
]
=
[]
is_tracing
=
(
torch
.
jit
.
is_tracing
()
or
isinstance
(
outputs
,
torch
.
fx
.
Proxy
)
or
(
hasattr
(
torch
,
"_dynamo"
)
and
torch
.
_dynamo
.
is_compiling
())
)
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if
self
.
config
.
use_auxiliary_loss
:
stacked_transformer_decoder_outputs
=
torch
.
stack
(
outputs
.
transformer_decoder_hidden_states
)
...
...
@@ -1770,14 +1777,17 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
stacked_transformer_decoder_outputs
)
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings
,
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
binaries_masks
=
torch
.
zeros
(
(
num_embeddings
,
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
binaries_masks
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[
None
,
:,
None
,
c
]
if
is_tracing
and
not
is_torch_greater_or_equal_than_2_1
:
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings
,
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
binaries_masks
=
torch
.
zeros
(
(
num_embeddings
,
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
binaries_masks
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[
None
,
:,
None
,
c
]
else
:
binaries_masks
=
torch
.
einsum
(
"lbqc, bchw -> lbqhw"
,
mask_embeddings
,
pixel_embeddings
)
masks_queries_logits
=
binaries_masks
[
-
1
]
# go til [:-1] because the last one is always used
...
...
@@ -1794,12 +1804,17 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
mask_embeddings
=
self
.
mask_embedder
(
transformer_decoder_hidden_states
)
# sum up over the channels
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
masks_queries_logits
=
torch
.
zeros
((
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
masks_queries_logits
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
if
is_tracing
and
not
is_torch_greater_or_equal_than_2_1
:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
masks_queries_logits
=
torch
.
zeros
(
(
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
masks_queries_logits
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
else
:
masks_queries_logits
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
return
class_queries_logits
,
masks_queries_logits
,
auxiliary_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment