Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
43c5c76f
Unverified
Commit
43c5c76f
authored
Jul 06, 2023
by
long11111111
Committed by
GitHub
Jul 06, 2023
Browse files
[Feature] Add support of points_in_polyogns for Ascend device (#2848)
parent
d28aa8a9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
12 deletions
+51
-12
docs/en/understand_mmcv/ops.md
docs/en/understand_mmcv/ops.md
+1
-1
mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
+27
-0
mmcv/ops/points_in_polygons.py
mmcv/ops/points_in_polygons.py
+5
-2
tests/test_ops/test_points_in_polygons.py
tests/test_ops/test_points_in_polygons.py
+18
-9
No files found.
docs/en/understand_mmcv/ops.md
View file @
43c5c76f
...
@@ -39,7 +39,7 @@ We implement common ops used in detection, segmentation, etc.
...
@@ -39,7 +39,7 @@ We implement common ops used in detection, segmentation, etc.
| NMSQuadri | √ | √ | | | |
| NMSQuadri | √ | √ | | | |
| PixelGroup | √ | | | | |
| PixelGroup | √ | | | | |
| PointsInBoxes | √ | √ | | | |
| PointsInBoxes | √ | √ | | | |
| PointsInPolygons | | √ | | |
|
| PointsInPolygons | | √ | | |
√
|
| PSAMask | √ | √ | √ | | √ |
| PSAMask | √ | √ | √ | | √ |
| RotatedFeatureAlign | √ | √ | √ | | |
| RotatedFeatureAlign | √ | √ | √ | | |
| RoIPointPool3d | | √ | √ | | |
| RoIPointPool3d | | √ | √ | | |
...
...
mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
0 → 100644
View file @
43c5c76f
#include "pytorch_npu_helper.hpp"
using
namespace
NPU_NAME_SPACE
;
using
namespace
std
;
constexpr
int32_t
MAX_POLYGONS_BATCH
=
2800
;
void
points_in_polygons_npu
(
const
Tensor
points
,
Tensor
polygons
,
Tensor
output
,
const
int
rows
,
const
int
cols
)
{
TORCH_CHECK
(
(
polygons
.
sizes
()[
0
]
<=
MAX_POLYGONS_BATCH
),
"The batch of polygons tensor must be less than MAX_POLYGONS_BATCH"
);
at
::
Tensor
trans_polygons
=
polygons
.
transpose
(
0
,
1
);
OpCommand
cmd
;
at
::
Tensor
new_trans_polygons
=
NpuUtils
::
format_contiguous
(
trans_polygons
);
cmd
.
Name
(
"PointsInPolygons"
)
.
Input
(
points
,
(
string
)
"points"
)
.
Input
(
new_trans_polygons
,
(
string
)
"polygons"
)
.
Output
(
output
)
.
Run
();
}
void
points_in_polygons_forward_impl
(
const
Tensor
points
,
Tensor
polygons
,
Tensor
output
,
const
int
rows
,
const
int
cols
);
REGISTER_NPU_IMPL
(
points_in_polygons_forward_impl
,
points_in_polygons_npu
);
mmcv/ops/points_in_polygons.py
View file @
43c5c76f
...
@@ -31,8 +31,11 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor:
...
@@ -31,8 +31,11 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor:
assert
polygons
.
shape
[
1
]
==
8
,
\
assert
polygons
.
shape
[
1
]
==
8
,
\
'polygons dimension should be 8, '
\
'polygons dimension should be 8, '
\
f
'but got unexpected shape
{
polygons
.
shape
[
1
]
}
'
f
'but got unexpected shape
{
polygons
.
shape
[
1
]
}
'
output
=
torch
.
full
([
points
.
shape
[
0
],
polygons
.
shape
[
0
]],
output
=
torch
.
zeros
(
0.
).
cuda
().
float
()
points
.
shape
[
0
],
polygons
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
points
.
device
)
ext_module
.
points_in_polygons_forward
(
points
.
contiguous
(),
ext_module
.
points_in_polygons_forward
(
points
.
contiguous
(),
polygons
.
contiguous
(),
output
)
polygons
.
contiguous
(),
output
)
return
output
return
output
tests/test_ops/test_points_in_polygons.py
View file @
43c5c76f
...
@@ -4,20 +4,29 @@ import pytest
...
@@ -4,20 +4,29 @@ import pytest
import
torch
import
torch
from
mmcv.ops
import
points_in_polygons
from
mmcv.ops
import
points_in_polygons
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_NPU_AVAILABLE
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
parametrize
(
'device'
,
[
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
pytest
.
param
(
def
test_points_in_polygons
():
'cuda'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_CUDA_AVAILABLE
,
reason
=
'requires CUDA support'
)),
pytest
.
param
(
'npu'
,
marks
=
pytest
.
mark
.
skipif
(
not
IS_NPU_AVAILABLE
,
reason
=
'requires NPU support'
))
])
def
test_points_in_polygons
(
device
):
points
=
np
.
array
([[
300.
,
300.
],
[
400.
,
400.
],
[
100.
,
100
],
[
300
,
250
],
points
=
np
.
array
([[
300.
,
300.
],
[
400.
,
400.
],
[
100.
,
100
],
[
300
,
250
],
[
100
,
0
]])
[
100
,
0
]])
polygons
=
np
.
array
([[
200.
,
200.
,
400.
,
400.
,
500.
,
200.
,
400.
,
100.
],
polygons
=
np
.
array
([[
200.
,
200.
,
400.
,
400.
,
500.
,
200.
,
400.
,
100.
],
[
400.
,
400.
,
500.
,
500.
,
600.
,
300.
,
500.
,
200.
],
[
400.
,
400.
,
500.
,
500.
,
600.
,
300.
,
500.
,
200.
],
[
300.
,
300.
,
600.
,
700.
,
700.
,
700.
,
700.
,
100.
]])
[
300.
,
300.
,
600.
,
700.
,
700.
,
700.
,
700.
,
100.
]])
expected_output
=
np
.
array
([[
0.
,
0.
,
0.
],
[
0.
,
0.
,
1.
],
[
0.
,
0.
,
0.
],
expected_output
=
np
.
array
([[
0.
,
0.
,
0.
],
[
0.
,
0.
,
1.
],
[
0.
,
0.
,
0.
],
[
1.
,
0.
,
0.
],
[
0.
,
0.
,
0.
]])
[
1.
,
0.
,
0.
],
[
0.
,
0.
,
0.
]])
.
astype
(
np
.
float32
)
points
=
torch
.
from_numpy
(
points
).
cuda
().
float
(
)
points
=
torch
.
tensor
(
points
,
dtype
=
torch
.
float32
,
device
=
device
)
polygons
=
torch
.
from_numpy
(
polygons
).
cuda
().
float
(
)
polygons
=
torch
.
tensor
(
polygons
,
dtype
=
torch
.
float32
,
device
=
device
)
expected_output
=
torch
.
from_numpy
(
expected_output
).
cuda
().
float
()
assert
np
.
allclose
(
assert
torch
.
allclose
(
points_in_polygons
(
points
,
polygons
).
cpu
().
numpy
(),
expected_output
,
points_in_polygons
(
points
,
polygons
),
expected_output
,
1e-3
)
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment