Unverified Commit 615d2817 authored by xinliang123's avatar xinliang123 Committed by GitHub
Browse files

[Feature] Add Ascend support for RoIPoolGrad op (#2569)

* add roipoolgrad adapter

* amend
parent 04e8727a
...@@ -11,7 +11,6 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output, ...@@ -11,7 +11,6 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
int64_t pooled_channel = 1; int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor( at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois); {}, rois.options().dtype(at::kInt), rois);
OpCommand cmd; OpCommand cmd;
cmd.Name("RoiPoolingWithArgMax") cmd.Name("RoiPoolingWithArgMax")
.Input(input) .Input(input)
...@@ -27,8 +26,38 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output, ...@@ -27,8 +26,38 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
.Run(); .Run();
} }
void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);
at::Tensor x = at::ones_like(grad_input);
OpCommand cmd;
cmd.Name("RoiPoolingGradWithArgMax")
.Input(grad_output)
.Input(x)
.Input(rois)
.Input(roi_actual_num)
.Input(argmax)
.Output(grad_input)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
}
void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width, Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale); float spatial_scale);
void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
REGISTER_NPU_IMPL(roi_pool_forward_impl, roi_pool_forward_npu); REGISTER_NPU_IMPL(roi_pool_forward_impl, roi_pool_forward_npu);
REGISTER_NPU_IMPL(roi_pool_backward_impl, roi_pool_backward_npu);
...@@ -69,20 +69,13 @@ class TestRoiPool: ...@@ -69,20 +69,13 @@ class TestRoiPool:
np_output = np.array(output[0]) np_output = np.array(output[0])
np_grad = np.array(output[1]) np_grad = np.array(output[1])
if device == 'npu': x = torch.tensor(
import torch_npu # noqa: F401 np_input, dtype=dtype, device=device, requires_grad=True)
x = torch.tensor(np_input, dtype=dtype).npu() rois = torch.tensor(np_rois, dtype=dtype, device=device)
rois = torch.tensor(np_rois, dtype=dtype).npu() output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale) output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
else: assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
@pytest.mark.parametrize('device', [ @pytest.mark.parametrize('device', [
pytest.param( pytest.param(
...@@ -103,8 +96,8 @@ class TestRoiPool: ...@@ -103,8 +96,8 @@ class TestRoiPool:
pytest.param( pytest.param(
torch.double, torch.double,
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
IS_MLU_AVAILABLE, IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')), reason='MLU, NPU does not support for 64-bit floating point')),
torch.half torch.half
]) ])
def test_roipool_allclose(self, device, dtype): def test_roipool_allclose(self, device, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment