Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c8064cdb
Unverified
Commit
c8064cdb
authored
May 29, 2020
by
NVS Abhilash
Committed by
GitHub
May 29, 2020
Browse files
check for degenerate boxes (fixes #2240) (#2258)
parent
5ba57eae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
1 deletion
+21
-1
test/test_models.py
test/test_models.py
+6
-1
torchvision/models/detection/generalized_rcnn.py
torchvision/models/detection/generalized_rcnn.py
+15
-0
No files found.
test/test_models.py
View file @
c8064cdb
...
@@ -158,7 +158,7 @@ class ModelTester(TestCase):
...
@@ -158,7 +158,7 @@ class ModelTester(TestCase):
def
_test_detection_model_validation
(
self
,
name
):
def
_test_detection_model_validation
(
self
,
name
):
set_rng_seed
(
0
)
set_rng_seed
(
0
)
model
=
models
.
detection
.
__dict__
[
name
](
num_classes
=
50
,
pretrained_backbone
=
False
)
model
=
models
.
detection
.
__dict__
[
name
](
num_classes
=
50
,
pretrained_backbone
=
False
)
input_shape
=
(
1
,
3
,
300
,
300
)
input_shape
=
(
3
,
300
,
300
)
x
=
[
torch
.
rand
(
input_shape
)]
x
=
[
torch
.
rand
(
input_shape
)]
# validate that targets are present in training
# validate that targets are present in training
...
@@ -173,6 +173,11 @@ class ModelTester(TestCase):
...
@@ -173,6 +173,11 @@ class ModelTester(TestCase):
targets
=
[{
'boxes'
:
boxes
}]
targets
=
[{
'boxes'
:
boxes
}]
self
.
assertRaises
(
ValueError
,
model
,
x
,
targets
=
targets
)
self
.
assertRaises
(
ValueError
,
model
,
x
,
targets
=
targets
)
# validate that no degenerate boxes are present
boxes
=
torch
.
tensor
([[
1
,
3
,
1
,
4
],
[
2
,
4
,
3
,
4
]])
targets
=
[{
'boxes'
:
boxes
}]
self
.
assertRaises
(
ValueError
,
model
,
x
,
targets
=
targets
)
def
_test_video_model
(
self
,
name
):
def
_test_video_model
(
self
,
name
):
# the default input shape is
# the default input shape is
# bs * num_channels * clip_len * h *w
# bs * num_channels * clip_len * h *w
...
...
torchvision/models/detection/generalized_rcnn.py
View file @
c8064cdb
...
@@ -77,6 +77,21 @@ class GeneralizedRCNN(nn.Module):
...
@@ -77,6 +77,21 @@ class GeneralizedRCNN(nn.Module):
original_image_sizes
.
append
((
val
[
0
],
val
[
1
]))
original_image_sizes
.
append
((
val
[
0
],
val
[
1
]))
images
,
targets
=
self
.
transform
(
images
,
targets
)
images
,
targets
=
self
.
transform
(
images
,
targets
)
# Check for degenerate boxes
# TODO: Move this to a function
if
targets
is
not
None
:
for
target_idx
,
target
in
enumerate
(
targets
):
boxes
=
target
[
"boxes"
]
degenerate_boxes
=
boxes
[:,
2
:]
<=
boxes
[:,
:
2
]
if
degenerate_boxes
.
any
():
# print the first degenrate box
bb_idx
=
degenerate_boxes
.
any
(
dim
=
1
).
nonzero
().
view
(
-
1
)[
0
]
degen_bb
:
List
[
float
]
=
boxes
[
bb_idx
].
tolist
()
raise
ValueError
(
"All bounding boxes should have positive height and width."
" Found invaid box {} for target at index {}."
.
format
(
degen_bb
,
target_idx
))
features
=
self
.
backbone
(
images
.
tensors
)
features
=
self
.
backbone
(
images
.
tensors
)
if
isinstance
(
features
,
torch
.
Tensor
):
if
isinstance
(
features
,
torch
.
Tensor
):
features
=
OrderedDict
([(
'0'
,
features
)])
features
=
OrderedDict
([(
'0'
,
features
)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment