Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
659829b6
Unverified
Commit
659829b6
authored
Jul 26, 2023
by
amyeroberts
Committed by
GitHub
Jul 26, 2023
Browse files
MaskFormer - enable return_dict in order to compile (#25052)
* Enable return_dict in order to compile * Update tests
parent
b914ec98
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
126 additions
and
41 deletions
+126
-41
src/transformers/models/maskformer/modeling_maskformer.py
src/transformers/models/maskformer/modeling_maskformer.py
+56
-20
tests/models/maskformer/test_modeling_maskformer.py
tests/models/maskformer/test_modeling_maskformer.py
+70
-21
No files found.
src/transformers/models/maskformer/modeling_maskformer.py
View file @
659829b6
...
@@ -1254,11 +1254,16 @@ class MaskFormerPixelDecoder(nn.Module):
...
@@ -1254,11 +1254,16 @@ class MaskFormerPixelDecoder(nn.Module):
self
.
fpn
=
MaskFormerFPNModel
(
*
args
,
feature_size
=
feature_size
,
**
kwargs
)
self
.
fpn
=
MaskFormerFPNModel
(
*
args
,
feature_size
=
feature_size
,
**
kwargs
)
self
.
mask_projection
=
nn
.
Conv2d
(
feature_size
,
mask_feature_size
,
kernel_size
=
3
,
padding
=
1
)
self
.
mask_projection
=
nn
.
Conv2d
(
feature_size
,
mask_feature_size
,
kernel_size
=
3
,
padding
=
1
)
def
forward
(
self
,
features
:
List
[
Tensor
],
output_hidden_states
:
bool
=
False
)
->
MaskFormerPixelDecoderOutput
:
def
forward
(
self
,
features
:
List
[
Tensor
],
output_hidden_states
:
bool
=
False
,
return_dict
:
bool
=
True
)
->
MaskFormerPixelDecoderOutput
:
fpn_features
=
self
.
fpn
(
features
)
fpn_features
=
self
.
fpn
(
features
)
# we use the last feature map
# we use the last feature map
last_feature_projected
=
self
.
mask_projection
(
fpn_features
[
-
1
])
last_feature_projected
=
self
.
mask_projection
(
fpn_features
[
-
1
])
if
not
return_dict
:
return
(
last_feature_projected
,
tuple
(
fpn_features
))
if
output_hidden_states
else
(
last_feature_projected
,)
return
MaskFormerPixelDecoderOutput
(
return
MaskFormerPixelDecoderOutput
(
last_hidden_state
=
last_feature_projected
,
hidden_states
=
tuple
(
fpn_features
)
if
output_hidden_states
else
()
last_hidden_state
=
last_feature_projected
,
hidden_states
=
tuple
(
fpn_features
)
if
output_hidden_states
else
()
)
)
...
@@ -1387,9 +1392,20 @@ class MaskFormerPixelLevelModule(nn.Module):
...
@@ -1387,9 +1392,20 @@ class MaskFormerPixelLevelModule(nn.Module):
lateral_widths
=
feature_channels
[:
-
1
],
lateral_widths
=
feature_channels
[:
-
1
],
)
)
def
forward
(
self
,
pixel_values
:
Tensor
,
output_hidden_states
:
bool
=
False
)
->
MaskFormerPixelLevelModuleOutput
:
def
forward
(
self
,
pixel_values
:
Tensor
,
output_hidden_states
:
bool
=
False
,
return_dict
:
bool
=
True
)
->
MaskFormerPixelLevelModuleOutput
:
features
=
self
.
encoder
(
pixel_values
).
feature_maps
features
=
self
.
encoder
(
pixel_values
).
feature_maps
decoder_output
=
self
.
decoder
(
features
,
output_hidden_states
)
decoder_output
=
self
.
decoder
(
features
,
output_hidden_states
,
return_dict
=
return_dict
)
if
not
return_dict
:
last_hidden_state
=
decoder_output
[
0
]
outputs
=
(
features
[
-
1
],
last_hidden_state
)
if
output_hidden_states
:
hidden_states
=
decoder_output
[
1
]
outputs
=
outputs
+
(
tuple
(
features
),)
+
(
hidden_states
,)
return
outputs
return
MaskFormerPixelLevelModuleOutput
(
return
MaskFormerPixelLevelModuleOutput
(
# the last feature is actually the output from the last layer
# the last feature is actually the output from the last layer
encoder_last_hidden_state
=
features
[
-
1
],
encoder_last_hidden_state
=
features
[
-
1
],
...
@@ -1414,7 +1430,11 @@ class MaskFormerTransformerModule(nn.Module):
...
@@ -1414,7 +1430,11 @@ class MaskFormerTransformerModule(nn.Module):
self
.
decoder
=
DetrDecoder
(
config
=
config
.
decoder_config
)
self
.
decoder
=
DetrDecoder
(
config
=
config
.
decoder_config
)
def
forward
(
def
forward
(
self
,
image_features
:
Tensor
,
output_hidden_states
:
bool
=
False
,
output_attentions
:
bool
=
False
self
,
image_features
:
Tensor
,
output_hidden_states
:
bool
=
False
,
output_attentions
:
bool
=
False
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
DetrDecoderOutput
:
)
->
DetrDecoderOutput
:
if
self
.
input_projection
is
not
None
:
if
self
.
input_projection
is
not
None
:
image_features
=
self
.
input_projection
(
image_features
)
image_features
=
self
.
input_projection
(
image_features
)
...
@@ -1438,7 +1458,7 @@ class MaskFormerTransformerModule(nn.Module):
...
@@ -1438,7 +1458,7 @@ class MaskFormerTransformerModule(nn.Module):
query_position_embeddings
=
queries_embeddings
,
query_position_embeddings
=
queries_embeddings
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
None
,
return_dict
=
return_dict
,
)
)
return
decoder_output
return
decoder_output
...
@@ -1593,9 +1613,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
...
@@ -1593,9 +1613,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
if
pixel_mask
is
None
:
if
pixel_mask
is
None
:
pixel_mask
=
torch
.
ones
((
batch_size
,
height
,
width
),
device
=
pixel_values
.
device
)
pixel_mask
=
torch
.
ones
((
batch_size
,
height
,
width
),
device
=
pixel_values
.
device
)
pixel_level_module_output
=
self
.
pixel_level_module
(
pixel_values
,
output_hidden_states
)
pixel_level_module_output
=
self
.
pixel_level_module
(
image_features
=
pixel_level_module_output
.
encoder_last_hidden_state
pixel_values
,
output_hidden_states
,
return_dict
=
return_dict
pixel_embeddings
=
pixel_level_module_output
.
decoder_last_hidden_state
)
image_features
=
pixel_level_module_output
[
0
]
pixel_embeddings
=
pixel_level_module_output
[
1
]
transformer_module_output
=
self
.
transformer_module
(
image_features
,
output_hidden_states
,
output_attentions
)
transformer_module_output
=
self
.
transformer_module
(
image_features
,
output_hidden_states
,
output_attentions
)
queries
=
transformer_module_output
.
last_hidden_state
queries
=
transformer_module_output
.
last_hidden_state
...
@@ -1606,9 +1628,9 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
...
@@ -1606,9 +1628,9 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
hidden_states
=
None
hidden_states
=
None
if
output_hidden_states
:
if
output_hidden_states
:
encoder_hidden_states
=
pixel_level_module_output
.
encoder_hidden_states
encoder_hidden_states
=
pixel_level_module_output
[
2
]
pixel_decoder_hidden_states
=
pixel_level_module_output
.
decoder_hidden_states
pixel_decoder_hidden_states
=
pixel_level_module_output
[
3
]
transformer_decoder_hidden_states
=
transformer_module_output
.
hidden_states
transformer_decoder_hidden_states
=
transformer_module_output
[
1
]
hidden_states
=
encoder_hidden_states
+
pixel_decoder_hidden_states
+
transformer_decoder_hidden_states
hidden_states
=
encoder_hidden_states
+
pixel_decoder_hidden_states
+
transformer_decoder_hidden_states
output
=
MaskFormerModelOutput
(
output
=
MaskFormerModelOutput
(
...
@@ -1803,13 +1825,25 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
...
@@ -1803,13 +1825,25 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
)
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
:
MaskFormerModelOutput
=
self
.
model
(
raw_
outputs
=
self
.
model
(
pixel_values
,
pixel_values
,
pixel_mask
,
pixel_mask
,
output_hidden_states
=
output_hidden_states
or
self
.
config
.
use_auxiliary_loss
,
output_hidden_states
=
output_hidden_states
or
self
.
config
.
use_auxiliary_loss
,
return_dict
=
True
,
return_dict
=
return_dict
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
)
)
# We need to have raw_outputs optionally be returned as a dict to use torch.compile. For backwards
# compatibility we convert to a dataclass for the rest of the model logic
outputs
=
MaskFormerModelOutput
(
encoder_last_hidden_state
=
raw_outputs
[
0
],
pixel_decoder_last_hidden_state
=
raw_outputs
[
1
],
transformer_decoder_last_hidden_state
=
raw_outputs
[
2
],
encoder_hidden_states
=
raw_outputs
[
3
]
if
output_hidden_states
else
None
,
pixel_decoder_hidden_states
=
raw_outputs
[
4
]
if
output_hidden_states
else
None
,
transformer_decoder_hidden_states
=
raw_outputs
[
5
]
if
output_hidden_states
else
None
,
hidden_states
=
raw_outputs
[
6
]
if
output_hidden_states
else
None
,
attentions
=
raw_outputs
[
-
1
]
if
output_attentions
else
None
,
)
loss
,
loss_dict
,
auxiliary_logits
=
None
,
None
,
None
loss
,
loss_dict
,
auxiliary_logits
=
None
,
None
,
None
...
@@ -1827,16 +1861,18 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
...
@@ -1827,16 +1861,18 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
if
not
output_auxiliary_logits
:
if
not
output_auxiliary_logits
:
auxiliary_logits
=
None
auxiliary_logits
=
None
output
=
MaskFormerForInstanceSegmentationOutput
(
if
not
return_dict
:
output
=
tuple
(
v
for
v
in
(
loss
,
class_queries_logits
,
masks_queries_logits
,
auxiliary_logits
,
*
outputs
.
values
())
if
v
is
not
None
)
return
output
return
MaskFormerForInstanceSegmentationOutput
(
loss
=
loss
,
loss
=
loss
,
**
outputs
,
**
outputs
,
class_queries_logits
=
class_queries_logits
,
class_queries_logits
=
class_queries_logits
,
masks_queries_logits
=
masks_queries_logits
,
masks_queries_logits
=
masks_queries_logits
,
auxiliary_logits
=
auxiliary_logits
,
auxiliary_logits
=
auxiliary_logits
,
)
)
if
not
return_dict
:
output
=
tuple
(
v
for
v
in
output
.
values
())
if
loss
is
not
None
:
output
=
((
loss
))
+
output
return
output
tests/models/maskformer/test_modeling_maskformer.py
View file @
659829b6
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
""" Testing suite for the PyTorch MaskFormer model. """
""" Testing suite for the PyTorch MaskFormer model. """
import
copy
import
inspect
import
inspect
import
unittest
import
unittest
...
@@ -54,6 +55,8 @@ class MaskFormerModelTester:
...
@@ -54,6 +55,8 @@ class MaskFormerModelTester:
max_size
=
32
*
6
,
max_size
=
32
*
6
,
num_labels
=
4
,
num_labels
=
4
,
mask_feature_size
=
32
,
mask_feature_size
=
32
,
num_hidden_layers
=
2
,
num_attention_heads
=
2
,
):
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -65,6 +68,9 @@ class MaskFormerModelTester:
...
@@ -65,6 +68,9 @@ class MaskFormerModelTester:
self
.
max_size
=
max_size
self
.
max_size
=
max_size
self
.
num_labels
=
num_labels
self
.
num_labels
=
num_labels
self
.
mask_feature_size
=
mask_feature_size
self
.
mask_feature_size
=
mask_feature_size
# This is passed to the decoder config. We add it to the model tester here for testing
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
pixel_values
=
floats_tensor
([
self
.
batch_size
,
self
.
num_channels
,
self
.
min_size
,
self
.
max_size
]).
to
(
pixel_values
=
floats_tensor
([
self
.
batch_size
,
self
.
num_channels
,
self
.
min_size
,
self
.
max_size
]).
to
(
...
@@ -91,11 +97,12 @@ class MaskFormerModelTester:
...
@@ -91,11 +97,12 @@ class MaskFormerModelTester:
),
),
decoder_config
=
DetrConfig
(
decoder_config
=
DetrConfig
(
decoder_ffn_dim
=
64
,
decoder_ffn_dim
=
64
,
decoder_layers
=
2
,
decoder_layers
=
self
.
num_hidden_layers
,
decoder_attention_heads
=
self
.
num_attention_heads
,
encoder_ffn_dim
=
64
,
encoder_ffn_dim
=
64
,
encoder_layers
=
2
,
encoder_layers
=
self
.
num_hidden_layers
,
encoder_attention_heads
=
self
.
num_attention_heads
,
num_queries
=
self
.
num_queries
,
num_queries
=
self
.
num_queries
,
decoder_attention_heads
=
2
,
d_model
=
self
.
mask_feature_size
,
d_model
=
self
.
mask_feature_size
,
),
),
mask_feature_size
=
self
.
mask_feature_size
,
mask_feature_size
=
self
.
mask_feature_size
,
...
@@ -196,6 +203,27 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
...
@@ -196,6 +203,27 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self
.
model_tester
=
MaskFormerModelTester
(
self
)
self
.
model_tester
=
MaskFormerModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
MaskFormerConfig
,
has_text_modality
=
False
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
MaskFormerConfig
,
has_text_modality
=
False
)
def
_prepare_for_class
(
self
,
inputs_dict
,
model_class
,
return_labels
=
False
):
inputs_dict
=
copy
.
deepcopy
(
inputs_dict
)
if
return_labels
:
if
model_class
in
[
MaskFormerForInstanceSegmentation
]:
inputs_dict
[
"mask_labels"
]
=
torch
.
zeros
(
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
,
self
.
model_tester
.
min_size
,
self
.
model_tester
.
max_size
,
),
dtype
=
torch
.
float32
,
device
=
torch_device
,
)
inputs_dict
[
"class_labels"
]
=
torch
.
zeros
(
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
),
dtype
=
torch
.
long
,
device
=
torch_device
)
return
inputs_dict
def
test_config
(
self
):
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
...
@@ -265,26 +293,47 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
...
@@ -265,26 +293,47 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self
.
model_tester
.
create_and_check_maskformer_model
(
config
,
**
inputs
,
output_hidden_states
=
True
)
self
.
model_tester
.
create_and_check_maskformer_model
(
config
,
**
inputs
,
output_hidden_states
=
True
)
def
test_attention_outputs
(
self
):
def
test_attention_outputs
(
self
):
config
,
inputs
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
.
return_dict
=
True
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
).
to
(
torch_device
)
inputs_dict
[
"output_attentions"
]
=
True
out
puts
=
model
(
**
inputs
,
output_attentions
=
True
)
in
puts
_dict
[
"output_hidden_states"
]
=
False
self
.
assertTrue
(
outputs
.
attentions
is
not
None
)
config
.
return_dict
=
True
model
=
model_class
(
config
)
def
test_training
(
self
):
model
.
to
(
torch_device
)
if
not
self
.
model_tester
.
is_training
:
model
.
eval
()
return
with
torch
.
no_grad
():
# only MaskFormerForInstanceSegmentation has the loss
outputs
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
model_class
=
self
.
all_model_classes
[
1
]
attentions
=
outputs
.
attentions
config
,
pixel_values
,
pixel_mask
,
mask_labels
,
class_labels
=
self
.
model_tester
.
prepare_config_and_inputs
(
)
self
.
assertEqual
(
len
(
attentions
),
self
.
model_tester
.
num_hidden_layers
)
# Check that output_attentions also work using config
del
inputs_dict
[
"output_attentions"
]
config
.
output_attentions
=
True
model
=
model_class
(
config
)
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
train
()
model
.
eval
()
with
torch
.
no_grad
():
outputs
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
attentions
=
outputs
.
attentions
self
.
assertEqual
(
len
(
attentions
),
self
.
model_tester
.
num_hidden_layers
)
out_len
=
len
(
outputs
)
# Check attention is always last and order is fine
inputs_dict
[
"output_attentions"
]
=
True
inputs_dict
[
"output_hidden_states"
]
=
True
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
outputs
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
# encoder_hidden_states, pixel_decoder_hidden_states, transformer_decoder_hidden_states, hidden_states
added_hidden_states
=
4
self
.
assertEqual
(
out_len
+
added_hidden_states
,
len
(
outputs
))
loss
=
model
(
pixel_values
,
mask_labels
=
mask_labels
,
class_labels
=
class_labels
).
los
s
self_attentions
=
outputs
.
attention
s
loss
.
backward
(
)
self
.
assertEqual
(
len
(
self_attentions
),
self
.
model_tester
.
num_hidden_layers
)
def
test_retain_grad_hidden_states_attentions
(
self
):
def
test_retain_grad_hidden_states_attentions
(
self
):
# only MaskFormerForInstanceSegmentation has the loss
# only MaskFormerForInstanceSegmentation has the loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment