Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ce2d4bc6
Unverified
Commit
ce2d4bc6
authored
Aug 29, 2023
by
amyeroberts
Committed by
GitHub
Aug 29, 2023
Browse files
MaskFormer,Mask2former - reduce memory load (#25741)
Allocate result array ahead of time
parent
0daeeb40
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
20 deletions
+22
-20
src/transformers/models/mask2former/modeling_mask2former.py
src/transformers/models/mask2former/modeling_mask2former.py
+6
-7
src/transformers/models/maskformer/modeling_maskformer.py
src/transformers/models/maskformer/modeling_maskformer.py
+16
-13
No files found.
src/transformers/models/mask2former/modeling_mask2former.py
View file @
ce2d4bc6
...
...
@@ -2011,13 +2011,12 @@ class Mask2FormerMaskPredictor(nn.Module):
def
forward
(
self
,
outputs
:
torch
.
Tensor
,
pixel_embeddings
:
torch
.
Tensor
,
attention_mask_target_size
:
int
=
None
):
mask_embeddings
=
self
.
mask_embedder
(
outputs
.
transpose
(
0
,
1
))
# Sum up over the channels
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
1
)
# (batch_size, num_queries, height, width)
outputs_mask
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
2
)
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
outputs_mask
=
torch
.
zeros
((
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
outputs_mask
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
attention_mask
=
nn
.
functional
.
interpolate
(
outputs_mask
,
size
=
attention_mask_target_size
,
mode
=
"bilinear"
,
align_corners
=
False
...
...
src/transformers/models/maskformer/modeling_maskformer.py
View file @
ce2d4bc6
...
...
@@ -1789,13 +1789,15 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
class_queries_logits
=
classes
[
-
1
]
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
stacked_transformer_decoder_outputs
)
# sum up over the channels for each embedding
# (num_embeddings, batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (1, batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
0
).
unsqueeze
(
2
)
# (num_embeddings, batch_size, num_queries, height, width)
binaries_masks
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
dim
=
3
)
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings
,
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
binaries_masks
=
torch
.
zeros
(
(
num_embeddings
,
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
binaries_masks
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[
None
,
:,
None
,
c
]
masks_queries_logits
=
binaries_masks
[
-
1
]
# go til [:-1] because the last one is always used
...
...
@@ -1811,12 +1813,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
transformer_decoder_hidden_states
)
# sum up over the channels
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
1
)
# (batch_size, num_queries, height, width)
masks_queries_logits
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
dim
=
2
)
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size
,
num_queries
,
num_channels
=
mask_embeddings
.
shape
_
,
_
,
height
,
width
=
pixel_embeddings
.
shape
masks_queries_logits
=
torch
.
zeros
((
batch_size
,
num_queries
,
height
,
width
),
device
=
mask_embeddings
.
device
)
for
c
in
range
(
num_channels
):
masks_queries_logits
+=
mask_embeddings
[...,
c
][...,
None
,
None
]
*
pixel_embeddings
[:,
None
,
c
]
return
class_queries_logits
,
masks_queries_logits
,
auxiliary_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment