Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
36d5b8b0
Unverified
Commit
36d5b8b0
authored
Aug 08, 2023
by
amyeroberts
Committed by
GitHub
Aug 08, 2023
Browse files
MaskFormer, Mask2Former - replace einsum for tracing (#25297)
* Replace einsum with ops for tracing * Fix comment
parent
dedd1116
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
13 deletions
+29
-13
src/transformers/models/mask2former/modeling_mask2former.py
src/transformers/models/mask2former/modeling_mask2former.py
+10
-5
src/transformers/models/maskformer/modeling_maskformer.py
src/transformers/models/maskformer/modeling_maskformer.py
+15
-4
src/transformers/models/oneformer/modeling_oneformer.py
src/transformers/models/oneformer/modeling_oneformer.py
+4
-4
No files found.
src/transformers/models/mask2former/modeling_mask2former.py
View file @
36d5b8b0
...
...
@@ -359,7 +359,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
numerator
=
2
*
torch
.
einsum
(
"nc,mc->nm"
,
inputs
,
labels
)
numerator
=
2
*
torch
.
matmul
(
inputs
,
labels
.
T
)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
...
...
@@ -387,9 +387,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos
=
criterion
(
inputs
,
torch
.
ones_like
(
inputs
))
cross_entropy_loss_neg
=
criterion
(
inputs
,
torch
.
zeros_like
(
inputs
))
loss
=
torch
.
einsum
(
"nc,mc->nm"
,
cross_entropy_loss_pos
,
labels
)
+
torch
.
einsum
(
"nc,mc->nm"
,
cross_entropy_loss_neg
,
(
1
-
labels
)
)
loss
_pos
=
torch
.
matmul
(
cross_entropy_loss_pos
,
labels
.
T
)
loss_neg
=
torch
.
matmul
(
cross_entropy_loss_neg
,
(
1
-
labels
)
.
T
)
loss
=
loss_pos
+
loss_neg
loss
=
loss
/
height_and_width
return
loss
...
...
@@ -2012,7 +2012,12 @@ class Mask2FormerMaskPredictor(nn.Module):
mask_embeddings
=
self
.
mask_embedder
(
outputs
.
transpose
(
0
,
1
))
# Sum up over the channels
outputs_mask
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
1
)
# (batch_size, num_queries, height, width)
outputs_mask
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
2
)
attention_mask
=
nn
.
functional
.
interpolate
(
outputs_mask
,
size
=
attention_mask_target_size
,
mode
=
"bilinear"
,
align_corners
=
False
...
...
src/transformers/models/maskformer/modeling_maskformer.py
View file @
36d5b8b0
...
...
@@ -355,7 +355,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
numerator
=
2
*
torch
.
einsum
(
"nc,mc->nm"
,
inputs
,
labels
)
numerator
=
2
*
torch
.
matmul
(
inputs
,
labels
.
T
)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
...
...
@@ -397,7 +397,7 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
focal_neg
=
(
prob
**
gamma
)
*
cross_entropy_loss_neg
focal_neg
*=
1
-
alpha
loss
=
torch
.
einsum
(
"nc,mc->nm"
,
focal_pos
,
labels
)
+
torch
.
einsum
(
"nc,mc->nm"
,
focal_neg
,
(
1
-
labels
))
loss
=
torch
.
matmul
(
focal_pos
,
labels
.
T
)
+
torch
.
matmul
(
focal_neg
,
(
1
-
labels
)
.
T
)
return
loss
/
height_and_width
...
...
@@ -1712,7 +1712,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
stacked_transformer_decoder_outputs
)
# sum up over the channels for each embedding
binaries_masks
=
torch
.
einsum
(
"lbqc, bchw -> lbqhw"
,
mask_embeddings
,
pixel_embeddings
)
# (num_embeddings, batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (1, batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
0
).
unsqueeze
(
2
)
# (num_embeddings, batch_size, num_queries, height, width)
binaries_masks
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
dim
=
3
)
masks_queries_logits
=
binaries_masks
[
-
1
]
# go til [:-1] because the last one is always used
for
aux_binary_masks
,
aux_classes
in
zip
(
binaries_masks
[:
-
1
],
classes
[:
-
1
]):
...
...
@@ -1727,7 +1733,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
transformer_decoder_hidden_states
)
# sum up over the channels
masks_queries_logits
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
1
)
# (batch_size, num_queries, height, width)
masks_queries_logits
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
dim
=
2
)
return
class_queries_logits
,
masks_queries_logits
,
auxiliary_logits
...
...
src/transformers/models/oneformer/modeling_oneformer.py
View file @
36d5b8b0
...
...
@@ -167,7 +167,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
numerator
=
2
*
torch
.
einsum
(
"nc,mc->nm"
,
inputs
,
labels
)
numerator
=
2
*
torch
.
matmul
(
inputs
,
labels
.
T
)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
...
...
@@ -196,9 +196,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos
=
criterion
(
inputs
,
torch
.
ones_like
(
inputs
))
cross_entropy_loss_neg
=
criterion
(
inputs
,
torch
.
zeros_like
(
inputs
))
loss
=
torch
.
einsum
(
"nc,mc->nm"
,
cross_entropy_loss_pos
,
labels
)
+
torch
.
einsum
(
"nc,mc->nm"
,
cross_entropy_loss_neg
,
(
1
-
labels
)
)
loss
_pos
=
torch
.
matmul
(
cross_entropy_loss_pos
,
labels
.
T
)
loss_neg
=
torch
.
matmul
(
cross_entropy_loss_neg
,
(
1
-
labels
)
.
T
)
loss
=
loss_pos
+
loss_neg
loss
=
loss
/
height_and_width
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment