Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
3bb0611a
Commit
3bb0611a
authored
Nov 08, 2022
by
q.yao
Committed by
Zaida Zhou
Nov 23, 2022
Browse files
[Fix] Create Tensor with new_* method to support amp (#2389)
parent
ac470881
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
116 additions
and
184 deletions
+116
-184
mmcv/ops/diff_iou_rotated.py
mmcv/ops/diff_iou_rotated.py
+2
-2
mmcv/ops/group_points.py
mmcv/ops/group_points.py
+3
-3
mmcv/ops/three_interpolate.py
mmcv/ops/three_interpolate.py
+2
-2
tests/test_ops/test_group_points.py
tests/test_ops/test_group_points.py
+90
-161
tests/test_ops/test_three_interpolate.py
tests/test_ops/test_three_interpolate.py
+19
-16
No files found.
mmcv/ops/diff_iou_rotated.py
View file @
3bb0611a
...
...
@@ -235,9 +235,9 @@ def box2corners(box: Tensor) -> Tensor:
"""
B
=
box
.
size
()[
0
]
x
,
y
,
w
,
h
,
alpha
=
box
.
split
([
1
,
1
,
1
,
1
,
1
],
dim
=-
1
)
x4
=
torch
.
FloatT
ensor
([
0.5
,
-
0.5
,
-
0.5
,
0.5
]).
to
(
box
.
device
)
x4
=
box
.
new_t
ensor
([
0.5
,
-
0.5
,
-
0.5
,
0.5
]).
to
(
box
.
device
)
x4
=
x4
*
w
# (B, N, 4)
y4
=
torch
.
FloatT
ensor
([
0.5
,
0.5
,
-
0.5
,
-
0.5
]).
to
(
box
.
device
)
y4
=
box
.
new_t
ensor
([
0.5
,
0.5
,
-
0.5
,
-
0.5
]).
to
(
box
.
device
)
y4
=
y4
*
h
# (B, N, 4)
corners
=
torch
.
stack
([
x4
,
y4
],
dim
=-
1
)
# (B, N, 4, 2)
sin
=
torch
.
sin
(
alpha
)
...
...
mmcv/ops/group_points.py
View file @
3bb0611a
...
...
@@ -233,7 +233,7 @@ class GroupingOperation(Function):
else
:
B
,
nfeatures
,
nsample
=
indices
.
size
()
_
,
C
,
N
=
features
.
size
()
output
=
torch
.
cuda
.
FloatTensor
(
B
,
C
,
nfeatures
,
nsample
)
output
=
features
.
new_zeros
(
B
,
C
,
nfeatures
,
nsample
)
ext_module
.
group_points_forward
(
features
,
...
...
@@ -262,7 +262,7 @@ class GroupingOperation(Function):
idx
,
N
=
ctx
.
for_backwards
B
,
C
,
npoint
,
nsample
=
grad_out
.
size
()
grad_features
=
torch
.
cuda
.
FloatTensor
(
B
,
C
,
N
)
.
zero_
()
grad_features
=
grad_out
.
new_zeros
(
B
,
C
,
N
)
grad_out_data
=
grad_out
.
data
.
contiguous
()
ext_module
.
group_points_backward
(
...
...
@@ -279,7 +279,7 @@ class GroupingOperation(Function):
B
,
N
,
idx
,
features_batch_cnt
,
idx_batch_cnt
=
ctx
.
for_backwards
M
,
C
,
nsample
=
grad_out
.
size
()
grad_features
=
torch
.
cuda
.
FloatTensor
(
N
,
C
).
zero_
(
)
grad_features
=
grad_out
.
new_zeros
(
N
,
C
)
grad_out_data
=
grad_out
.
data
.
contiguous
()
ext_module
.
stack_group_points_backward
(
...
...
mmcv/ops/three_interpolate.py
View file @
3bb0611a
...
...
@@ -38,7 +38,7 @@ class ThreeInterpolate(Function):
B
,
c
,
m
=
features
.
size
()
n
=
indices
.
size
(
1
)
ctx
.
three_interpolate_for_backward
=
(
indices
,
weight
,
m
)
output
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
n
)
output
=
features
.
new_empty
(
B
,
c
,
n
)
ext_module
.
three_interpolate_forward
(
features
,
indices
,
weight
,
output
,
b
=
B
,
c
=
c
,
m
=
m
,
n
=
n
)
...
...
@@ -58,7 +58,7 @@ class ThreeInterpolate(Function):
idx
,
weight
,
m
=
ctx
.
three_interpolate_for_backward
B
,
c
,
n
=
grad_out
.
size
()
grad_features
=
torch
.
cuda
.
FloatTensor
(
B
,
c
,
m
)
.
zero_
()
grad_features
=
grad_out
.
new_zeros
(
B
,
c
,
m
)
grad_out_data
=
grad_out
.
data
.
contiguous
()
ext_module
.
three_interpolate_backward
(
...
...
tests/test_ops/test_group_points.py
View file @
3bb0611a
...
...
@@ -7,7 +7,8 @@ from mmcv.ops import grouping_operation
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_grouping_points
():
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
half
,
torch
.
float
,
torch
.
double
])
def
test_grouping_points
(
dtype
):
idx
=
torch
.
tensor
([[[
0
,
0
,
0
],
[
3
,
3
,
3
],
[
8
,
8
,
8
],
[
0
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
[[
0
,
0
,
0
],
[
6
,
6
,
6
],
[
9
,
9
,
9
],
[
0
,
0
,
0
],
[
0
,
0
,
0
],
...
...
@@ -35,51 +36,37 @@ def test_grouping_points():
[
-
0.6646
,
-
0.6870
,
-
0.1125
,
-
0.2224
,
-
0.3445
,
-
1.4049
,
0.4990
,
-
0.7037
,
-
0.9924
,
0.0386
]]]).
cuda
()
]]],
dtype
=
dtype
).
cuda
()
output
=
grouping_operation
(
features
,
idx
)
expected_output
=
torch
.
tensor
([[[[
0.5798
,
0.5798
,
0.5798
],
[
-
1.3311
,
-
1.3311
,
-
1.3311
],
[
0.9268
,
0.9268
,
0.9268
],
[
0.5798
,
0.5798
,
0.5798
],
[
0.5798
,
0.5798
,
0.5798
],
[
0.5798
,
0.5798
,
0.5798
]],
[[
5.4247
,
5.4247
,
5.4247
],
[
1.4740
,
1.4740
,
1.4740
],
[
2.1581
,
2.1581
,
2.1581
],
[
5.4247
,
5.4247
,
5.4247
],
[
5.4247
,
5.4247
,
5.4247
],
[
5.4247
,
5.4247
,
5.4247
]],
[[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6931
,
-
1.6931
,
-
1.6931
],
[
-
1.6786
,
-
1.6786
,
-
1.6786
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
]]],
[[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
]],
[[
1.1773
,
1.1773
,
1.1773
],
[
6.0865
,
6.0865
,
6.0865
],
[
2.8229
,
2.8229
,
2.8229
],
[
1.1773
,
1.1773
,
1.1773
],
[
1.1773
,
1.1773
,
1.1773
],
[
1.1773
,
1.1773
,
1.1773
]],
[[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
0.4990
,
0.4990
,
0.4990
],
[
0.0386
,
0.0386
,
0.0386
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
]]]]).
cuda
()
expected_output
=
torch
.
tensor
(
[[[[
0.5798
,
0.5798
,
0.5798
],
[
-
1.3311
,
-
1.3311
,
-
1.3311
],
[
0.9268
,
0.9268
,
0.9268
],
[
0.5798
,
0.5798
,
0.5798
],
[
0.5798
,
0.5798
,
0.5798
],
[
0.5798
,
0.5798
,
0.5798
]],
[[
5.4247
,
5.4247
,
5.4247
],
[
1.4740
,
1.4740
,
1.4740
],
[
2.1581
,
2.1581
,
2.1581
],
[
5.4247
,
5.4247
,
5.4247
],
[
5.4247
,
5.4247
,
5.4247
],
[
5.4247
,
5.4247
,
5.4247
]],
[[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6931
,
-
1.6931
,
-
1.6931
],
[
-
1.6786
,
-
1.6786
,
-
1.6786
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.6266
,
-
1.6266
,
-
1.6266
]]],
[[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.0380
,
-
0.0380
,
-
0.0380
]],
[[
1.1773
,
1.1773
,
1.1773
],
[
6.0865
,
6.0865
,
6.0865
],
[
2.8229
,
2.8229
,
2.8229
],
[
1.1773
,
1.1773
,
1.1773
],
[
1.1773
,
1.1773
,
1.1773
],
[
1.1773
,
1.1773
,
1.1773
]],
[[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
0.4990
,
0.4990
,
0.4990
],
[
0.0386
,
0.0386
,
0.0386
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
],
[
-
0.6646
,
-
0.6646
,
-
0.6646
]]]],
dtype
=
dtype
).
cuda
()
assert
torch
.
allclose
(
output
,
expected_output
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_stack_grouping_points
():
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
half
,
torch
.
float
,
torch
.
double
])
def
test_stack_grouping_points
(
dtype
):
idx
=
torch
.
tensor
([[
0
,
0
,
0
],
[
3
,
3
,
3
],
[
8
,
8
,
8
],
[
1
,
1
,
1
],
[
0
,
0
,
0
],
[
2
,
2
,
2
],
[
0
,
0
,
0
],
[
6
,
6
,
6
],
[
9
,
9
,
9
],
[
0
,
0
,
0
],
[
1
,
1
,
1
],
[
0
,
0
,
0
]]).
int
().
cuda
()
...
...
@@ -106,130 +93,72 @@ def test_stack_grouping_points():
[
-
0.6646
,
-
0.6870
,
-
0.1125
,
-
0.2224
,
-
0.3445
,
-
1.4049
,
0.4990
,
-
0.7037
,
-
0.9924
,
0.0386
]]).
float
().
cuda
()
]],
dtype
=
dtype
).
cuda
()
features_batch_cnt
=
torch
.
tensor
([
3
,
3
]).
int
().
cuda
()
indices_batch_cnt
=
torch
.
tensor
([
6
,
6
]).
int
().
cuda
()
output
=
grouping_operation
(
features
,
idx
,
features_batch_cnt
,
indices_batch_cnt
)
expected_output
=
torch
.
Tensor
([[[
0.5798
,
0.5798
,
0.5798
],
[
-
0.7981
,
-
0.7981
,
-
0.7981
],
[
-
0.9280
,
-
0.9280
,
-
0.9280
],
[
-
1.3311
,
-
1.3311
,
-
1.3311
],
[
1.3687
,
1.3687
,
1.3687
],
[
0.9277
,
0.9277
,
0.9277
],
[
-
0.4164
,
-
0.4164
,
-
0.4164
],
[
-
1.8274
,
-
1.8274
,
-
1.8274
],
[
0.9268
,
0.9268
,
0.9268
],
[
0.8414
,
0.8414
,
0.8414
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
5.4247
,
5.4247
,
5.4247
],
[
1.5113
,
1.5113
,
1.5113
],
[
2.3944
,
2.3944
,
2.3944
],
[
1.4740
,
1.4740
,
1.4740
],
[
5.0300
,
5.0300
,
5.0300
],
[
5.1030
,
5.1030
,
5.1030
],
[
1.9360
,
1.9360
,
1.9360
],
[
2.1939
,
2.1939
,
2.1939
],
[
2.1581
,
2.1581
,
2.1581
],
[
3.4666
,
3.4666
,
3.4666
]],
[[
0.5798
,
0.5798
,
0.5798
],
[
-
0.7981
,
-
0.7981
,
-
0.7981
],
[
-
0.9280
,
-
0.9280
,
-
0.9280
],
[
-
1.3311
,
-
1.3311
,
-
1.3311
],
[
1.3687
,
1.3687
,
1.3687
],
[
0.9277
,
0.9277
,
0.9277
],
[
-
0.4164
,
-
0.4164
,
-
0.4164
],
[
-
1.8274
,
-
1.8274
,
-
1.8274
],
[
0.9268
,
0.9268
,
0.9268
],
[
0.8414
,
0.8414
,
0.8414
]],
[[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.0281
,
-
1.0281
,
-
1.0281
],
[
-
1.0393
,
-
1.0393
,
-
1.0393
],
[
-
1.6931
,
-
1.6931
,
-
1.6931
],
[
-
1.3982
,
-
1.3982
,
-
1.3982
],
[
-
0.5732
,
-
0.5732
,
-
0.5732
],
[
-
1.0830
,
-
1.0830
,
-
1.0830
],
[
-
1.7561
,
-
1.7561
,
-
1.7561
],
[
-
1.6786
,
-
1.6786
,
-
1.6786
],
[
-
1.6967
,
-
1.6967
,
-
1.6967
]],
[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.1880
,
-
0.1880
,
-
0.1880
],
[
-
1.5724
,
-
1.5724
,
-
1.5724
],
[
0.6905
,
0.6905
,
0.6905
],
[
-
0.3190
,
-
0.3190
,
-
0.3190
],
[
0.7798
,
0.7798
,
0.7798
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
0.9457
,
-
0.9457
,
-
0.9457
],
[
-
0.2942
,
-
0.2942
,
-
0.2942
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.1880
,
-
0.1880
,
-
0.1880
],
[
-
1.5724
,
-
1.5724
,
-
1.5724
],
[
0.6905
,
0.6905
,
0.6905
],
[
-
0.3190
,
-
0.3190
,
-
0.3190
],
[
0.7798
,
0.7798
,
0.7798
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
0.9457
,
-
0.9457
,
-
0.9457
],
[
-
0.2942
,
-
0.2942
,
-
0.2942
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
]],
[[
1.1773
,
1.1773
,
1.1773
],
[
1.5009
,
1.5009
,
1.5009
],
[
2.6399
,
2.6399
,
2.6399
],
[
5.9242
,
5.9242
,
5.9242
],
[
1.0962
,
1.0962
,
1.0962
],
[
2.7346
,
2.7346
,
2.7346
],
[
6.0865
,
6.0865
,
6.0865
],
[
1.5555
,
1.5555
,
1.5555
],
[
4.3303
,
4.3303
,
4.3303
],
[
2.8229
,
2.8229
,
2.8229
]],
[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.1880
,
-
0.1880
,
-
0.1880
],
[
-
1.5724
,
-
1.5724
,
-
1.5724
],
[
0.6905
,
0.6905
,
0.6905
],
[
-
0.3190
,
-
0.3190
,
-
0.3190
],
[
0.7798
,
0.7798
,
0.7798
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
0.9457
,
-
0.9457
,
-
0.9457
],
[
-
0.2942
,
-
0.2942
,
-
0.2942
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
]]]).
cuda
().
float
()
expected_output
=
torch
.
tensor
(
[[[
0.5798
,
0.5798
,
0.5798
],
[
-
0.7981
,
-
0.7981
,
-
0.7981
],
[
-
0.9280
,
-
0.9280
,
-
0.9280
],
[
-
1.3311
,
-
1.3311
,
-
1.3311
],
[
1.3687
,
1.3687
,
1.3687
],
[
0.9277
,
0.9277
,
0.9277
],
[
-
0.4164
,
-
0.4164
,
-
0.4164
],
[
-
1.8274
,
-
1.8274
,
-
1.8274
],
[
0.9268
,
0.9268
,
0.9268
],
[
0.8414
,
0.8414
,
0.8414
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
5.4247
,
5.4247
,
5.4247
],
[
1.5113
,
1.5113
,
1.5113
],
[
2.3944
,
2.3944
,
2.3944
],
[
1.4740
,
1.4740
,
1.4740
],
[
5.0300
,
5.0300
,
5.0300
],
[
5.1030
,
5.1030
,
5.1030
],
[
1.9360
,
1.9360
,
1.9360
],
[
2.1939
,
2.1939
,
2.1939
],
[
2.1581
,
2.1581
,
2.1581
],
[
3.4666
,
3.4666
,
3.4666
]],
[[
0.5798
,
0.5798
,
0.5798
],
[
-
0.7981
,
-
0.7981
,
-
0.7981
],
[
-
0.9280
,
-
0.9280
,
-
0.9280
],
[
-
1.3311
,
-
1.3311
,
-
1.3311
],
[
1.3687
,
1.3687
,
1.3687
],
[
0.9277
,
0.9277
,
0.9277
],
[
-
0.4164
,
-
0.4164
,
-
0.4164
],
[
-
1.8274
,
-
1.8274
,
-
1.8274
],
[
0.9268
,
0.9268
,
0.9268
],
[
0.8414
,
0.8414
,
0.8414
]],
[[
-
1.6266
,
-
1.6266
,
-
1.6266
],
[
-
1.0281
,
-
1.0281
,
-
1.0281
],
[
-
1.0393
,
-
1.0393
,
-
1.0393
],
[
-
1.6931
,
-
1.6931
,
-
1.6931
],
[
-
1.3982
,
-
1.3982
,
-
1.3982
],
[
-
0.5732
,
-
0.5732
,
-
0.5732
],
[
-
1.0830
,
-
1.0830
,
-
1.0830
],
[
-
1.7561
,
-
1.7561
,
-
1.7561
],
[
-
1.6786
,
-
1.6786
,
-
1.6786
],
[
-
1.6967
,
-
1.6967
,
-
1.6967
]],
[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.1880
,
-
0.1880
,
-
0.1880
],
[
-
1.5724
,
-
1.5724
,
-
1.5724
],
[
0.6905
,
0.6905
,
0.6905
],
[
-
0.3190
,
-
0.3190
,
-
0.3190
],
[
0.7798
,
0.7798
,
0.7798
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
0.9457
,
-
0.9457
,
-
0.9457
],
[
-
0.2942
,
-
0.2942
,
-
0.2942
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
]],
[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.1880
,
-
0.1880
,
-
0.1880
],
[
-
1.5724
,
-
1.5724
,
-
1.5724
],
[
0.6905
,
0.6905
,
0.6905
],
[
-
0.3190
,
-
0.3190
,
-
0.3190
],
[
0.7798
,
0.7798
,
0.7798
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
0.9457
,
-
0.9457
,
-
0.9457
],
[
-
0.2942
,
-
0.2942
,
-
0.2942
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
]],
[[
1.1773
,
1.1773
,
1.1773
],
[
1.5009
,
1.5009
,
1.5009
],
[
2.6399
,
2.6399
,
2.6399
],
[
5.9242
,
5.9242
,
5.9242
],
[
1.0962
,
1.0962
,
1.0962
],
[
2.7346
,
2.7346
,
2.7346
],
[
6.0865
,
6.0865
,
6.0865
],
[
1.5555
,
1.5555
,
1.5555
],
[
4.3303
,
4.3303
,
4.3303
],
[
2.8229
,
2.8229
,
2.8229
]],
[[
-
0.0380
,
-
0.0380
,
-
0.0380
],
[
-
0.1880
,
-
0.1880
,
-
0.1880
],
[
-
1.5724
,
-
1.5724
,
-
1.5724
],
[
0.6905
,
0.6905
,
0.6905
],
[
-
0.3190
,
-
0.3190
,
-
0.3190
],
[
0.7798
,
0.7798
,
0.7798
],
[
-
0.3693
,
-
0.3693
,
-
0.3693
],
[
-
0.9457
,
-
0.9457
,
-
0.9457
],
[
-
0.2942
,
-
0.2942
,
-
0.2942
],
[
-
1.8527
,
-
1.8527
,
-
1.8527
]]],
dtype
=
dtype
).
cuda
()
assert
torch
.
allclose
(
output
,
expected_output
)
tests/test_ops/test_three_interpolate.py
View file @
3bb0611a
...
...
@@ -7,19 +7,20 @@ from mmcv.ops import three_interpolate
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_three_interpolate
():
features
=
torch
.
tensor
([[[
2.4350
,
4.7516
,
4.4995
,
2.4350
,
2.4350
,
2.4350
],
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
half
,
torch
.
float
,
torch
.
double
])
def
test_three_interpolate
(
dtype
):
features
=
torch
.
tensor
(
[[[
2.4350
,
4.7516
,
4.4995
,
2.4350
,
2.4350
,
2.4350
],
[
3.1236
,
2.6278
,
3.0447
,
3.1236
,
3.1236
,
3.1236
],
[
2.6732
,
2.8677
,
2.6436
,
2.6732
,
2.6732
,
2.6732
],
[
0.0124
,
7.0150
,
7.0199
,
0.0124
,
0.0124
,
0.0124
],
[
0.3207
,
0.0000
,
0.3411
,
0.3207
,
0.3207
,
0.3207
]],
[
0.3207
,
0.0000
,
0.3411
,
0.3207
,
0.3207
,
0.3207
]],
[[
0.0000
,
0.9544
,
2.4532
,
0.0000
,
0.0000
,
0.0000
],
[
0.5346
,
1.9176
,
1.4715
,
0.5346
,
0.5346
,
0.5346
],
[
0.0000
,
0.2744
,
2.0842
,
0.0000
,
0.0000
,
0.0000
],
[
0.3414
,
1.5063
,
1.6209
,
0.3414
,
0.3414
,
0.3414
],
[
0.5814
,
0.0103
,
0.0000
,
0.5814
,
0.5814
,
0.5814
]]]
).
cuda
()
[
0.5814
,
0.0103
,
0.0000
,
0.5814
,
0.5814
,
0.5814
]]],
dtype
=
dtype
).
cuda
()
idx
=
torch
.
tensor
([[[
0
,
1
,
2
],
[
2
,
3
,
4
],
[
2
,
3
,
4
],
[
0
,
1
,
2
],
[
0
,
1
,
2
],
[
0
,
1
,
3
]],
...
...
@@ -37,7 +38,8 @@ def test_three_interpolate():
[
1.0000e+00
,
1.7148e-08
,
1.4070e-08
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
],
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
]]]).
cuda
()
[
3.3333e-01
,
3.3333e-01
,
3.3333e-01
]]],
dtype
=
dtype
).
cuda
()
output
=
three_interpolate
(
features
,
idx
,
weight
)
expected_output
=
torch
.
tensor
([[[
...
...
@@ -70,6 +72,7 @@ def test_three_interpolate():
[
3.8760e-01
,
1.0300e-02
,
8.3569e-09
,
3.8760e-01
,
3.8760e-01
,
1.9723e-01
]]]).
cuda
()
]]],
dtype
=
dtype
).
cuda
()
assert
torch
.
allclose
(
output
,
expected_output
,
1e-4
)
assert
torch
.
allclose
(
output
,
expected_output
,
1e-3
,
1e-4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment