Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3428a7de
Unverified
Commit
3428a7de
authored
Mar 10, 2021
by
Nicolas Hug
Committed by
GitHub
Mar 10, 2021
Browse files
Added test for aligned=True (#3540)
parent
01398088
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
3 deletions
+7
-3
test/test_ops.py
test/test_ops.py
+7
-3
No files found.
test/test_ops.py
View file @
3428a7de
...
...
@@ -54,7 +54,7 @@ class OpTester(object):
class
RoIOpTester
(
OpTester
):
def
_test_forward
(
self
,
device
,
contiguous
,
x_dtype
=
None
,
rois_dtype
=
None
):
def
_test_forward
(
self
,
device
,
contiguous
,
x_dtype
=
None
,
rois_dtype
=
None
,
**
kwargs
):
x_dtype
=
self
.
dtype
if
x_dtype
is
None
else
x_dtype
rois_dtype
=
self
.
dtype
if
rois_dtype
is
None
else
rois_dtype
pool_size
=
5
...
...
@@ -70,11 +70,11 @@ class RoIOpTester(OpTester):
dtype
=
rois_dtype
,
device
=
device
)
pool_h
,
pool_w
=
pool_size
,
pool_size
y
=
self
.
fn
(
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
)
y
=
self
.
fn
(
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
)
# the following should be true whether we're running an autocast test or not.
self
.
assertTrue
(
y
.
dtype
==
x
.
dtype
)
gt_y
=
self
.
expected_fn
(
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
device
=
device
,
dtype
=
self
.
dtype
)
sampling_ratio
=-
1
,
device
=
device
,
dtype
=
self
.
dtype
,
**
kwargs
)
tol
=
1e-3
if
(
x_dtype
is
torch
.
half
or
rois_dtype
is
torch
.
half
)
else
1e-5
self
.
assertTrue
(
torch
.
allclose
(
gt_y
.
to
(
y
.
dtype
),
y
,
rtol
=
tol
,
atol
=
tol
))
...
...
@@ -304,6 +304,10 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
def
_test_boxes_shape
(
self
):
self
.
_helper_boxes_shape
(
ops
.
roi_align
)
def
_test_forward
(
self
,
device
,
contiguous
,
x_dtype
=
None
,
rois_dtype
=
None
,
**
kwargs
):
for
aligned
in
(
True
,
False
):
super
().
_test_forward
(
device
,
contiguous
,
x_dtype
,
rois_dtype
,
aligned
=
aligned
)
class
PSRoIAlignTester
(
RoIOpTester
,
unittest
.
TestCase
):
def
fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment