Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ca99624f
You need to sign in or sign up before continuing.
Unverified
Commit
ca99624f
authored
Sep 19, 2023
by
sherie
Committed by
GitHub
Sep 19, 2023
Browse files
[Fix] Fix the support for nms_rotated in Ascend (#2931)
parent
b361a81a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
9 deletions
+62
-9
mmcv/ops/csrc/common/pytorch_npu_helper.hpp
mmcv/ops/csrc/common/pytorch_npu_helper.hpp
+10
-5
mmcv/ops/csrc/pytorch/nms_rotated.cpp
mmcv/ops/csrc/pytorch/nms_rotated.cpp
+5
-3
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
+44
-0
setup.py
setup.py
+3
-1
No files found.
mmcv/ops/csrc/common/pytorch_npu_helper.hpp
View file @
ca99624f
...
@@ -27,16 +27,21 @@
...
@@ -27,16 +27,21 @@
#define NPU_NAME_SPACE at_npu::native
#define NPU_NAME_SPACE at_npu::native
#if MMCV_WITH_XLA
#if
def
MMCV_WITH_XLA
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#else
#else
#define REGISTER_NPU_IMPL(key, value) \
#define REGISTER_NPU_IMPL(key, value) \
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif
#endif
#ifdef MMCV_WITH_XLA
#define CHECK_NPU(x) \
#define CHECK_NPU(x) \
TORCH_CHECK( \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
x.device().type() == at::kXLA || x.device().type() == at::kPrivateUse1, \
#else
#x " must be a NPU tensor")
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kPrivateUse1, #x \
" must be a NPU " \
"tensor")
#endif
#endif // PYTORCH_NPU_HELPER_HPP_
#endif // PYTORCH_NPU_HELPER_HPP_
mmcv/ops/csrc/pytorch/nms_rotated.cpp
View file @
ca99624f
...
@@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
...
@@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
#else
#else
AT_ERROR
(
"Not compiled with GPU support"
);
AT_ERROR
(
"Not compiled with GPU support"
);
#endif
#endif
#ifdef MMCV_WITH_XLA
}
else
if
(
dets
.
device
().
type
()
==
at
::
kXLA
)
{
}
else
if
(
dets
.
device
().
type
()
==
at
::
kXLA
)
{
#ifdef MMCV_WITH_NPU
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
#else
#endif
AT_ERROR
(
"Not compiled with NPU support"
);
#ifdef MMCV_WITH_KPRIVATE
}
else
if
(
dets
.
device
().
type
()
==
at
::
kPrivateUse1
)
{
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
#endif
#endif
#ifdef MMCV_WITH_MLU
#ifdef MMCV_WITH_MLU
}
else
if
(
dets
.
device
().
type
()
==
at
::
kMLU
)
{
}
else
if
(
dets
.
device
().
type
()
==
at
::
kMLU
)
{
...
...
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
View file @
ca99624f
...
@@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
...
@@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
.
Attr
(
"batch_dims"
,
batch_dims
)
.
Attr
(
"batch_dims"
,
batch_dims
)
.
Run
();
.
Run
();
}
}
void
gather_points_backward_npu
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
grad_out
,
const
Tensor
idx
,
Tensor
grad_points
)
{
at
::
Tensor
indices
=
idx
;
if
(
idx
.
scalar_type
()
!=
at
::
ScalarType
::
Int
)
{
indices
=
idx
.
to
(
at
::
kInt
);
}
if
(
idx
.
dim
()
==
0
)
{
indices
.
unsqueeze_
(
0
);
}
int64_t
dim
=
0
;
at
::
SmallVector
<
int64_t
,
N
>
pad_size
=
array_to_small_vector
(
idx
.
sizes
());
at
::
Tensor
trans_grad_points
=
grad_points
.
transpose
(
1
,
2
).
contiguous
();
at
::
Tensor
grad_points_view
=
trans_grad_points
.
view
(
{
trans_grad_points
.
sizes
()[
0
]
*
trans_grad_points
.
sizes
()[
1
],
trans_grad_points
.
sizes
()[
2
]});
at
::
Tensor
trans_grad_out
=
grad_out
.
transpose
(
1
,
2
).
contiguous
();
trans_grad_out
=
trans_grad_out
.
view
(
{
trans_grad_out
.
sizes
()[
0
]
*
trans_grad_out
.
sizes
()[
1
],
trans_grad_out
.
sizes
()[
2
]});
auto
index
=
at
::
arange
(
0
,
b
);
index
=
index
.
to
(
grad_out
.
device
());
index
=
at
::
mul
(
index
,
n
);
index
=
index
.
view
({
b
,
1
});
index
=
at
::
broadcast_to
(
index
,
pad_size
);
indices
=
at
::
add
(
index
,
indices
);
indices
=
indices
.
view
({
-
1
});
OpCommand
cmd
;
cmd
.
Name
(
"InplaceIndexAdd"
)
.
Input
(
grad_points_view
)
.
Input
(
indices
)
.
Input
(
trans_grad_out
)
.
Output
(
grad_points_view
)
.
Attr
(
"axis"
,
dim
)
.
Run
();
at
::
Tensor
grad_points_result
=
grad_points_view
.
view
(
trans_grad_points
.
sizes
());
grad_points_result
=
grad_points_result
.
transpose
(
1
,
2
);
grad_points
.
copy_
(
grad_points_result
);
}
void
gather_points_forward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
void
gather_points_forward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
points
,
const
Tensor
idx
,
const
Tensor
points
,
const
Tensor
idx
,
Tensor
out
);
Tensor
out
);
void
gather_points_backward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
grad_out
,
const
Tensor
idx
,
Tensor
grad_points
);
REGISTER_NPU_IMPL
(
gather_points_forward_impl
,
gather_points_forward_npu
);
REGISTER_NPU_IMPL
(
gather_points_forward_impl
,
gather_points_forward_npu
);
REGISTER_NPU_IMPL
(
gather_points_backward_impl
,
gather_points_backward_npu
);
setup.py
View file @
ca99624f
...
@@ -396,8 +396,10 @@ def get_extensions():
...
@@ -396,8 +396,10 @@ def get_extensions():
from
torch_npu.utils.cpp_extension
import
NpuExtension
from
torch_npu.utils.cpp_extension
import
NpuExtension
define_macros
+=
[(
'MMCV_WITH_NPU'
,
None
)]
define_macros
+=
[(
'MMCV_WITH_NPU'
,
None
)]
extension
=
NpuExtension
extension
=
NpuExtension
if
parse_version
(
torch
.
__version__
)
>
=
parse_version
(
'2.0.0'
):
if
parse_version
(
torch
.
__version__
)
<
=
parse_version
(
'2.0.0'
):
define_macros
+=
[(
'MMCV_WITH_XLA'
,
None
)]
define_macros
+=
[(
'MMCV_WITH_XLA'
,
None
)]
if
parse_version
(
torch
.
__version__
)
>
parse_version
(
'2.0.0'
):
define_macros
+=
[(
'MMCV_WITH_KPRIVATE'
,
None
)]
except
Exception
:
except
Exception
:
raise
ImportError
(
'can not find any torch_npu'
)
raise
ImportError
(
'can not find any torch_npu'
)
# src
# src
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment