Commit cdfbdc0b authored by liuduanhui's avatar liuduanhui Committed by Zaida Zhou
Browse files

[Feature] Support ThreeNN with cambricon MLU backend (#2215)

parent b091e4d2
......@@ -53,7 +53,7 @@ We implement common ops used in detection, segmentation, etc.
| Sparse Convolution | | √ | | |
| Synchronized BatchNorm | | √ | | |
| ThreeInterpolate | | √ | | |
| ThreeNN | | √ | | |
| ThreeNN | | √ | | |
| TINShift | | √ | √ | |
| UpFirDn2d | | √ | | |
| Voxelization | √ | √ | | |
......
......@@ -53,7 +53,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| Sparse Convolution | | √ | | |
| Synchronized BatchNorm | | √ | | |
| ThreeInterpolate | | √ | | |
| ThreeNN | | √ | | |
| ThreeNN | | √ | | |
| TINShift | | √ | √ | |
| UpFirDn2d | | √ | | |
| Voxelization | √ | √ | | |
......
This diff is collapsed.
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelThreeNNForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *unknown, const void *known, void *dist2,
int *idx, const int b, const int n, const int m);
void ThreeNNMLUKernelLauncher(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
// Check dtype.
TORCH_CHECK(
unknown.scalar_type() == at::kFloat || unknown.scalar_type() == at::kHalf,
"unknown type should be Float or Half, got ", unknown.scalar_type(), ".");
TORCH_CHECK(unknown.scalar_type() == known.scalar_type(),
"known should have the same type as unknown.");
TORCH_CHECK(unknown.scalar_type() == dist2.scalar_type(),
"dist2 should have the same type as unknown.");
TORCH_CHECK(idx.scalar_type() == at::kInt, "idx type should be Int.");
// Check shape.
TORCH_CHECK(unknown.dim() == 3, "unknown should be 3d tensor, got ",
unknown.dim(), "D.");
TORCH_CHECK(known.dim() == 3, "known should be 3d tensor, got ", known.dim(),
"D.");
TORCH_CHECK(unknown.size(0) == known.size(0),
"known.dim0 should be equal to unknown.dim0, got ", known.size(0),
".");
TORCH_CHECK(unknown.size(2) == 3, "unknown dim2 should be 3, got ",
unknown.size(2), ".");
TORCH_CHECK(known.size(2) == 3, "known dim2 should be 3, got ", known.size(2),
".");
// zero element check
TORCH_CHECK(unknown.numel() > 0,
"unknown.numel should greater than zero, got ", unknown.numel(),
".");
if (known.numel() == 0) {
// return if known zero element
return;
}
// large tensor check
const size_t max_input_num = 2147483648; // 2^31, 2G num
TORCH_CHECK(unknown.numel() < max_input_num,
"unknown.numel() should be less than 2147483648, got ",
unknown.numel(), ".");
TORCH_CHECK(known.numel() < max_input_num,
"known.numel() should be less than 2147483648, got ",
known.numel(), ".");
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto unknown_impl = torch_mlu::getMluTensorImpl(unknown);
auto unknown_ptr = unknown_impl->cnnlMalloc();
auto known_t = known.permute({0, 2, 1}).contiguous();
auto known_impl = torch_mlu::getMluTensorImpl(known_t);
auto known_ptr = known_impl->cnnlMalloc();
auto dist2_impl = torch_mlu::getMluTensorImpl(dist2);
auto dist2_ptr = dist2_impl->cnnlMalloc();
auto idx_impl = torch_mlu::getMluTensorImpl(idx);
auto idx_ptr = idx_impl->cnnlMalloc();
cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1;
cnrtDim3_t k_dim;
k_dim.x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim.y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim.z = 1;
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(unknown.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelThreeNNForward<<<" << k_dim.x << ", "
<< k_dim.y << ", " << k_dim.z << ">>>.";
KernelThreeNNForward(k_dim, k_type, queue, data_type, unknown_ptr, known_ptr,
dist2_ptr, (int *)idx_ptr, b, n, m);
}
void three_nn_forward_mlu(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
ThreeNNMLUKernelLauncher(b, n, m, unknown, known, dist2, idx);
}
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx);
REGISTER_DEVICE_IMPL(three_nn_forward_impl, MLU, three_nn_forward_mlu);
......@@ -34,8 +34,8 @@ class ThreeNN(Function):
B, N, _ = target.size()
m = source.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
dist2 = torch.FloatTensor(B, N, 3).to(target.device)
idx = torch.IntTensor(B, N, 3).to(target.device)
ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
if torch.__version__ != 'parrots':
......
......@@ -3,70 +3,61 @@ import pytest
import torch
from mmcv.ops import three_nn
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_three_nn():
known = torch.tensor([[[-1.8373, 3.5605,
-0.7867], [0.7615, 2.9420, 0.2314],
[-0.6503, 3.6637, -1.0622],
[-1.8373, 3.5605, -0.7867],
[-1.8373, 3.5605, -0.7867]],
[[-1.3399, 1.9991, -0.3698],
[-0.0799, 0.9698,
-0.8457], [0.0858, 2.4721, -0.1928],
[-1.3399, 1.9991, -0.3698],
[-1.3399, 1.9991, -0.3698]]]).cuda()
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_three_nn(device):
known = torch.tensor(
[[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314],
[-0.6503, 3.6637, -1.0622], [-1.8373, 3.5605, -0.7867],
[-1.8373, 3.5605, -0.7867]],
[[-1.3399, 1.9991, -0.3698], [-0.0799, 0.9698, -0.8457],
[0.0858, 2.4721, -0.1928], [-1.3399, 1.9991, -0.3698],
[-1.3399, 1.9991, -0.3698]]],
device=device)
unknown = torch.tensor([[[-1.8373, 3.5605, -0.7867],
[0.7615, 2.9420, 0.2314],
[-0.6503, 3.6637, -1.0622],
[-1.5237, 2.3976, -0.8097],
[-0.0722, 3.4017, -0.2880],
[0.5198, 3.0661, -0.4605],
[-2.0185, 3.5019, -0.3236],
[0.5098, 3.1020, 0.5799],
[-1.6137, 3.8443, -0.5269],
[0.7341, 2.9626, -0.3189]],
[[-1.3399, 1.9991, -0.3698],
[-0.0799, 0.9698, -0.8457],
[0.0858, 2.4721, -0.1928],
[-0.9022, 1.6560, -1.3090],
[0.1156, 1.6901, -0.4366],
[-0.6477, 2.3576, -0.1563],
[-0.8482, 1.1466, -1.2704],
[-0.8753, 2.0845, -0.3460],
[-0.5621, 1.4233, -1.2858],
[-0.5883, 1.3114, -1.2899]]]).cuda()
unknown = torch.tensor(
[[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314],
[-0.6503, 3.6637, -1.0622], [-1.5237, 2.3976, -0.8097],
[-0.0722, 3.4017, -0.2880], [0.5198, 3.0661, -0.4605],
[-2.0185, 3.5019, -0.3236], [0.5098, 3.1020, 0.5799],
[-1.6137, 3.8443, -0.5269], [0.7341, 2.9626, -0.3189]],
[[-1.3399, 1.9991, -0.3698], [-0.0799, 0.9698, -0.8457],
[0.0858, 2.4721, -0.1928], [-0.9022, 1.6560, -1.3090],
[0.1156, 1.6901, -0.4366], [-0.6477, 2.3576, -0.1563],
[-0.8482, 1.1466, -1.2704], [-0.8753, 2.0845, -0.3460],
[-0.5621, 1.4233, -1.2858], [-0.5883, 1.3114, -1.2899]]],
device=device)
dist, idx = three_nn(unknown, known)
expected_dist = torch.tensor([[[0.0000, 0.0000, 0.0000],
[0.0000, 2.0463, 2.8588],
[0.0000, 1.2229, 1.2229],
[1.2047, 1.2047, 1.2047],
[1.0011, 1.0845, 1.8411],
[0.7433, 1.4451, 2.4304],
[0.5007, 0.5007, 0.5007],
[0.4587, 2.0875, 2.7544],
[0.4450, 0.4450, 0.4450],
[0.5514, 1.7206, 2.6811]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 1.6464, 1.6952],
[0.0000, 1.5125, 1.5125],
[1.0915, 1.0915, 1.0915],
[0.8197, 0.8511, 1.4894],
[0.7433, 0.8082, 0.8082],
[0.8955, 1.3340, 1.3340],
[0.4730, 0.4730, 0.4730],
[0.7949, 1.3325, 1.3325],
[0.7566, 1.3727, 1.3727]]]).cuda()
expected_idx = torch.tensor([[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
[2, 1, 0], [1, 2, 0], [0, 3, 4], [1, 2, 0],
[0, 3, 4], [1, 2, 0]],
[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4],
[2, 1, 0], [2, 0, 3], [1, 0, 3], [0, 3, 4],
[1, 0, 3], [1, 0, 3]]]).cuda()
expected_dist = torch.tensor(
[[[0.0000, 0.0000, 0.0000], [0.0000, 2.0463, 2.8588],
[0.0000, 1.2229, 1.2229], [1.2047, 1.2047, 1.2047],
[1.0011, 1.0845, 1.8411], [0.7433, 1.4451, 2.4304],
[0.5007, 0.5007, 0.5007], [0.4587, 2.0875, 2.7544],
[0.4450, 0.4450, 0.4450], [0.5514, 1.7206, 2.6811]],
[[0.0000, 0.0000, 0.0000], [0.0000, 1.6464, 1.6952],
[0.0000, 1.5125, 1.5125], [1.0915, 1.0915, 1.0915],
[0.8197, 0.8511, 1.4894], [0.7433, 0.8082, 0.8082],
[0.8955, 1.3340, 1.3340], [0.4730, 0.4730, 0.4730],
[0.7949, 1.3325, 1.3325], [0.7566, 1.3727, 1.3727]]],
device=device)
expected_idx = torch.tensor(
[[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4], [2, 1, 0], [1, 2, 0],
[0, 3, 4], [1, 2, 0], [0, 3, 4], [1, 2, 0]],
[[0, 3, 4], [1, 2, 0], [2, 0, 3], [0, 3, 4], [2, 1, 0], [2, 0, 3],
[1, 0, 3], [0, 3, 4], [1, 0, 3], [1, 0, 3]]],
device=device)
assert torch.allclose(dist, expected_dist, 1e-4)
assert torch.allclose(dist, expected_dist, atol=1e-4)
assert torch.all(idx == expected_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment