Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
36d5b8b0
Unverified
Commit
36d5b8b0
authored
Aug 08, 2023
by
amyeroberts
Committed by
GitHub
Aug 08, 2023
Browse files
MaskFormer, Mask2Former - replace einsum for tracing (#25297)
* Replace einsum with ops for tracing * Fix comment
parent
dedd1116
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
13 deletions
+29
-13
src/transformers/models/mask2former/modeling_mask2former.py
src/transformers/models/mask2former/modeling_mask2former.py
+10
-5
src/transformers/models/maskformer/modeling_maskformer.py
src/transformers/models/maskformer/modeling_maskformer.py
+15
-4
src/transformers/models/oneformer/modeling_oneformer.py
src/transformers/models/oneformer/modeling_oneformer.py
+4
-4
No files found.
src/transformers/models/mask2former/modeling_mask2former.py
View file @
36d5b8b0
...
@@ -359,7 +359,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
...
@@ -359,7 +359,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
`torch.Tensor`: The computed loss between each pairs.
"""
"""
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
numerator
=
2
*
torch
.
einsum
(
"nc,mc->nm"
,
inputs
,
labels
)
numerator
=
2
*
torch
.
matmul
(
inputs
,
labels
.
T
)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
...
@@ -387,9 +387,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
...
@@ -387,9 +387,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos
=
criterion
(
inputs
,
torch
.
ones_like
(
inputs
))
cross_entropy_loss_pos
=
criterion
(
inputs
,
torch
.
ones_like
(
inputs
))
cross_entropy_loss_neg
=
criterion
(
inputs
,
torch
.
zeros_like
(
inputs
))
cross_entropy_loss_neg
=
criterion
(
inputs
,
torch
.
zeros_like
(
inputs
))
loss
=
torch
.
einsum
(
"nc,mc->nm"
,
cross_entropy_loss_pos
,
labels
)
+
torch
.
einsum
(
loss
_pos
=
torch
.
matmul
(
cross_entropy_loss_pos
,
labels
.
T
)
"nc,mc->nm"
,
cross_entropy_loss_neg
,
(
1
-
labels
)
loss_neg
=
torch
.
matmul
(
cross_entropy_loss_neg
,
(
1
-
labels
)
.
T
)
)
loss
=
loss_pos
+
loss_neg
loss
=
loss
/
height_and_width
loss
=
loss
/
height_and_width
return
loss
return
loss
...
@@ -2012,7 +2012,12 @@ class Mask2FormerMaskPredictor(nn.Module):
...
@@ -2012,7 +2012,12 @@ class Mask2FormerMaskPredictor(nn.Module):
mask_embeddings
=
self
.
mask_embedder
(
outputs
.
transpose
(
0
,
1
))
mask_embeddings
=
self
.
mask_embedder
(
outputs
.
transpose
(
0
,
1
))
# Sum up over the channels
# Sum up over the channels
outputs_mask
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
1
)
# (batch_size, num_queries, height, width)
outputs_mask
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
2
)
attention_mask
=
nn
.
functional
.
interpolate
(
attention_mask
=
nn
.
functional
.
interpolate
(
outputs_mask
,
size
=
attention_mask_target_size
,
mode
=
"bilinear"
,
align_corners
=
False
outputs_mask
,
size
=
attention_mask_target_size
,
mode
=
"bilinear"
,
align_corners
=
False
...
...
src/transformers/models/maskformer/modeling_maskformer.py
View file @
36d5b8b0
...
@@ -355,7 +355,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
...
@@ -355,7 +355,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
`torch.Tensor`: The computed loss between each pairs.
"""
"""
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
numerator
=
2
*
torch
.
einsum
(
"nc,mc->nm"
,
inputs
,
labels
)
numerator
=
2
*
torch
.
matmul
(
inputs
,
labels
.
T
)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
...
@@ -397,7 +397,7 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
...
@@ -397,7 +397,7 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
focal_neg
=
(
prob
**
gamma
)
*
cross_entropy_loss_neg
focal_neg
=
(
prob
**
gamma
)
*
cross_entropy_loss_neg
focal_neg
*=
1
-
alpha
focal_neg
*=
1
-
alpha
loss
=
torch
.
einsum
(
"nc,mc->nm"
,
focal_pos
,
labels
)
+
torch
.
einsum
(
"nc,mc->nm"
,
focal_neg
,
(
1
-
labels
))
loss
=
torch
.
matmul
(
focal_pos
,
labels
.
T
)
+
torch
.
matmul
(
focal_neg
,
(
1
-
labels
)
.
T
)
return
loss
/
height_and_width
return
loss
/
height_and_width
...
@@ -1712,7 +1712,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
...
@@ -1712,7 +1712,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
stacked_transformer_decoder_outputs
)
mask_embeddings
=
self
.
mask_embedder
(
stacked_transformer_decoder_outputs
)
# sum up over the channels for each embedding
# sum up over the channels for each embedding
binaries_masks
=
torch
.
einsum
(
"lbqc, bchw -> lbqhw"
,
mask_embeddings
,
pixel_embeddings
)
# (num_embeddings, batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (1, batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
0
).
unsqueeze
(
2
)
# (num_embeddings, batch_size, num_queries, height, width)
binaries_masks
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
dim
=
3
)
masks_queries_logits
=
binaries_masks
[
-
1
]
masks_queries_logits
=
binaries_masks
[
-
1
]
# go til [:-1] because the last one is always used
# go til [:-1] because the last one is always used
for
aux_binary_masks
,
aux_classes
in
zip
(
binaries_masks
[:
-
1
],
classes
[:
-
1
]):
for
aux_binary_masks
,
aux_classes
in
zip
(
binaries_masks
[:
-
1
],
classes
[:
-
1
]):
...
@@ -1727,7 +1733,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
...
@@ -1727,7 +1733,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
# get the masks
mask_embeddings
=
self
.
mask_embedder
(
transformer_decoder_hidden_states
)
mask_embeddings
=
self
.
mask_embedder
(
transformer_decoder_hidden_states
)
# sum up over the channels
# sum up over the channels
masks_queries_logits
=
torch
.
einsum
(
"bqc, bchw -> bqhw"
,
mask_embeddings
,
pixel_embeddings
)
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings
=
mask_embeddings
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings
=
pixel_embeddings
.
unsqueeze
(
1
)
# (batch_size, num_queries, height, width)
masks_queries_logits
=
(
mask_embeddings
*
pixel_embeddings
).
sum
(
dim
=
2
)
return
class_queries_logits
,
masks_queries_logits
,
auxiliary_logits
return
class_queries_logits
,
masks_queries_logits
,
auxiliary_logits
...
...
src/transformers/models/oneformer/modeling_oneformer.py
View file @
36d5b8b0
...
@@ -167,7 +167,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
...
@@ -167,7 +167,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
`torch.Tensor`: The computed loss between each pairs.
"""
"""
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
inputs
=
inputs
.
sigmoid
().
flatten
(
1
)
numerator
=
2
*
torch
.
einsum
(
"nc,mc->nm"
,
inputs
,
labels
)
numerator
=
2
*
torch
.
matmul
(
inputs
,
labels
.
T
)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
denominator
=
inputs
.
sum
(
-
1
)[:,
None
]
+
labels
.
sum
(
-
1
)[
None
,
:]
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
loss
=
1
-
(
numerator
+
1
)
/
(
denominator
+
1
)
...
@@ -196,9 +196,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
...
@@ -196,9 +196,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos
=
criterion
(
inputs
,
torch
.
ones_like
(
inputs
))
cross_entropy_loss_pos
=
criterion
(
inputs
,
torch
.
ones_like
(
inputs
))
cross_entropy_loss_neg
=
criterion
(
inputs
,
torch
.
zeros_like
(
inputs
))
cross_entropy_loss_neg
=
criterion
(
inputs
,
torch
.
zeros_like
(
inputs
))
loss
=
torch
.
einsum
(
"nc,mc->nm"
,
cross_entropy_loss_pos
,
labels
)
+
torch
.
einsum
(
loss
_pos
=
torch
.
matmul
(
cross_entropy_loss_pos
,
labels
.
T
)
"nc,mc->nm"
,
cross_entropy_loss_neg
,
(
1
-
labels
)
loss_neg
=
torch
.
matmul
(
cross_entropy_loss_neg
,
(
1
-
labels
)
.
T
)
)
loss
=
loss_pos
+
loss_neg
loss
=
loss
/
height_and_width
loss
=
loss
/
height_and_width
return
loss
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment