Unverified Commit d28aa8a9 authored by Danielmic's avatar Danielmic Committed by GitHub
Browse files

[Feature] Add the implementation of diff_iou_rotated with mlu-ops (#2840)

parent 10c8b9e7
...@@ -20,7 +20,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -20,7 +20,7 @@ We implement common ops used in detection, segmentation, etc.
| Correlation | | √ | | | | | Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | √ | | Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | | | DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | √ | | | | DynamicScatter | | √ | √ | | |
| FurthestPointSample | | √ | | | | | FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | | | FurthestPointSampleWithDist | | √ | | | |
......
...@@ -20,7 +20,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -20,7 +20,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| Correlation | | √ | | | | | Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | √ | | Deformable Convolution v1/v2 | √ | √ | | | √ |
| Deformable RoIPool | | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | | | DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | √ | | | | DynamicScatter | | √ | √ | | |
| FurthestPointSample | | √ | | | | | FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | | | FurthestPointSampleWithDist | | √ | | | |
......
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
Tensor diff_iou_rotated_sort_vertices_forward_mlu(Tensor vertices, Tensor mask,
Tensor num_valid) {
// params check
TORCH_CHECK(vertices.scalar_type() == at::kFloat,
"vertices type should be Float, got ", vertices.scalar_type());
TORCH_CHECK(mask.scalar_type() == at::kBool, "mask should be Bool, got ",
mask.scalar_type());
TORCH_CHECK(num_valid.scalar_type() == at::kInt,
"num_valid type should be Int32, got ", num_valid.scalar_type());
TORCH_CHECK(vertices.size(2) == 24, "vertices.dim(2) should be 24, got ",
vertices.size(2));
TORCH_CHECK(mask.size(2) == 24, "mask.dim(2) should be 24, got ",
mask.size(2));
// zero-element check
if (vertices.numel() == 0) {
return at::empty({0}, num_valid.options().dtype(at::kInt));
}
auto idx = at::empty({vertices.size(0), vertices.size(1), 9},
num_valid.options().dtype(at::kInt));
INITIAL_MLU_PARAM_WITH_TENSOR(vertices);
INITIAL_MLU_PARAM_WITH_TENSOR(mask);
INITIAL_MLU_PARAM_WITH_TENSOR(num_valid);
INITIAL_MLU_PARAM_WITH_TENSOR(idx);
// get compute handle
auto handle = mluOpGetCurrentHandle();
// launch kernel
mluOpDiffIouRotatedSortVerticesForward(
handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr,
num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr);
return idx;
}
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, MLU,
diff_iou_rotated_sort_vertices_forward_mlu);
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \ auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc(); auto NAME##_ptr = NAME##_impl->cnnlMalloc();
enum class reduce_t{ SUM = 0, MEAN = 1, MAX = 2 }; enum class reduce_t { SUM = 0, MEAN = 1, MAX = 2 };
inline std::string to_string(reduce_t reduce_type) { inline std::string to_string(reduce_t reduce_type) {
if (reduce_type == reduce_t::MAX) { if (reduce_type == reduce_t::MAX) {
......
...@@ -11,16 +11,16 @@ ...@@ -11,16 +11,16 @@
*************************************************************************/ *************************************************************************/
#include "mlu_common_helper.h" #include "mlu_common_helper.h"
std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats, std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(
const Tensor &coors, const Tensor &feats, const Tensor &coors, const reduce_t reduce_type) {
const reduce_t reduce_type) {
// params check // params check
TORCH_CHECK(feats.scalar_type() == at::kFloat, TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type()); "feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(coors.scalar_type() == at::kInt, TORCH_CHECK(coors.scalar_type() == at::kInt,
"coors type should be Int32, got ", coors.scalar_type()); "coors type should be Int32, got ", coors.scalar_type());
TORCH_CHECK(feats.size(0) == coors.size(0), TORCH_CHECK(feats.size(0) == coors.size(0),
"feats.dim(0) and coors.dim(0) should be same, got ", feats.size(0), " vs ", coors.size(0)); "feats.dim(0) and coors.dim(0) should be same, got ",
feats.size(0), " vs ", coors.size(0));
const int num_input = feats.size(0); const int num_input = feats.size(0);
const int num_feats = feats.size(1); const int num_feats = feats.size(1);
...@@ -49,59 +49,48 @@ std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats, ...@@ -49,59 +49,48 @@ std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats,
auto handle = mluOpGetCurrentHandle(); auto handle = mluOpGetCurrentHandle();
size_t workspace_size; size_t workspace_size;
mluOpGetDynamicPointToVoxelForwardWorkspaceSize(handle, mluOpGetDynamicPointToVoxelForwardWorkspaceSize(
feats_desc.desc(), handle, feats_desc.desc(), coors_desc.desc(), &workspace_size);
coors_desc.desc(),
&workspace_size);
auto workspace_tensor = auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte)); at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor); INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
// launch kernel // launch kernel
mluOpDynamicPointToVoxelForward(handle, mluOpDynamicPointToVoxelForward(
mlu_reduce_type, handle, mlu_reduce_type, feats_desc.desc(), feats_ptr, coors_desc.desc(),
feats_desc.desc(), coors_ptr, workspace_tensor_ptr, workspace_size,
feats_ptr, reduced_feats_desc.desc(), reduced_feats_ptr, out_coors_desc.desc(),
coors_desc.desc(), out_coors_ptr, coors_map_desc.desc(), coors_map_ptr,
coors_ptr, reduce_count_desc.desc(), reduce_count_ptr, voxel_num_desc.desc(),
workspace_tensor_ptr,
workspace_size,
reduced_feats_desc.desc(),
reduced_feats_ptr,
out_coors_desc.desc(),
out_coors_ptr,
coors_map_desc.desc(),
coors_map_ptr,
reduce_count_desc.desc(),
reduce_count_ptr,
voxel_num_desc.desc(),
voxel_num_ptr); voxel_num_ptr);
int voxel_num_value = *static_cast<int *>(voxel_num.cpu().data_ptr()); int voxel_num_value = *static_cast<int *>(voxel_num.cpu().data_ptr());
TORCH_CHECK(voxel_num_value <= feats.size(0), TORCH_CHECK(voxel_num_value <= feats.size(0),
"voxel_num should be less than or equal to feats_num, got ", voxel_num_value, " vs ", feats.size(0)); "voxel_num should be less than or equal to feats_num, got ",
return {reduced_feats.slice(0, 0, voxel_num_value), out_coors.slice(0, 0, voxel_num_value), voxel_num_value, " vs ", feats.size(0));
coors_map, reduce_count.slice(0, 0, voxel_num_value)}; return {reduced_feats.slice(0, 0, voxel_num_value),
out_coors.slice(0, 0, voxel_num_value), coors_map,
reduce_count.slice(0, 0, voxel_num_value)};
} }
void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats, void dynamic_point_to_voxel_backward_mlu(
const Tensor &grad_reduced_feats, Tensor &grad_feats, const Tensor &grad_reduced_feats, const Tensor &feats,
const Tensor &feats, const Tensor &reduced_feats, const Tensor &coors_idx,
const Tensor &reduced_feats, const Tensor &reduce_count, const reduce_t reduce_type) {
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type) {
// params check // params check
TORCH_CHECK(grad_reduced_feats.scalar_type() == at::kFloat, TORCH_CHECK(grad_reduced_feats.scalar_type() == at::kFloat,
"grad_reduced_feats type should be Float, got ", grad_reduced_feats.scalar_type()); "grad_reduced_feats type should be Float, got ",
grad_reduced_feats.scalar_type());
TORCH_CHECK(feats.scalar_type() == at::kFloat, TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type()); "feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(reduced_feats.scalar_type() == at::kFloat, TORCH_CHECK(reduced_feats.scalar_type() == at::kFloat,
"reduced_feats type should be Float, got ", reduced_feats.scalar_type()); "reduced_feats type should be Float, got ",
reduced_feats.scalar_type());
TORCH_CHECK(coors_idx.scalar_type() == at::kInt, TORCH_CHECK(coors_idx.scalar_type() == at::kInt,
"coors_idx type should be Int32, got ", coors_idx.scalar_type()); "coors_idx type should be Int32, got ", coors_idx.scalar_type());
TORCH_CHECK(reduce_count.scalar_type() == at::kInt, TORCH_CHECK(reduce_count.scalar_type() == at::kInt,
"reduce_count type should be Int32, got ", reduce_count.scalar_type()); "reduce_count type should be Int32, got ",
reduce_count.scalar_type());
const int num_input = feats.size(0); const int num_input = feats.size(0);
const int num_reduced = reduced_feats.size(0); const int num_reduced = reduced_feats.size(0);
...@@ -114,11 +103,13 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats, ...@@ -114,11 +103,13 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
// TODO(miaochen): remove this after mlu-ops supports other mode of reduce. // TODO(miaochen): remove this after mlu-ops supports other mode of reduce.
TORCH_CHECK(reduce_type == reduce_t::MAX, TORCH_CHECK(reduce_type == reduce_t::MAX,
"only supports max reduce in current version, got ", to_string(reduce_type)); "only supports max reduce in current version, got ",
to_string(reduce_type));
int voxel_num_value = reduced_feats.size(0); int voxel_num_value = reduced_feats.size(0);
auto opts = torch::TensorOptions().dtype(torch::kInt32); auto opts = torch::TensorOptions().dtype(torch::kInt32);
auto voxel_num = torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU); auto voxel_num =
torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU);
auto mlu_reduce_type = getMluOpReduceMode(reduce_type); auto mlu_reduce_type = getMluOpReduceMode(reduce_type);
INITIAL_MLU_PARAM_WITH_TENSOR(grad_feats); INITIAL_MLU_PARAM_WITH_TENSOR(grad_feats);
...@@ -134,43 +125,30 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats, ...@@ -134,43 +125,30 @@ void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
size_t workspace_size; size_t workspace_size;
mluOpGetDynamicPointToVoxelBackwardWorkspaceSize( mluOpGetDynamicPointToVoxelBackwardWorkspaceSize(
handle, mlu_reduce_type, handle, mlu_reduce_type, grad_feats_desc.desc(), feats_desc.desc(),
grad_feats_desc.desc(), grad_reduced_feats_desc.desc(), coors_idx_desc.desc(),
feats_desc.desc(), reduce_count_desc.desc(), voxel_num_desc.desc(), &workspace_size);
grad_reduced_feats_desc.desc(),
coors_idx_desc.desc(),
reduce_count_desc.desc(),
voxel_num_desc.desc(),
&workspace_size);
auto workspace_tensor = auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte)); at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor); INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);
// launch kernel // launch kernel
mluOpDynamicPointToVoxelBackward( mluOpDynamicPointToVoxelBackward(
handle, mlu_reduce_type, handle, mlu_reduce_type, grad_reduced_feats_desc.desc(),
grad_reduced_feats_desc.desc(), grad_reduced_feats_ptr, feats_desc.desc(), feats_ptr,
grad_reduced_feats_ptr, reduced_feats_desc.desc(), reduced_feats_ptr, coors_idx_desc.desc(),
feats_desc.desc(), feats_ptr, coors_idx_ptr, reduce_count_desc.desc(), reduce_count_ptr,
reduced_feats_desc.desc(), reduced_feats_ptr, voxel_num_desc.desc(), voxel_num_ptr, workspace_tensor_ptr,
coors_idx_desc.desc(), coors_idx_ptr, workspace_size, grad_feats_desc.desc(), grad_feats_ptr);
reduce_count_desc.desc(), reduce_count_ptr,
voxel_num_desc.desc(), voxel_num_ptr,
workspace_tensor_ptr, workspace_size,
grad_feats_desc.desc(), grad_feats_ptr);
} }
std::vector<Tensor> dynamic_point_to_voxel_forward_impl(const Tensor &feats, std::vector<Tensor> dynamic_point_to_voxel_forward_impl(
const Tensor &coors, const Tensor &feats, const Tensor &coors, const reduce_t reduce_type);
const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_impl(
void dynamic_point_to_voxel_backward_impl(Tensor &grad_feats, Tensor &grad_feats, const Tensor &grad_reduced_feats, const Tensor &feats,
const Tensor &grad_reduced_feats, const Tensor &reduced_feats, const Tensor &coors_idx,
const Tensor &feats, const Tensor &reduce_count, const reduce_t reduce_type);
const Tensor &reduced_feats,
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MLU, REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MLU,
dynamic_point_to_voxel_forward_mlu); dynamic_point_to_voxel_forward_mlu);
......
...@@ -4,11 +4,23 @@ import pytest ...@@ -4,11 +4,23 @@ import pytest
import torch import torch
from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
if IS_MLU_AVAILABLE:
torch.backends.mlu.matmul.allow_tf32 = False
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support') @pytest.mark.parametrize('device', [
def test_diff_iou_rotated_2d(): pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_2d(device):
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0]]], [0.5, 0.5, 1., 1., .0]]],
...@@ -19,17 +31,25 @@ def test_diff_iou_rotated_2d(): ...@@ -19,17 +31,25 @@ def test_diff_iou_rotated_2d():
[1.5, 1.5, 1., 1., .0]]], [1.5, 1.5, 1., 1., .0]]],
dtype=np.float32) dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda() boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).cuda() boxes2 = torch.from_numpy(np_boxes2).to(device)
np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]]) np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]])
ious = diff_iou_rotated_2d(boxes1, boxes2) ious = diff_iou_rotated_2d(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
@pytest.mark.skipif( @pytest.mark.parametrize('device', [
not torch.cuda.is_available(), reason='requires CUDA support') pytest.param(
def test_diff_iou_rotated_3d(): 'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_3d(device):
np_boxes1 = np.asarray( np_boxes1 = np.asarray(
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], [[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
...@@ -41,8 +61,8 @@ def test_diff_iou_rotated_3d(): ...@@ -41,8 +61,8 @@ def test_diff_iou_rotated_3d():
[-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]], [-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]],
dtype=np.float32) dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda() boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).cuda() boxes2 = torch.from_numpy(np_boxes2).to(device)
np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]]) np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]])
ious = diff_iou_rotated_3d(boxes1, boxes2) ious = diff_iou_rotated_3d(boxes1, boxes2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment