Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ca99624f
Unverified
Commit
ca99624f
authored
Sep 19, 2023
by
sherie
Committed by
GitHub
Sep 19, 2023
Browse files
[Fix] Fix the support for nms_rotated in Ascend (#2931)
parent
b361a81a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
9 deletions
+62
-9
mmcv/ops/csrc/common/pytorch_npu_helper.hpp
mmcv/ops/csrc/common/pytorch_npu_helper.hpp
+10
-5
mmcv/ops/csrc/pytorch/nms_rotated.cpp
mmcv/ops/csrc/pytorch/nms_rotated.cpp
+5
-3
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
+44
-0
setup.py
setup.py
+3
-1
No files found.
mmcv/ops/csrc/common/pytorch_npu_helper.hpp
View file @
ca99624f
...
@@ -27,16 +27,21 @@
...
@@ -27,16 +27,21 @@
#define NPU_NAME_SPACE at_npu::native
#define NPU_NAME_SPACE at_npu::native
#if MMCV_WITH_XLA
#if
def
MMCV_WITH_XLA
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#else
#else
#define REGISTER_NPU_IMPL(key, value) \
#define REGISTER_NPU_IMPL(key, value) \
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif
#endif
#define CHECK_NPU(x) \
#ifdef MMCV_WITH_XLA
TORCH_CHECK( \
#define CHECK_NPU(x) \
x.device().type() == at::kXLA || x.device().type() == at::kPrivateUse1, \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
#x " must be a NPU tensor")
#else
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kPrivateUse1, #x \
" must be a NPU " \
"tensor")
#endif
#endif // PYTORCH_NPU_HELPER_HPP_
#endif // PYTORCH_NPU_HELPER_HPP_
mmcv/ops/csrc/pytorch/nms_rotated.cpp
View file @
ca99624f
...
@@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
...
@@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
#else
#else
AT_ERROR
(
"Not compiled with GPU support"
);
AT_ERROR
(
"Not compiled with GPU support"
);
#endif
#endif
#ifdef MMCV_WITH_XLA
}
else
if
(
dets
.
device
().
type
()
==
at
::
kXLA
)
{
}
else
if
(
dets
.
device
().
type
()
==
at
::
kXLA
)
{
#ifdef MMCV_WITH_NPU
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
#else
#endif
AT_ERROR
(
"Not compiled with NPU support"
);
#ifdef MMCV_WITH_KPRIVATE
}
else
if
(
dets
.
device
().
type
()
==
at
::
kPrivateUse1
)
{
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
#endif
#endif
#ifdef MMCV_WITH_MLU
#ifdef MMCV_WITH_MLU
}
else
if
(
dets
.
device
().
type
()
==
at
::
kMLU
)
{
}
else
if
(
dets
.
device
().
type
()
==
at
::
kMLU
)
{
...
...
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
View file @
ca99624f
...
@@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
...
@@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
.
Attr
(
"batch_dims"
,
batch_dims
)
.
Attr
(
"batch_dims"
,
batch_dims
)
.
Run
();
.
Run
();
}
}
void
gather_points_backward_npu
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
grad_out
,
const
Tensor
idx
,
Tensor
grad_points
)
{
at
::
Tensor
indices
=
idx
;
if
(
idx
.
scalar_type
()
!=
at
::
ScalarType
::
Int
)
{
indices
=
idx
.
to
(
at
::
kInt
);
}
if
(
idx
.
dim
()
==
0
)
{
indices
.
unsqueeze_
(
0
);
}
int64_t
dim
=
0
;
at
::
SmallVector
<
int64_t
,
N
>
pad_size
=
array_to_small_vector
(
idx
.
sizes
());
at
::
Tensor
trans_grad_points
=
grad_points
.
transpose
(
1
,
2
).
contiguous
();
at
::
Tensor
grad_points_view
=
trans_grad_points
.
view
(
{
trans_grad_points
.
sizes
()[
0
]
*
trans_grad_points
.
sizes
()[
1
],
trans_grad_points
.
sizes
()[
2
]});
at
::
Tensor
trans_grad_out
=
grad_out
.
transpose
(
1
,
2
).
contiguous
();
trans_grad_out
=
trans_grad_out
.
view
(
{
trans_grad_out
.
sizes
()[
0
]
*
trans_grad_out
.
sizes
()[
1
],
trans_grad_out
.
sizes
()[
2
]});
auto
index
=
at
::
arange
(
0
,
b
);
index
=
index
.
to
(
grad_out
.
device
());
index
=
at
::
mul
(
index
,
n
);
index
=
index
.
view
({
b
,
1
});
index
=
at
::
broadcast_to
(
index
,
pad_size
);
indices
=
at
::
add
(
index
,
indices
);
indices
=
indices
.
view
({
-
1
});
OpCommand
cmd
;
cmd
.
Name
(
"InplaceIndexAdd"
)
.
Input
(
grad_points_view
)
.
Input
(
indices
)
.
Input
(
trans_grad_out
)
.
Output
(
grad_points_view
)
.
Attr
(
"axis"
,
dim
)
.
Run
();
at
::
Tensor
grad_points_result
=
grad_points_view
.
view
(
trans_grad_points
.
sizes
());
grad_points_result
=
grad_points_result
.
transpose
(
1
,
2
);
grad_points
.
copy_
(
grad_points_result
);
}
void
gather_points_forward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
void
gather_points_forward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
points
,
const
Tensor
idx
,
const
Tensor
points
,
const
Tensor
idx
,
Tensor
out
);
Tensor
out
);
void
gather_points_backward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
grad_out
,
const
Tensor
idx
,
Tensor
grad_points
);
REGISTER_NPU_IMPL
(
gather_points_forward_impl
,
gather_points_forward_npu
);
REGISTER_NPU_IMPL
(
gather_points_forward_impl
,
gather_points_forward_npu
);
REGISTER_NPU_IMPL
(
gather_points_backward_impl
,
gather_points_backward_npu
);
setup.py
View file @
ca99624f
...
@@ -396,8 +396,10 @@ def get_extensions():
...
@@ -396,8 +396,10 @@ def get_extensions():
from
torch_npu.utils.cpp_extension
import
NpuExtension
from
torch_npu.utils.cpp_extension
import
NpuExtension
define_macros
+=
[(
'MMCV_WITH_NPU'
,
None
)]
define_macros
+=
[(
'MMCV_WITH_NPU'
,
None
)]
extension
=
NpuExtension
extension
=
NpuExtension
if
parse_version
(
torch
.
__version__
)
>
=
parse_version
(
'2.0.0'
):
if
parse_version
(
torch
.
__version__
)
<
=
parse_version
(
'2.0.0'
):
define_macros
+=
[(
'MMCV_WITH_XLA'
,
None
)]
define_macros
+=
[(
'MMCV_WITH_XLA'
,
None
)]
if
parse_version
(
torch
.
__version__
)
>
parse_version
(
'2.0.0'
):
define_macros
+=
[(
'MMCV_WITH_KPRIVATE'
,
None
)]
except
Exception
:
except
Exception
:
raise
ImportError
(
'can not find any torch_npu'
)
raise
ImportError
(
'can not find any torch_npu'
)
# src
# src
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment