Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
55150bfb
Unverified
Commit
55150bfb
authored
May 21, 2021
by
Nicolas Hug
Committed by
GitHub
May 21, 2021
Browse files
Use torch.testing.assert_close in test_ops.py (#3883)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
ea34cd1e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
53 deletions
+35
-53
test/test_ops.py
test/test_ops.py
+35
-53
No files found.
test/test_ops.py
View file @
55150bfb
from
common_utils
import
needs_cuda
,
cpu_only
from
common_utils
import
needs_cuda
,
cpu_only
from
_assert_utils
import
assert_equal
import
math
import
math
import
unittest
import
unittest
import
pytest
import
pytest
...
@@ -78,7 +79,7 @@ class RoIOpTester(OpTester):
...
@@ -78,7 +79,7 @@ class RoIOpTester(OpTester):
sampling_ratio
=-
1
,
device
=
device
,
dtype
=
self
.
dtype
,
**
kwargs
)
sampling_ratio
=-
1
,
device
=
device
,
dtype
=
self
.
dtype
,
**
kwargs
)
tol
=
1e-3
if
(
x_dtype
is
torch
.
half
or
rois_dtype
is
torch
.
half
)
else
1e-5
tol
=
1e-3
if
(
x_dtype
is
torch
.
half
or
rois_dtype
is
torch
.
half
)
else
1e-5
self
.
assertTrue
(
torch
.
all
close
(
gt_y
.
to
(
y
.
dtype
),
y
,
rtol
=
tol
,
atol
=
tol
)
)
torch
.
testing
.
assert_
close
(
gt_y
.
to
(
y
),
y
,
rtol
=
tol
,
atol
=
tol
)
def
_test_backward
(
self
,
device
,
contiguous
):
def
_test_backward
(
self
,
device
,
contiguous
):
pool_size
=
2
pool_size
=
2
...
@@ -363,7 +364,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
...
@@ -363,7 +364,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
abs_diff
=
torch
.
abs
(
qy
[
diff_idx
].
dequantize
()
-
quantized_float_y
[
diff_idx
].
dequantize
())
abs_diff
=
torch
.
abs
(
qy
[
diff_idx
].
dequantize
()
-
quantized_float_y
[
diff_idx
].
dequantize
())
t_scale
=
torch
.
full_like
(
abs_diff
,
fill_value
=
scale
)
t_scale
=
torch
.
full_like
(
abs_diff
,
fill_value
=
scale
)
self
.
assertTrue
(
torch
.
all
close
(
abs_diff
,
t_scale
,
atol
=
1e-5
)
)
torch
.
testing
.
assert_
close
(
abs_diff
,
t_scale
,
rtol
=
1e-5
,
atol
=
1e-5
)
x
=
torch
.
randint
(
50
,
100
,
size
=
(
2
,
3
,
10
,
10
)).
to
(
dtype
)
x
=
torch
.
randint
(
50
,
100
,
size
=
(
2
,
3
,
10
,
10
)).
to
(
dtype
)
qx
=
torch
.
quantize_per_tensor
(
x
,
scale
=
1
,
zero_point
=
0
,
dtype
=
torch
.
qint8
)
qx
=
torch
.
quantize_per_tensor
(
x
,
scale
=
1
,
zero_point
=
0
,
dtype
=
torch
.
qint8
)
...
@@ -555,7 +556,7 @@ class TestNMS:
...
@@ -555,7 +556,7 @@ class TestNMS:
iou_thres
=
0.2
iou_thres
=
0.2
keep32
=
ops
.
nms
(
boxes
,
scores
,
iou_thres
)
keep32
=
ops
.
nms
(
boxes
,
scores
,
iou_thres
)
keep16
=
ops
.
nms
(
boxes
.
to
(
torch
.
float16
),
scores
.
to
(
torch
.
float16
),
iou_thres
)
keep16
=
ops
.
nms
(
boxes
.
to
(
torch
.
float16
),
scores
.
to
(
torch
.
float16
),
iou_thres
)
assert
torch
.
all
(
torch
.
eq
(
keep32
,
keep16
)
)
assert
_equal
(
keep32
,
keep16
)
@
cpu_only
@
cpu_only
def
test_batched_nms_implementations
(
self
):
def
test_batched_nms_implementations
(
self
):
...
@@ -573,12 +574,13 @@ class TestNMS:
...
@@ -573,12 +574,13 @@ class TestNMS:
keep_vanilla
=
ops
.
boxes
.
_batched_nms_vanilla
(
boxes
,
scores
,
idxs
,
iou_threshold
)
keep_vanilla
=
ops
.
boxes
.
_batched_nms_vanilla
(
boxes
,
scores
,
idxs
,
iou_threshold
)
keep_trick
=
ops
.
boxes
.
_batched_nms_coordinate_trick
(
boxes
,
scores
,
idxs
,
iou_threshold
)
keep_trick
=
ops
.
boxes
.
_batched_nms_coordinate_trick
(
boxes
,
scores
,
idxs
,
iou_threshold
)
err_msg
=
"The vanilla and the trick implementation yield different nms outputs."
torch
.
testing
.
assert_close
(
assert
torch
.
allclose
(
keep_vanilla
,
keep_trick
),
err_msg
keep_vanilla
,
keep_trick
,
msg
=
"The vanilla and the trick implementation yield different nms outputs."
)
# Also make sure an empty tensor is returned if boxes is empty
# Also make sure an empty tensor is returned if boxes is empty
empty
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
)
empty
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
)
assert
torch
.
all
close
(
empty
,
ops
.
batched_nms
(
empty
,
None
,
None
,
None
))
torch
.
testing
.
assert_
close
(
empty
,
ops
.
batched_nms
(
empty
,
None
,
None
,
None
))
class
DeformConvTester
(
OpTester
,
unittest
.
TestCase
):
class
DeformConvTester
(
OpTester
,
unittest
.
TestCase
):
...
@@ -690,15 +692,17 @@ class DeformConvTester(OpTester, unittest.TestCase):
...
@@ -690,15 +692,17 @@ class DeformConvTester(OpTester, unittest.TestCase):
bias
=
layer
.
bias
.
data
bias
=
layer
.
bias
.
data
expected
=
self
.
expected_fn
(
x
,
weight
,
offset
,
mask
,
bias
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
expected
=
self
.
expected_fn
(
x
,
weight
,
offset
,
mask
,
bias
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
self
.
assertTrue
(
torch
.
allclose
(
res
.
to
(
expected
.
dtype
),
expected
,
rtol
=
tol
,
atol
=
tol
),
torch
.
testing
.
assert_close
(
'
\n
res:
\n
{}
\n
expected:
\n
{}'
.
format
(
res
,
expected
))
res
.
to
(
expected
),
expected
,
rtol
=
tol
,
atol
=
tol
,
msg
=
'
\n
res:
\n
{}
\n
expected:
\n
{}'
.
format
(
res
,
expected
)
)
# no modulation test
# no modulation test
res
=
layer
(
x
,
offset
)
res
=
layer
(
x
,
offset
)
expected
=
self
.
expected_fn
(
x
,
weight
,
offset
,
None
,
bias
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
expected
=
self
.
expected_fn
(
x
,
weight
,
offset
,
None
,
bias
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
self
.
assertTrue
(
torch
.
allclose
(
res
.
to
(
expected
.
dtype
),
expected
,
rtol
=
tol
,
atol
=
tol
),
torch
.
testing
.
assert_close
(
'
\n
res:
\n
{}
\n
expected:
\n
{}'
.
format
(
res
,
expected
))
res
.
to
(
expected
),
expected
,
rtol
=
tol
,
atol
=
tol
,
msg
=
'
\n
res:
\n
{}
\n
expected:
\n
{}'
.
format
(
res
,
expected
)
)
# test for wrong sizes
# test for wrong sizes
with
self
.
assertRaises
(
RuntimeError
):
with
self
.
assertRaises
(
RuntimeError
):
...
@@ -778,7 +782,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
...
@@ -778,7 +782,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
else
:
else
:
self
.
assertTrue
(
init_weight
.
grad
is
not
None
)
self
.
assertTrue
(
init_weight
.
grad
is
not
None
)
res_grads
=
init_weight
.
grad
.
to
(
"cpu"
)
res_grads
=
init_weight
.
grad
.
to
(
"cpu"
)
self
.
assert
Tru
e
(
true_cpu_grads
.
allclose
(
res_grads
)
)
torch
.
testing
.
assert
_clos
e
(
true_cpu_grads
,
res_grads
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"CUDA unavailable"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"CUDA unavailable"
)
def
test_autocast
(
self
):
def
test_autocast
(
self
):
...
@@ -812,14 +816,14 @@ class FrozenBNTester(unittest.TestCase):
...
@@ -812,14 +816,14 @@ class FrozenBNTester(unittest.TestCase):
bn
=
torch
.
nn
.
BatchNorm2d
(
sample_size
[
1
]).
eval
()
bn
=
torch
.
nn
.
BatchNorm2d
(
sample_size
[
1
]).
eval
()
bn
.
load_state_dict
(
state_dict
)
bn
.
load_state_dict
(
state_dict
)
# Difference is expected to fall in an acceptable range
# Difference is expected to fall in an acceptable range
self
.
assertTrue
(
torch
.
all
close
(
fbn
(
x
),
bn
(
x
),
atol
=
1e-6
)
)
torch
.
testing
.
assert_
close
(
fbn
(
x
),
bn
(
x
),
rtol
=
1e-5
,
atol
=
1e-6
)
# Check computation for eps > 0
# Check computation for eps > 0
fbn
=
ops
.
misc
.
FrozenBatchNorm2d
(
sample_size
[
1
],
eps
=
1e-5
)
fbn
=
ops
.
misc
.
FrozenBatchNorm2d
(
sample_size
[
1
],
eps
=
1e-5
)
fbn
.
load_state_dict
(
state_dict
,
strict
=
False
)
fbn
.
load_state_dict
(
state_dict
,
strict
=
False
)
bn
=
torch
.
nn
.
BatchNorm2d
(
sample_size
[
1
],
eps
=
1e-5
).
eval
()
bn
=
torch
.
nn
.
BatchNorm2d
(
sample_size
[
1
],
eps
=
1e-5
).
eval
()
bn
.
load_state_dict
(
state_dict
)
bn
.
load_state_dict
(
state_dict
)
self
.
assertTrue
(
torch
.
all
close
(
fbn
(
x
),
bn
(
x
),
atol
=
1e-6
)
)
torch
.
testing
.
assert_
close
(
fbn
(
x
),
bn
(
x
),
rtol
=
1e-5
,
atol
=
1e-6
)
def
test_frozenbatchnorm2d_n_arg
(
self
):
def
test_frozenbatchnorm2d_n_arg
(
self
):
"""Ensure a warning is thrown when passing `n` kwarg
"""Ensure a warning is thrown when passing `n` kwarg
...
@@ -860,20 +864,10 @@ class BoxTester(unittest.TestCase):
...
@@ -860,20 +864,10 @@ class BoxTester(unittest.TestCase):
exp_xyxy
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
exp_xyxy
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
[
10
,
15
,
30
,
35
],
[
23
,
35
,
93
,
95
]],
dtype
=
torch
.
float
)
[
10
,
15
,
30
,
35
],
[
23
,
35
,
93
,
95
]],
dtype
=
torch
.
float
)
box_same
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"xyxy"
)
assert
exp_xyxy
.
size
()
==
torch
.
Size
([
4
,
4
])
self
.
assertEqual
(
exp_xyxy
.
size
(),
torch
.
Size
([
4
,
4
]))
assert_equal
(
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"xyxy"
),
exp_xyxy
)
self
.
assertEqual
(
exp_xyxy
.
dtype
,
box_tensor
.
dtype
)
assert_equal
(
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xywh"
,
out_fmt
=
"xywh"
),
exp_xyxy
)
assert
torch
.
all
(
torch
.
eq
(
box_same
,
exp_xyxy
)).
item
()
assert_equal
(
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"cxcywh"
),
exp_xyxy
)
box_same
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xywh"
,
out_fmt
=
"xywh"
)
self
.
assertEqual
(
exp_xyxy
.
size
(),
torch
.
Size
([
4
,
4
]))
self
.
assertEqual
(
exp_xyxy
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_same
,
exp_xyxy
)).
item
()
box_same
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"cxcywh"
)
self
.
assertEqual
(
exp_xyxy
.
size
(),
torch
.
Size
([
4
,
4
]))
self
.
assertEqual
(
exp_xyxy
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_same
,
exp_xyxy
)).
item
()
def
test_bbox_xyxy_xywh
(
self
):
def
test_bbox_xyxy_xywh
(
self
):
# Simple test convert boxes to xywh and back. Make sure they are same.
# Simple test convert boxes to xywh and back. Make sure they are same.
...
@@ -883,16 +877,13 @@ class BoxTester(unittest.TestCase):
...
@@ -883,16 +877,13 @@ class BoxTester(unittest.TestCase):
exp_xywh
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
exp_xywh
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
[
10
,
15
,
20
,
20
],
[
23
,
35
,
70
,
60
]],
dtype
=
torch
.
float
)
[
10
,
15
,
20
,
20
],
[
23
,
35
,
70
,
60
]],
dtype
=
torch
.
float
)
assert
exp_xywh
.
size
()
==
torch
.
Size
([
4
,
4
])
box_xywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"xywh"
)
box_xywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"xywh"
)
self
.
assertEqual
(
exp_xywh
.
size
(),
torch
.
Size
([
4
,
4
]))
assert_equal
(
box_xywh
,
exp_xywh
)
self
.
assertEqual
(
exp_xywh
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_xywh
,
exp_xywh
)).
item
()
# Reverse conversion
# Reverse conversion
box_xyxy
=
ops
.
box_convert
(
box_xywh
,
in_fmt
=
"xywh"
,
out_fmt
=
"xyxy"
)
box_xyxy
=
ops
.
box_convert
(
box_xywh
,
in_fmt
=
"xywh"
,
out_fmt
=
"xyxy"
)
self
.
assertEqual
(
box_xyxy
.
size
(),
torch
.
Size
([
4
,
4
]))
assert_equal
(
box_xyxy
,
box_tensor
)
self
.
assertEqual
(
box_xyxy
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_xyxy
,
box_tensor
)).
item
()
def
test_bbox_xyxy_cxcywh
(
self
):
def
test_bbox_xyxy_cxcywh
(
self
):
# Simple test convert boxes to xywh and back. Make sure they are same.
# Simple test convert boxes to xywh and back. Make sure they are same.
...
@@ -902,16 +893,13 @@ class BoxTester(unittest.TestCase):
...
@@ -902,16 +893,13 @@ class BoxTester(unittest.TestCase):
exp_cxcywh
=
torch
.
tensor
([[
50
,
50
,
100
,
100
],
[
0
,
0
,
0
,
0
],
exp_cxcywh
=
torch
.
tensor
([[
50
,
50
,
100
,
100
],
[
0
,
0
,
0
,
0
],
[
20
,
25
,
20
,
20
],
[
58
,
65
,
70
,
60
]],
dtype
=
torch
.
float
)
[
20
,
25
,
20
,
20
],
[
58
,
65
,
70
,
60
]],
dtype
=
torch
.
float
)
assert
exp_cxcywh
.
size
()
==
torch
.
Size
([
4
,
4
])
box_cxcywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"cxcywh"
)
box_cxcywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"cxcywh"
)
self
.
assertEqual
(
exp_cxcywh
.
size
(),
torch
.
Size
([
4
,
4
]))
assert_equal
(
box_cxcywh
,
exp_cxcywh
)
self
.
assertEqual
(
exp_cxcywh
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_cxcywh
,
exp_cxcywh
)).
item
()
# Reverse conversion
# Reverse conversion
box_xyxy
=
ops
.
box_convert
(
box_cxcywh
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"xyxy"
)
box_xyxy
=
ops
.
box_convert
(
box_cxcywh
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"xyxy"
)
self
.
assertEqual
(
box_xyxy
.
size
(),
torch
.
Size
([
4
,
4
]))
assert_equal
(
box_xyxy
,
box_tensor
)
self
.
assertEqual
(
box_xyxy
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_xyxy
,
box_tensor
)).
item
()
def
test_bbox_xywh_cxcywh
(
self
):
def
test_bbox_xywh_cxcywh
(
self
):
box_tensor
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
box_tensor
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
...
@@ -921,16 +909,13 @@ class BoxTester(unittest.TestCase):
...
@@ -921,16 +909,13 @@ class BoxTester(unittest.TestCase):
exp_cxcywh
=
torch
.
tensor
([[
50
,
50
,
100
,
100
],
[
0
,
0
,
0
,
0
],
exp_cxcywh
=
torch
.
tensor
([[
50
,
50
,
100
,
100
],
[
0
,
0
,
0
,
0
],
[
20
,
25
,
20
,
20
],
[
58
,
65
,
70
,
60
]],
dtype
=
torch
.
float
)
[
20
,
25
,
20
,
20
],
[
58
,
65
,
70
,
60
]],
dtype
=
torch
.
float
)
assert
exp_cxcywh
.
size
()
==
torch
.
Size
([
4
,
4
])
box_cxcywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xywh"
,
out_fmt
=
"cxcywh"
)
box_cxcywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xywh"
,
out_fmt
=
"cxcywh"
)
self
.
assertEqual
(
exp_cxcywh
.
size
(),
torch
.
Size
([
4
,
4
]))
assert_equal
(
box_cxcywh
,
exp_cxcywh
)
self
.
assertEqual
(
exp_cxcywh
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_cxcywh
,
exp_cxcywh
)).
item
()
# Reverse conversion
# Reverse conversion
box_xywh
=
ops
.
box_convert
(
box_cxcywh
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"xywh"
)
box_xywh
=
ops
.
box_convert
(
box_cxcywh
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"xywh"
)
self
.
assertEqual
(
box_xywh
.
size
(),
torch
.
Size
([
4
,
4
]))
assert_equal
(
box_xywh
,
box_tensor
)
self
.
assertEqual
(
box_xywh
.
dtype
,
box_tensor
.
dtype
)
assert
torch
.
all
(
torch
.
eq
(
box_xywh
,
box_tensor
)).
item
()
def
test_bbox_invalid
(
self
):
def
test_bbox_invalid
(
self
):
box_tensor
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
box_tensor
=
torch
.
tensor
([[
0
,
0
,
100
,
100
],
[
0
,
0
,
0
,
0
],
...
@@ -951,19 +936,18 @@ class BoxTester(unittest.TestCase):
...
@@ -951,19 +936,18 @@ class BoxTester(unittest.TestCase):
box_xywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"xywh"
)
box_xywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"xywh"
)
scripted_xywh
=
scripted_fn
(
box_tensor
,
'xyxy'
,
'xywh'
)
scripted_xywh
=
scripted_fn
(
box_tensor
,
'xyxy'
,
'xywh'
)
self
.
assert
True
(
(
scripted_xywh
-
box_xywh
).
abs
().
max
()
<
TOLERANCE
)
torch
.
testing
.
assert
_close
(
scripted_xywh
,
box_xywh
,
rtol
=
0.0
,
atol
=
TOLERANCE
)
box_cxcywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"cxcywh"
)
box_cxcywh
=
ops
.
box_convert
(
box_tensor
,
in_fmt
=
"xyxy"
,
out_fmt
=
"cxcywh"
)
scripted_cxcywh
=
scripted_fn
(
box_tensor
,
'xyxy'
,
'cxcywh'
)
scripted_cxcywh
=
scripted_fn
(
box_tensor
,
'xyxy'
,
'cxcywh'
)
self
.
assert
True
(
(
scripted_cxcywh
-
box_cxcywh
).
abs
().
max
()
<
TOLERANCE
)
torch
.
testing
.
assert
_close
(
scripted_cxcywh
,
box_cxcywh
,
rtol
=
0.0
,
atol
=
TOLERANCE
)
class
BoxAreaTester
(
unittest
.
TestCase
):
class
BoxAreaTester
(
unittest
.
TestCase
):
def
test_box_area
(
self
):
def
test_box_area
(
self
):
def
area_check
(
box
,
expected
,
tolerance
=
1e-4
):
def
area_check
(
box
,
expected
,
tolerance
=
1e-4
):
out
=
ops
.
box_area
(
box
)
out
=
ops
.
box_area
(
box
)
assert
out
.
size
()
==
expected
.
size
()
torch
.
testing
.
assert_close
(
out
,
expected
,
rtol
=
0.0
,
check_dtype
=
False
,
atol
=
tolerance
)
assert
((
out
-
expected
).
abs
().
max
()
<
tolerance
).
item
()
# Check for int boxes
# Check for int boxes
for
dtype
in
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
for
dtype
in
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
...
@@ -991,8 +975,7 @@ class BoxIouTester(unittest.TestCase):
...
@@ -991,8 +975,7 @@ class BoxIouTester(unittest.TestCase):
def
test_iou
(
self
):
def
test_iou
(
self
):
def
iou_check
(
box
,
expected
,
tolerance
=
1e-4
):
def
iou_check
(
box
,
expected
,
tolerance
=
1e-4
):
out
=
ops
.
box_iou
(
box
,
box
)
out
=
ops
.
box_iou
(
box
,
box
)
assert
out
.
size
()
==
expected
.
size
()
torch
.
testing
.
assert_close
(
out
,
expected
,
rtol
=
0.0
,
check_dtype
=
False
,
atol
=
tolerance
)
assert
((
out
-
expected
).
abs
().
max
()
<
tolerance
).
item
()
# Check for int boxes
# Check for int boxes
for
dtype
in
[
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
for
dtype
in
[
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
...
@@ -1013,8 +996,7 @@ class GenBoxIouTester(unittest.TestCase):
...
@@ -1013,8 +996,7 @@ class GenBoxIouTester(unittest.TestCase):
def
test_gen_iou
(
self
):
def
test_gen_iou
(
self
):
def
gen_iou_check
(
box
,
expected
,
tolerance
=
1e-4
):
def
gen_iou_check
(
box
,
expected
,
tolerance
=
1e-4
):
out
=
ops
.
generalized_box_iou
(
box
,
box
)
out
=
ops
.
generalized_box_iou
(
box
,
box
)
assert
out
.
size
()
==
expected
.
size
()
torch
.
testing
.
assert_close
(
out
,
expected
,
rtol
=
0.0
,
check_dtype
=
False
,
atol
=
tolerance
)
assert
((
out
-
expected
).
abs
().
max
()
<
tolerance
).
item
()
# Check for int boxes
# Check for int boxes
for
dtype
in
[
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
for
dtype
in
[
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment