Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c8064cdb
"...text-generation-inference.git" did not exist on "323546df1da2929e433ce197499ab71621dec51d"
Unverified
Commit
c8064cdb
authored
May 29, 2020
by
NVS Abhilash
Committed by
GitHub
May 29, 2020
Browse files
check for degenerate boxes (fixes #2240) (#2258)
parent
5ba57eae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
1 deletion
+21
-1
test/test_models.py
test/test_models.py
+6
-1
torchvision/models/detection/generalized_rcnn.py
torchvision/models/detection/generalized_rcnn.py
+15
-0
No files found.
test/test_models.py
View file @
c8064cdb
...
...
@@ -158,7 +158,7 @@ class ModelTester(TestCase):
def
_test_detection_model_validation
(
self
,
name
):
set_rng_seed
(
0
)
model
=
models
.
detection
.
__dict__
[
name
](
num_classes
=
50
,
pretrained_backbone
=
False
)
input_shape
=
(
1
,
3
,
300
,
300
)
input_shape
=
(
3
,
300
,
300
)
x
=
[
torch
.
rand
(
input_shape
)]
# validate that targets are present in training
...
...
@@ -173,6 +173,11 @@ class ModelTester(TestCase):
targets
=
[{
'boxes'
:
boxes
}]
self
.
assertRaises
(
ValueError
,
model
,
x
,
targets
=
targets
)
# validate that no degenerate boxes are present
boxes
=
torch
.
tensor
([[
1
,
3
,
1
,
4
],
[
2
,
4
,
3
,
4
]])
targets
=
[{
'boxes'
:
boxes
}]
self
.
assertRaises
(
ValueError
,
model
,
x
,
targets
=
targets
)
def
_test_video_model
(
self
,
name
):
# the default input shape is
# bs * num_channels * clip_len * h *w
...
...
torchvision/models/detection/generalized_rcnn.py
View file @
c8064cdb
...
...
@@ -77,6 +77,21 @@ class GeneralizedRCNN(nn.Module):
original_image_sizes
.
append
((
val
[
0
],
val
[
1
]))
images
,
targets
=
self
.
transform
(
images
,
targets
)
# Check for degenerate boxes
# TODO: Move this to a function
if
targets
is
not
None
:
for
target_idx
,
target
in
enumerate
(
targets
):
boxes
=
target
[
"boxes"
]
degenerate_boxes
=
boxes
[:,
2
:]
<=
boxes
[:,
:
2
]
if
degenerate_boxes
.
any
():
# print the first degenrate box
bb_idx
=
degenerate_boxes
.
any
(
dim
=
1
).
nonzero
().
view
(
-
1
)[
0
]
degen_bb
:
List
[
float
]
=
boxes
[
bb_idx
].
tolist
()
raise
ValueError
(
"All bounding boxes should have positive height and width."
" Found invaid box {} for target at index {}."
.
format
(
degen_bb
,
target_idx
))
features
=
self
.
backbone
(
images
.
tensors
)
if
isinstance
(
features
,
torch
.
Tensor
):
features
=
OrderedDict
([(
'0'
,
features
)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment