Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0c23eb02
Unverified
Commit
0c23eb02
authored
May 19, 2023
by
bdf
Committed by
GitHub
May 19, 2023
Browse files
Sync main with mmcv1.x branch (#2800)
parent
59c1418e
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
258 deletions
+109
-258
mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp
+56
-225
setup.py
setup.py
+8
-3
tests/test_ops/test_box_iou_rotated.py
tests/test_ops/test_box_iou_rotated.py
+29
-14
tests/test_ops/test_roi_align.py
tests/test_ops/test_roi_align.py
+9
-9
tests/test_ops/test_roiaware_pool3d.py
tests/test_ops/test_roiaware_pool3d.py
+7
-7
No files found.
mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp
View file @
0c23eb02
...
...
@@ -9,238 +9,69 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#include "mlu_common_helper.h"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
void
KernelDynamicVoxelize
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
points
,
void
*
coors
,
const
float
voxel_x
,
const
float
voxel_y
,
const
float
voxel_z
,
const
float
coors_x_min
,
const
float
coors_y_min
,
const
float
coors_z_min
,
const
float
coors_x_max
,
const
float
coors_y_max
,
const
float
coors_z_max
,
const
int32_t
grid_x
,
const
int32_t
grid_y
,
const
int32_t
grid_z
,
const
int32_t
num_points
,
const
int32_t
num_features
);
void
KernelPoint2Voxel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
void
*
coors
,
void
*
point_to_pointidx
,
void
*
point_to_voxelidx
,
const
int32_t
num_points
,
const
int32_t
max_points
);
void
KernelCalcPointsPerVoxel
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
void
*
point_to_pointidx
,
void
*
point_to_voxelidx
,
void
*
coor_to_voxelidx
,
void
*
num_points_per_voxel
,
void
*
voxel_num
,
const
int32_t
max_voxels
,
const
int32_t
num_points
);
void
KernelAssignVoxelsCoors
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
points
,
void
*
temp_coors
,
void
*
point_to_voxelidx
,
void
*
coor_to_voxelidx
,
void
*
voxels
,
void
*
coors
,
const
int32_t
max_points
,
const
int32_t
num_points
,
const
int32_t
num_features
);
// policy function
static
void
policyFuncDefault
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
int
num_points
)
{
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
MIN
((
num_points
+
k_dim
->
x
-
1
)
/
k_dim
->
x
,
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
));
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
// policy function
static
void
policyFuncCalcPointsPerVoxel
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
int
num_points
)
{
k_dim
->
x
=
1
;
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
}
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
int
HardVoxelizeForwardMLUKernelLauncher
(
const
at
::
Tensor
&
points
,
at
::
Tensor
&
voxels
,
at
::
Tensor
&
coors
,
at
::
Tensor
&
num_points_per_voxel
,
const
std
::
vector
<
float
>
voxel_size
,
const
std
::
vector
<
float
>
coors_range
,
const
int
max_points
,
const
int
max_voxels
,
const
int
NDim
=
3
)
{
// check datatype
TORCH_CHECK
(
points
.
scalar_type
()
==
at
::
kFloat
,
"points type should be Float, got "
,
points
.
scalar_type
(),
"."
);
TORCH_CHECK
(
voxels
.
scalar_type
()
==
at
::
kFloat
,
"voxels type should be Float, got "
,
voxels
.
scalar_type
(),
"."
);
TORCH_CHECK
(
coors
.
scalar_type
()
==
at
::
kInt
,
"coors type should be Float, got "
,
coors
.
scalar_type
(),
"."
);
TORCH_CHECK
(
num_points_per_voxel
.
scalar_type
()
==
at
::
kInt
,
"num_points_per_voxel type should be Float, got "
,
num_points_per_voxel
.
scalar_type
(),
"."
);
// check shape
TORCH_CHECK
(
points
.
dim
()
==
2
,
"points should be a 2d tensor, got "
,
points
.
dim
(),
"D."
);
TORCH_CHECK
(
voxels
.
dim
()
==
3
,
"voxels should be a 3d tensor, got "
,
voxels
.
dim
(),
"D."
);
TORCH_CHECK
(
coors
.
dim
()
==
2
,
"coors should be a 2d tensor, got "
,
coors
.
dim
(),
"D."
);
TORCH_CHECK
(
num_points_per_voxel
.
dim
()
==
1
,
"num_points_per_voxel should be a 1d tensor, got "
,
num_points_per_voxel
.
dim
(),
"D."
);
const
int
num_points
=
points
.
size
(
0
);
const
int
num_features
=
points
.
size
(
1
);
TORCH_CHECK
(
points
.
size
(
0
)
==
num_points
,
"the 1st dimensions of points should be num_points, got "
,
points
.
size
(
0
),
"."
);
TORCH_CHECK
(
points
.
size
(
1
)
==
num_features
,
"the 2nd dimensions of points should be num_features, got "
,
points
.
size
(
1
),
"."
);
TORCH_CHECK
(
voxels
.
size
(
0
)
==
max_voxels
,
"the 1st dimensions of voxels should be max_voxels, got "
,
voxels
.
size
(
0
),
"."
);
TORCH_CHECK
(
voxels
.
size
(
1
)
==
max_points
,
"the 2nd dimensions of voxels should be max_points, got "
,
voxels
.
size
(
1
),
"."
);
TORCH_CHECK
(
voxels
.
size
(
2
)
==
num_features
,
"the 3rd dimensions of voxels should be num_features, got "
,
voxels
.
size
(
2
),
"."
);
TORCH_CHECK
(
coors
.
size
(
0
)
==
max_voxels
,
"the 1st dimensions of coors should be max_voxels, got "
,
coors
.
size
(
0
),
"."
);
TORCH_CHECK
(
coors
.
size
(
1
)
==
3
,
"the 2nd dimensions of coors should be 3, got "
,
coors
.
size
(
1
),
"."
);
TORCH_CHECK
(
num_points_per_voxel
.
size
(
0
)
==
max_voxels
,
"the 1st dimensions of num_points_per_voxel should be 3, got "
,
num_points_per_voxel
.
size
(
0
),
"."
);
// large tensor check
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
points
.
numel
()
<
max_input_size
,
"points element num should be less than 2^31, got "
,
points
.
numel
(),
"."
);
TORCH_CHECK
(
voxels
.
numel
()
<
max_input_size
,
"voxels element num should be less than 2^31, got "
,
voxels
.
numel
(),
"."
);
TORCH_CHECK
(
coors
.
numel
()
<
max_input_size
,
"coors element num should be less than 2^31, got "
,
coors
.
numel
(),
"."
);
// check zero element
if
(
max_points
==
0
||
max_voxels
==
0
)
{
return
0
;
}
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
points_
=
points
.
contiguous
();
auto
points_impl
=
torch_mlu
::
getMluTensorImpl
(
points_
);
auto
points_ptr
=
points_impl
->
cnnlMalloc
();
auto
voxels_
=
voxels
.
contiguous
();
auto
voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
voxels_
);
auto
voxels_ptr
=
voxels_impl
->
cnnlMalloc
();
auto
coors_
=
coors
.
contiguous
();
auto
coors_impl
=
torch_mlu
::
getMluTensorImpl
(
coors_
);
auto
coors_ptr
=
coors_impl
->
cnnlMalloc
();
auto
num_points_per_voxel_
=
num_points_per_voxel
.
contiguous
();
auto
num_points_per_voxel_impl
=
torch_mlu
::
getMluTensorImpl
(
num_points_per_voxel_
);
auto
num_points_per_voxel_ptr
=
num_points_per_voxel_impl
->
cnnlMalloc
();
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncDefault
(
&
k_dim
,
&
k_type
,
num_points
);
// 1. link point to corresponding voxel coors
const
float
voxel_x
=
voxel_size
[
0
];
const
float
voxel_y
=
voxel_size
[
1
];
const
float
voxel_z
=
voxel_size
[
2
];
const
float
coors_x_min
=
coors_range
[
0
];
const
float
coors_y_min
=
coors_range
[
1
];
const
float
coors_z_min
=
coors_range
[
2
];
const
float
coors_x_max
=
coors_range
[
3
];
const
float
coors_y_max
=
coors_range
[
4
];
const
float
coors_z_max
=
coors_range
[
5
];
const
int
grid_x
=
round
((
coors_x_max
-
coors_x_min
)
/
voxel_x
);
const
int
grid_y
=
round
((
coors_y_max
-
coors_y_min
)
/
voxel_y
);
const
int
grid_z
=
round
((
coors_z_max
-
coors_z_min
)
/
voxel_z
);
auto
temp_coors
=
at
::
zeros
({
NDim
,
num_points
},
points
.
options
().
dtype
(
at
::
kInt
))
.
contiguous
();
auto
temp_coors_impl
=
torch_mlu
::
getMluTensorImpl
(
temp_coors
);
auto
temp_coors_ptr
=
temp_coors_impl
->
cnnlMalloc
();
KernelDynamicVoxelize
(
k_dim
,
k_type
,
queue
,
points_ptr
,
temp_coors_ptr
,
voxel_x
,
voxel_y
,
voxel_z
,
coors_x_min
,
coors_y_min
,
coors_z_min
,
coors_x_max
,
coors_y_max
,
coors_z_max
,
grid_x
,
grid_y
,
grid_z
,
num_points
,
num_features
);
// 2. map point to the idx of the corresponding voxel, find duplicate coor
auto
point_to_pointidx
=
at
::
zeros
(
{
num_points
,
},
points
.
options
().
dtype
(
at
::
kInt
))
.
contiguous
();
auto
point_to_pointidx_impl
=
torch_mlu
::
getMluTensorImpl
(
point_to_pointidx
);
auto
point_to_pointidx_ptr
=
point_to_pointidx_impl
->
cnnlMalloc
();
auto
point_to_voxelidx
=
at
::
zeros
(
{
num_points
,
},
points
.
options
().
dtype
(
at
::
kInt
))
.
contiguous
();
auto
point_to_voxelidx_impl
=
torch_mlu
::
getMluTensorImpl
(
point_to_voxelidx
);
auto
point_to_voxelidx_ptr
=
point_to_voxelidx_impl
->
cnnlMalloc
();
KernelPoint2Voxel
(
k_dim
,
k_type
,
queue
,
temp_coors_ptr
,
point_to_pointidx_ptr
,
point_to_voxelidx_ptr
,
num_points
,
max_points
);
// calculate task dimension
cnrtDim3_t
k_dim_calc_points_per_voxel
;
cnrtFunctionType_t
k_type_calc_points_per_voxel
;
policyFuncCalcPointsPerVoxel
(
&
k_dim_calc_points_per_voxel
,
&
k_type_calc_points_per_voxel
,
num_points
);
// 3. determine voxel num and voxel's coor index
auto
coor_to_voxelidx
=
at
::
zeros
(
{
num_points
,
},
points
.
options
().
dtype
(
at
::
kInt
))
.
contiguous
();
auto
coor_to_voxelidx_impl
=
torch_mlu
::
getMluTensorImpl
(
coor_to_voxelidx
);
auto
coor_to_voxelidx_ptr
=
coor_to_voxelidx_impl
->
cnnlMalloc
();
auto
voxel_num
=
at
::
zeros
(
{
1
,
},
points
.
options
().
dtype
(
at
::
kInt
))
.
contiguous
();
auto
voxel_num_impl
=
torch_mlu
::
getMluTensorImpl
(
voxel_num
);
auto
voxel_num_ptr
=
voxel_num_impl
->
cnnlMalloc
();
KernelCalcPointsPerVoxel
(
k_dim_calc_points_per_voxel
,
k_type_calc_points_per_voxel
,
queue
,
point_to_pointidx_ptr
,
point_to_voxelidx_ptr
,
coor_to_voxelidx_ptr
,
num_points_per_voxel_ptr
,
voxel_num_ptr
,
max_voxels
,
num_points
);
// 4. copy point features and coors of each voxels to voxels
KernelAssignVoxelsCoors
(
k_dim
,
k_type
,
queue
,
points_ptr
,
temp_coors_ptr
,
point_to_voxelidx_ptr
,
coor_to_voxelidx_ptr
,
voxels_ptr
,
coors_ptr
,
max_points
,
num_points
,
num_features
);
auto
voxel_num_cpu
=
voxel_num
.
to
(
at
::
kCPU
);
std
::
vector
<
float
>
_voxel_size
(
voxel_size
.
begin
(),
voxel_size
.
end
());
std
::
vector
<
float
>
_coors_range
(
coors_range
.
begin
(),
coors_range
.
end
());
auto
opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
);
auto
voxel_size_tensor
=
torch
::
from_blob
(
_voxel_size
.
data
(),
{
int64_t
(
_voxel_size
.
size
())},
opts
)
.
clone
()
.
to
(
at
::
kMLU
);
auto
coors_range_tensor
=
torch
::
from_blob
(
_coors_range
.
data
(),
{
int64_t
(
_coors_range
.
size
())},
opts
)
.
clone
()
.
to
(
at
::
kMLU
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
points
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxels
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
coors
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
num_points_per_voxel
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxel_size_tensor
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
coors_range_tensor
);
auto
voxel_num_tensor
=
at
::
empty
({
1
},
points
.
options
().
dtype
(
torch
::
kInt32
));
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxel_num_tensor
);
size_t
workspace_size
;
auto
handle
=
mluOpGetCurrentHandle
();
mluOpGetVoxelizationWorkspaceSize
(
handle
,
points_desc
.
desc
(),
voxel_size_tensor_desc
.
desc
(),
coors_range_tensor_desc
.
desc
(),
max_points
,
max_voxels
,
NDim
,
true
,
voxels_desc
.
desc
(),
coors_desc
.
desc
(),
num_points_per_voxel_desc
.
desc
(),
voxel_num_tensor_desc
.
desc
(),
&
workspace_size
);
auto
workspace_tensor
=
at
::
empty
(
workspace_size
,
points
.
options
().
dtype
(
at
::
kByte
));
INITIAL_MLU_PARAM_WITH_TENSOR
(
workspace_tensor
);
mluOpVoxelization
(
handle
,
points_desc
.
desc
(),
points_ptr
,
voxel_size_tensor_desc
.
desc
(),
voxel_size_tensor_ptr
,
coors_range_tensor_desc
.
desc
(),
coors_range_tensor_ptr
,
max_points
,
max_voxels
,
NDim
,
true
,
workspace_tensor_ptr
,
workspace_size
,
voxels_desc
.
desc
(),
voxels_ptr
,
coors_desc
.
desc
(),
coors_ptr
,
num_points_per_voxel_desc
.
desc
(),
num_points_per_voxel_ptr
,
voxel_num_tensor_desc
.
desc
(),
voxel_num_tensor_ptr
);
auto
voxel_num_cpu
=
voxel_num_tensor
.
to
(
at
::
kCPU
);
int
voxel_num_int
=
voxel_num_cpu
.
data_ptr
<
int
>
()[
0
];
return
voxel_num_int
;
}
...
...
@@ -254,7 +85,7 @@ int hard_voxelize_forward_mlu(const at::Tensor &points, at::Tensor &voxels,
return
HardVoxelizeForwardMLUKernelLauncher
(
points
,
voxels
,
coors
,
num_points_per_voxel
,
voxel_size
,
coors_range
,
max_points
,
max_voxels
,
NDim
);
}
;
}
int
hard_voxelize_forward_impl
(
const
at
::
Tensor
&
points
,
at
::
Tensor
&
voxels
,
at
::
Tensor
&
coors
,
...
...
setup.py
View file @
0c23eb02
...
...
@@ -212,6 +212,7 @@ def get_extensions():
include_dirs
=
[]
extra_objects
=
[]
extra_link_args
=
[]
is_rocm_pytorch
=
False
try
:
from
torch.utils.cpp_extension
import
ROCM_HOME
...
...
@@ -325,8 +326,11 @@ def get_extensions():
'./mlu-ops/bangc-ops/kernels/**/*.cpp'
,
recursive
=
True
)
+
\
glob
.
glob
(
'./mlu-ops/bangc-ops/kernels/**/*.mlu'
,
recursive
=
True
)
extra_objects
=
glob
.
glob
(
'./mlu-ops/bangc-ops/kernels/kernel_wrapper/*.o'
)
extra_link_args
=
[
'-Wl,--whole-archive'
,
'./mlu-ops/bangc-ops/kernels/kernel_wrapper/lib/libextops.a'
,
'-Wl,--no-whole-archive'
]
extension
=
MLUExtension
include_dirs
.
append
(
os
.
path
.
abspath
(
'./mmcv/ops/csrc/common'
))
include_dirs
.
append
(
os
.
path
.
abspath
(
'./mmcv/ops/csrc/common/mlu'
))
...
...
@@ -393,7 +397,8 @@ def get_extensions():
include_dirs
=
include_dirs
,
define_macros
=
define_macros
,
extra_objects
=
extra_objects
,
extra_compile_args
=
extra_compile_args
)
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
)
extensions
.
append
(
ext_ops
)
return
extensions
...
...
tests/test_ops/test_box_iou_rotated.py
View file @
0c23eb02
...
...
@@ -3,11 +3,13 @@ import numpy as np
import
pytest
import
torch
from
mmcv.ops
import
box_iou_rotated
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
class
TestBoxIoURotated
:
def
test_box_iou_rotated_cpu
(
self
):
from
mmcv.ops
import
box_iou_rotated
np_boxes1
=
np
.
asarray
(
[[
1.0
,
1.0
,
3.0
,
4.0
,
0.5
],
[
2.0
,
2.0
,
3.0
,
4.0
,
0.6
],
[
7.0
,
7.0
,
8.0
,
8.0
,
0.4
]],
...
...
@@ -44,10 +46,17 @@ class TestBoxIoURotated:
assert
np
.
allclose
(
ious
.
cpu
().
numpy
(),
np_expect_ious_aligned
,
atol
=
1e-4
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_box_iou_rotated_cuda
(
self
):
from
mmcv.ops
import
box_iou_rotated
@
pytest
.
mark
.
parametrize
(
'device'
,
[
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
def
test_box_iou_rotated
(
self
,
device
):
np_boxes1
=
np
.
asarray
(
[[
1.0
,
1.0
,
3.0
,
4.0
,
0.5
],
[
2.0
,
2.0
,
3.0
,
4.0
,
0.6
],
[
7.0
,
7.0
,
8.0
,
8.0
,
0.4
]],
...
...
@@ -63,8 +72,8 @@ class TestBoxIoURotated:
np_expect_ious_aligned
=
np
.
asarray
([
0.3708
,
0.4487
,
0.3622
],
dtype
=
np
.
float32
)
boxes1
=
torch
.
from_numpy
(
np_boxes1
).
cuda
(
)
boxes2
=
torch
.
from_numpy
(
np_boxes2
).
cuda
(
)
boxes1
=
torch
.
from_numpy
(
np_boxes1
).
to
(
device
)
boxes2
=
torch
.
from_numpy
(
np_boxes2
).
to
(
device
)
# test cw angle definition
ious
=
box_iou_rotated
(
boxes1
,
boxes2
)
...
...
@@ -85,7 +94,6 @@ class TestBoxIoURotated:
ious
.
cpu
().
numpy
(),
np_expect_ious_aligned
,
atol
=
1e-4
)
def
test_box_iou_rotated_iof_cpu
(
self
):
from
mmcv.ops
import
box_iou_rotated
np_boxes1
=
np
.
asarray
(
[[
1.0
,
1.0
,
3.0
,
4.0
,
0.5
],
[
2.0
,
2.0
,
3.0
,
4.0
,
0.6
],
[
7.0
,
7.0
,
8.0
,
8.0
,
0.4
]],
...
...
@@ -121,10 +129,17 @@ class TestBoxIoURotated:
assert
np
.
allclose
(
ious
.
cpu
().
numpy
(),
np_expect_ious_aligned
,
atol
=
1e-4
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_box_iou_rotated_iof_cuda
(
self
):
from
mmcv.ops
import
box_iou_rotated
@
pytest
.
mark
.
parametrize
(
'device'
,
[
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'mlu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
def
test_box_iou_rotated_iof
(
self
,
device
):
np_boxes1
=
np
.
asarray
(
[[
1.0
,
1.0
,
3.0
,
4.0
,
0.5
],
[
2.0
,
2.0
,
3.0
,
4.0
,
0.6
],
[
7.0
,
7.0
,
8.0
,
8.0
,
0.4
]],
...
...
@@ -140,8 +155,8 @@ class TestBoxIoURotated:
np_expect_ious_aligned
=
np
.
asarray
([
0.4959
,
0.5420
,
0.4404
],
dtype
=
np
.
float32
)
boxes1
=
torch
.
from_numpy
(
np_boxes1
).
cuda
(
)
boxes2
=
torch
.
from_numpy
(
np_boxes2
).
cuda
(
)
boxes1
=
torch
.
from_numpy
(
np_boxes1
).
to
(
device
)
boxes2
=
torch
.
from_numpy
(
np_boxes2
).
to
(
device
)
# test cw angle definition
ious
=
box_iou_rotated
(
boxes1
,
boxes2
,
mode
=
'iof'
)
...
...
tests/test_ops/test_roi_align.py
View file @
0c23eb02
...
...
@@ -93,6 +93,15 @@ def _test_roialign_allclose(device, dtype):
x
.
grad
.
data
.
type
(
torch
.
float
).
cpu
().
numpy
(),
np_grad
,
atol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float
,
pytest
.
param
(
torch
.
double
,
marks
=
pytest
.
mark
.
skipif
(
IS_MLU_AVAILABLE
or
IS_NPU_AVAILABLE
,
reason
=
'MLU and NPU do not support for 64-bit floating point'
)),
torch
.
half
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cpu'
,
pytest
.
param
(
...
...
@@ -108,15 +117,6 @@ def _test_roialign_allclose(device, dtype):
marks
=
pytest
.
mark
.
skipif
(
not
IS_NPU_AVAILABLE
,
reason
=
'requires NPU support'
))
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float
,
pytest
.
param
(
torch
.
double
,
marks
=
pytest
.
mark
.
skipif
(
IS_MLU_AVAILABLE
or
IS_NPU_AVAILABLE
,
reason
=
'MLU and NPU do not support for 64-bit floating point'
)),
torch
.
half
])
def
test_roialign
(
device
,
dtype
):
# check double only
if
dtype
is
torch
.
double
:
...
...
tests/test_ops/test_roiaware_pool3d.py
View file @
0c23eb02
...
...
@@ -8,6 +8,13 @@ from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float
,
torch
.
half
,
pytest
.
param
(
torch
.
double
,
marks
=
pytest
.
mark
.
skipif
(
IS_MLU_AVAILABLE
,
reason
=
'MLU does not support for double'
))
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
pytest
.
param
(
'cuda'
,
...
...
@@ -18,13 +25,6 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
marks
=
pytest
.
mark
.
skipif
(
not
IS_MLU_AVAILABLE
,
reason
=
'requires MLU support'
))
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float
,
torch
.
half
,
pytest
.
param
(
torch
.
double
,
marks
=
pytest
.
mark
.
skipif
(
IS_MLU_AVAILABLE
,
reason
=
'MLU does not support for double'
))
])
def
test_RoIAwarePool3d
(
device
,
dtype
):
roiaware_pool3d_max
=
RoIAwarePool3d
(
out_size
=
4
,
max_pts_per_voxel
=
128
,
mode
=
'max'
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment