Unverified Commit 73856344 authored by DaGaiBa's avatar DaGaiBa Committed by GitHub
Browse files

[Feature] Support PSAMask op for Ascend device (#2487)

parent fdc052e8
...@@ -40,7 +40,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -40,7 +40,7 @@ We implement common ops used in detection, segmentation, etc.
| PixelGroup | √ | | | | | | PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | | | PointsInBoxes | √ | √ | | | |
| PointsInPolygons | | √ | | | | | PointsInPolygons | | √ | | | |
| PSAMask | √ | √ | √ | | | | PSAMask | √ | √ | √ | | |
| RotatedFeatureAlign | √ | √ | | | | | RotatedFeatureAlign | √ | √ | | | |
| RoIPointPool3d | | √ | √ | | | | RoIPointPool3d | | √ | √ | | |
| RoIPool | | √ | √ | | | | RoIPool | | √ | √ | | |
......
...@@ -40,7 +40,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -40,7 +40,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| PixelGroup | √ | | | | | | PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | | | PointsInBoxes | √ | √ | | | |
| PointsInPolygons | | √ | | | | | PointsInPolygons | | √ | | | |
| PSAMask | √ | √ | √ | | | | PSAMask | √ | √ | √ | | |
| RotatedFeatureAlign | √ | √ | | | | | RotatedFeatureAlign | √ | √ | | | |
| RoIPointPool3d | | √ | √ | | | | RoIPointPool3d | | √ | √ | | |
| RoIPool | | √ | √ | | | | RoIPool | | √ | √ | | |
......
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;
void psamask_forward_npu(const int psa_type,
const Tensor x,
Tensor y,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask) {
int64_t psa_type_i64 = psa_type;
int64_t num_i64 = num;
int64_t h_feature_i64 = h_feature;
int64_t w_feature_i64 = w_feature;
int64_t h_mask_i64 = h_mask;
int64_t w_mask_i64 = w_mask;
int64_t half_h_mask_i64 = half_h_mask;
int64_t half_w_mask_i64 = half_w_mask;
OpCommand cmd;
cmd.Name("PSAMask")
.Input(x)
.Output(y)
.Attr("psa_type", psa_type_i64)
.Attr("num", num_i64)
.Attr("h_feature", h_feature_i64)
.Attr("w_feature", w_feature_i64)
.Attr("h_mask", h_mask_i64)
.Attr("w_mask", w_mask_i64)
.Attr("half_h_mask", half_h_mask_i64)
.Attr("half_w_mask", half_w_mask_i64)
.Run();
}
void psamask_forward_impl(const int psa_type,
const Tensor x,
Tensor y,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask);
void psamask_backward_npu(const int psa_type,
const Tensor y_grad,
Tensor x_grad,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask) {
int64_t psa_type_i64 = psa_type;
int64_t num_i64 = num;
int64_t h_feature_i64 = h_feature;
int64_t w_feature_i64 = w_feature;
int64_t h_mask_i64 = h_mask;
int64_t w_mask_i64 = w_mask;
int64_t half_h_mask_i64 = half_h_mask;
int64_t half_w_mask_i64 = half_w_mask;
OpCommand cmd;
cmd.Name("PSAMaskGrad")
.Input(y_grad)
.Output(x_grad)
.Attr("psa_type", psa_type_i64)
.Attr("num", num_i64)
.Attr("h_feature", h_feature_i64)
.Attr("w_feature", w_feature_i64)
.Attr("h_mask", h_mask_i64)
.Attr("w_mask", w_mask_i64)
.Attr("half_h_mask", half_h_mask_i64)
.Attr("half_w_mask", half_w_mask_i64)
.Run();
}
void psamask_backward_impl(const int psa_type,
const Tensor y_grad,
Tensor x_grad,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask);
REGISTER_NPU_IMPL(psamask_forward_impl, psamask_forward_npu);
REGISTER_NPU_IMPL(psamask_backward_impl, psamask_backward_npu);
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
class Loss(nn.Module): class Loss(nn.Module):
...@@ -28,7 +28,11 @@ class TestPSAMask: ...@@ -28,7 +28,11 @@ class TestPSAMask:
pytest.param( pytest.param(
'mlu', 'mlu',
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')) not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
]) ])
def test_psa_mask_collect(self, device): def test_psa_mask_collect(self, device):
from mmcv.ops import PSAMask from mmcv.ops import PSAMask
...@@ -76,7 +80,11 @@ class TestPSAMask: ...@@ -76,7 +80,11 @@ class TestPSAMask:
pytest.param( pytest.param(
'mlu', 'mlu',
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')) not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
]) ])
def test_psa_mask_distribute(self, device): def test_psa_mask_distribute(self, device):
from mmcv.ops import PSAMask from mmcv.ops import PSAMask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment