Unverified Commit 2810718a authored by sherie's avatar sherie Committed by GitHub
Browse files

[Feature] Add Ascend support for RoIPool op (#2483)


Co-authored-by: default avatarwangxiaoxin_sherie <wangxiaoxin7@huawei.com>
parent 48ea88ab
...@@ -43,7 +43,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -43,7 +43,7 @@ We implement common ops used in detection, segmentation, etc.
| PSAMask | √ | √ | √ | | √ | | PSAMask | √ | √ | √ | | √ |
| RotatedFeatureAlign | √ | √ | | | | | RotatedFeatureAlign | √ | √ | | | |
| RoIPointPool3d | | √ | √ | | | | RoIPointPool3d | | √ | √ | | |
| RoIPool | | √ | √ | | | | RoIPool | | √ | √ | | |
| RoIAlignRotated | √ | √ | √ | | | | RoIAlignRotated | √ | √ | √ | | |
| RiRoIAlignRotated | | √ | | | | | RiRoIAlignRotated | | √ | | | |
| RoIAlign | √ | √ | √ | | | | RoIAlign | √ | √ | √ | | |
......
...@@ -43,7 +43,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -43,7 +43,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| PSAMask | √ | √ | √ | | √ | | PSAMask | √ | √ | √ | | √ |
| RotatedFeatureAlign | √ | √ | | | | | RotatedFeatureAlign | √ | √ | | | |
| RoIPointPool3d | | √ | √ | | | | RoIPointPool3d | | √ | √ | | |
| RoIPool | | √ | √ | | | | RoIPool | | √ | √ | | |
| RoIAlignRotated | √ | √ | √ | | | | RoIAlignRotated | √ | √ | √ | | |
| RiRoIAlignRotated | | √ | | | | | RiRoIAlignRotated | | √ | | | |
| RoIAlign | √ | √ | √ | | | | RoIAlign | √ | √ | √ | | |
......
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;
void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale) {
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);
OpCommand cmd;
cmd.Name("RoiPoolingWithArgMax")
.Input(input)
.Input(rois)
.Input(roi_actual_num)
.Output(output)
.Output(argmax)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
}
void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale);
REGISTER_NPU_IMPL(roi_pool_forward_impl, roi_pool_forward_npu);
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
_USING_PARROTS = True _USING_PARROTS = True
try: try:
...@@ -69,10 +69,16 @@ class TestRoiPool: ...@@ -69,10 +69,16 @@ class TestRoiPool:
np_output = np.array(output[0]) np_output = np.array(output[0])
np_grad = np.array(output[1]) np_grad = np.array(output[1])
if device == 'npu':
import torch_npu # noqa: F401
x = torch.tensor(np_input, dtype=dtype).npu()
rois = torch.tensor(np_rois, dtype=dtype).npu()
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
else:
x = torch.tensor( x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True) np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device) rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale) output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output)) output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
...@@ -86,7 +92,11 @@ class TestRoiPool: ...@@ -86,7 +92,11 @@ class TestRoiPool:
pytest.param( pytest.param(
'mlu', 'mlu',
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')) not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
]) ])
@pytest.mark.parametrize('dtype', [ @pytest.mark.parametrize('dtype', [
torch.float, torch.float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment