Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7dc8a690
Unverified
Commit
7dc8a690
authored
Jun 11, 2021
by
Zhiqiang Wang
Committed by
GitHub
Jun 11, 2021
Browse files
Port test_models_detection_negative_samples.py to pytest (#4045)
parent
e4eded48
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
34 deletions
+38
-34
test/test_models_detection_negative_samples.py
test/test_models_detection_negative_samples.py
+38
-34
No files found.
test/test_models_detection_negative_samples.py
View file @
7dc8a690
...
...
@@ -6,10 +6,11 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro
from
torchvision.models.detection.roi_heads
import
RoIHeads
from
torchvision.models.detection.faster_rcnn
import
FastRCNNPredictor
,
TwoMLPHead
import
unittest
import
pytest
from
_assert_utils
import
assert_equal
class
Test
er
(
unittest
.
TestCase
)
:
class
Test
ModelsDetectionNegativeSamples
:
def
_make_empty_sample
(
self
,
add_masks
=
False
,
add_keypoints
=
False
):
images
=
[
torch
.
rand
((
3
,
100
,
100
),
dtype
=
torch
.
float32
)]
...
...
@@ -48,13 +49,13 @@ class Tester(unittest.TestCase):
labels
,
matched_gt_boxes
=
head
.
assign_targets_to_anchors
(
anchors
,
targets
)
self
.
assert
Equal
(
labels
[
0
].
sum
()
,
0
)
self
.
assert
Equal
(
labels
[
0
].
shape
,
torch
.
Size
([
anchors
[
0
].
shape
[
0
]])
)
self
.
assert
Equal
(
labels
[
0
].
dtype
,
torch
.
float32
)
assert
labels
[
0
].
sum
()
==
0
assert
labels
[
0
].
shape
==
torch
.
Size
([
anchors
[
0
].
shape
[
0
]])
assert
labels
[
0
].
dtype
==
torch
.
float32
self
.
assert
Equal
(
matched_gt_boxes
[
0
].
sum
()
,
0
)
self
.
assert
Equal
(
matched_gt_boxes
[
0
].
shape
,
anchors
[
0
].
shape
)
self
.
assert
Equal
(
matched_gt_boxes
[
0
].
dtype
,
torch
.
float32
)
assert
matched_gt_boxes
[
0
].
sum
()
==
0
assert
matched_gt_boxes
[
0
].
shape
==
anchors
[
0
].
shape
assert
matched_gt_boxes
[
0
].
dtype
==
torch
.
float32
def
test_assign_targets_to_proposals
(
self
):
...
...
@@ -88,25 +89,28 @@ class Tester(unittest.TestCase):
matched_idxs
,
labels
=
roi_heads
.
assign_targets_to_proposals
(
proposals
,
gt_boxes
,
gt_labels
)
self
.
assertEqual
(
matched_idxs
[
0
].
sum
(),
0
)
self
.
assertEqual
(
matched_idxs
[
0
].
shape
,
torch
.
Size
([
proposals
[
0
].
shape
[
0
]]))
self
.
assertEqual
(
matched_idxs
[
0
].
dtype
,
torch
.
int64
)
self
.
assertEqual
(
labels
[
0
].
sum
(),
0
)
self
.
assertEqual
(
labels
[
0
].
shape
,
torch
.
Size
([
proposals
[
0
].
shape
[
0
]]))
self
.
assertEqual
(
labels
[
0
].
dtype
,
torch
.
int64
)
def
test_forward_negative_sample_frcnn
(
self
):
for
name
in
[
"fasterrcnn_resnet50_fpn"
,
"fasterrcnn_mobilenet_v3_large_fpn"
,
"fasterrcnn_mobilenet_v3_large_320_fpn"
]:
model
=
torchvision
.
models
.
detection
.
__dict__
[
name
](
num_classes
=
2
,
min_size
=
100
,
max_size
=
100
)
assert
matched_idxs
[
0
].
sum
()
==
0
assert
matched_idxs
[
0
].
shape
==
torch
.
Size
([
proposals
[
0
].
shape
[
0
]])
assert
matched_idxs
[
0
].
dtype
==
torch
.
int64
assert
labels
[
0
].
sum
()
==
0
assert
labels
[
0
].
shape
==
torch
.
Size
([
proposals
[
0
].
shape
[
0
]])
assert
labels
[
0
].
dtype
==
torch
.
int64
@
pytest
.
mark
.
parametrize
(
'name'
,
[
"fasterrcnn_resnet50_fpn"
,
"fasterrcnn_mobilenet_v3_large_fpn"
,
"fasterrcnn_mobilenet_v3_large_320_fpn"
,
])
def
test_forward_negative_sample_frcnn
(
self
,
name
):
model
=
torchvision
.
models
.
detection
.
__dict__
[
name
](
num_classes
=
2
,
min_size
=
100
,
max_size
=
100
)
images
,
targets
=
self
.
_make_empty_sample
()
loss_dict
=
model
(
images
,
targets
)
images
,
targets
=
self
.
_make_empty_sample
()
loss_dict
=
model
(
images
,
targets
)
self
.
assert
E
qual
(
loss_dict
[
"loss_box_reg"
],
torch
.
tensor
(
0.
))
self
.
assert
E
qual
(
loss_dict
[
"loss_rpn_box_reg"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_box_reg"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_rpn_box_reg"
],
torch
.
tensor
(
0.
))
def
test_forward_negative_sample_mrcnn
(
self
):
model
=
torchvision
.
models
.
detection
.
maskrcnn_resnet50_fpn
(
...
...
@@ -115,9 +119,9 @@ class Tester(unittest.TestCase):
images
,
targets
=
self
.
_make_empty_sample
(
add_masks
=
True
)
loss_dict
=
model
(
images
,
targets
)
self
.
assert
E
qual
(
loss_dict
[
"loss_box_reg"
],
torch
.
tensor
(
0.
))
self
.
assert
E
qual
(
loss_dict
[
"loss_rpn_box_reg"
],
torch
.
tensor
(
0.
))
self
.
assert
E
qual
(
loss_dict
[
"loss_mask"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_box_reg"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_rpn_box_reg"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_mask"
],
torch
.
tensor
(
0.
))
def
test_forward_negative_sample_krcnn
(
self
):
model
=
torchvision
.
models
.
detection
.
keypointrcnn_resnet50_fpn
(
...
...
@@ -126,9 +130,9 @@ class Tester(unittest.TestCase):
images
,
targets
=
self
.
_make_empty_sample
(
add_keypoints
=
True
)
loss_dict
=
model
(
images
,
targets
)
self
.
assert
E
qual
(
loss_dict
[
"loss_box_reg"
],
torch
.
tensor
(
0.
))
self
.
assert
E
qual
(
loss_dict
[
"loss_rpn_box_reg"
],
torch
.
tensor
(
0.
))
self
.
assert
E
qual
(
loss_dict
[
"loss_keypoint"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_box_reg"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_rpn_box_reg"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"loss_keypoint"
],
torch
.
tensor
(
0.
))
def
test_forward_negative_sample_retinanet
(
self
):
model
=
torchvision
.
models
.
detection
.
retinanet_resnet50_fpn
(
...
...
@@ -137,7 +141,7 @@ class Tester(unittest.TestCase):
images
,
targets
=
self
.
_make_empty_sample
()
loss_dict
=
model
(
images
,
targets
)
self
.
assert
E
qual
(
loss_dict
[
"bbox_regression"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"bbox_regression"
],
torch
.
tensor
(
0.
))
def
test_forward_negative_sample_ssd
(
self
):
model
=
torchvision
.
models
.
detection
.
ssd300_vgg16
(
...
...
@@ -146,8 +150,8 @@ class Tester(unittest.TestCase):
images
,
targets
=
self
.
_make_empty_sample
()
loss_dict
=
model
(
images
,
targets
)
self
.
assert
E
qual
(
loss_dict
[
"bbox_regression"
],
torch
.
tensor
(
0.
))
assert
_e
qual
(
loss_dict
[
"bbox_regression"
],
torch
.
tensor
(
0.
))
if
__name__
==
'__main__'
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment