Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
8b8bf5e1
Unverified
Commit
8b8bf5e1
authored
Aug 28, 2023
by
Chris Jiang
Committed by
GitHub
Aug 28, 2023
Browse files
[Refactor] Replace roipoint_pool3d op of MLU backend with mlu-ops (#2875)
parent
099ee24d
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
1142 deletions
+42
-1142
mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu
...common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu
+0
-536
mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu
+0
-544
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
+42
-62
No files found.
mmcv/ops/csrc/common/mlu/roipoint_pool3d_large_boxes_num_mlu_kernel.mlu
deleted
100644 → 0
View file @
099ee24d
This diff is collapsed.
Click to expand it.
mmcv/ops/csrc/common/mlu/roipoint_pool3d_mlu_kernel.mlu
deleted
100644 → 0
View file @
099ee24d
This diff is collapsed.
Click to expand it.
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
View file @
8b8bf5e1
...
@@ -9,32 +9,7 @@
...
@@ -9,32 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelRoiPointPool3dForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
batch_size
,
const
int
pts_num
,
const
int
boxes_num
,
const
int
feature_in_len
,
const
int
sampled_pts_num
,
const
void
*
xyz
,
const
void
*
boxes3d
,
const
void
*
pts_feature
,
void
*
pooled_features
,
int
*
pooled_empty_flag
);
void
KernelRoiPointPool3dLargeBoxesNumForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
batch_size
,
const
int
pts_num
,
const
int
boxes_num
,
const
int
feature_in_len
,
const
int
sampled_pts_num
,
const
void
*
xyz
,
const
void
*
boxes3d
,
const
void
*
pts_feature
,
void
*
pooled_features
,
int
*
pooled_empty_flag
);
// policy function
static
void
policyFuncForward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// start U1 task, occupy all available clusters
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
void
RoIPointPool3dForwardMLUKernelLauncher
(
void
RoIPointPool3dForwardMLUKernelLauncher
(
int
batch_size
,
int
pts_num
,
int
boxes_num
,
int
feature_in_len
,
int
batch_size
,
int
pts_num
,
int
boxes_num
,
int
feature_in_len
,
...
@@ -98,50 +73,55 @@ void RoIPointPool3dForwardMLUKernelLauncher(
...
@@ -98,50 +73,55 @@ void RoIPointPool3dForwardMLUKernelLauncher(
"pts_feature element num should be less than 2^31, got "
,
"pts_feature element num should be less than 2^31, got "
,
pts_feature
.
numel
(),
"."
);
pts_feature
.
numel
(),
"."
);
// calculate task dimension
// set contiguous
cnrtDim3_t
k_dim
;
auto
xyz_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
cnrtFunctionType_t
k_type
;
xyz
,
xyz
.
suggest_memory_format
());
policyFuncForward
(
&
k_dim
,
&
k_type
);
auto
pts_feature_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_feature
,
pts_feature
.
suggest_memory_format
());
// get compute queue
auto
boxes3d_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
auto
queue
=
torch_mlu
::
getCurQueue
();
boxes3d
,
boxes3d
.
suggest_memory_format
());
auto
pooled_features_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pooled_features
,
pooled_features
.
suggest_memory_format
());
auto
pooled_empty_flag_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pooled_empty_flag
,
pooled_empty_flag
.
suggest_memory_format
());
// get ptr of tensors
// get ptr of tensors
// transpose points [B, N ,3] -> [3, B, N]
auto
xyz_impl
=
torch_mlu
::
getMluTensorImpl
(
xyz_contiguous
);
auto
xyz_
=
xyz
.
permute
({
2
,
0
,
1
}).
contiguous
();
auto
xyz_impl
=
torch_mlu
::
getMluTensorImpl
(
xyz_
);
auto
xyz_ptr
=
xyz_impl
->
cnnlMalloc
();
auto
xyz_ptr
=
xyz_impl
->
cnnlMalloc
();
// transpose point_features [B, N, C] -> [B, C, N]
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_contiguous
);
auto
pts_feature_
=
pts_feature
.
permute
({
0
,
2
,
1
}).
contiguous
();
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_
);
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
boxes3d_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes3d
);
auto
boxes3d_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes3d
_contiguous
);
auto
boxes3d_ptr
=
boxes3d_impl
->
cnnlMalloc
();
auto
boxes3d_ptr
=
boxes3d_impl
->
cnnlMalloc
();
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features
);
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features
_contiguous
);
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
pooled_empty_flag_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_empty_flag
);
auto
pooled_empty_flag_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_empty_flag
_contiguous
);
auto
pooled_empty_flag_ptr
=
pooled_empty_flag_impl
->
cnnlMalloc
();
auto
pooled_empty_flag_ptr
=
pooled_empty_flag_impl
->
cnnlMalloc
();
// get compute dtype of input
// create tensor descriptors
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
xyz_
.
dtype
());
MluOpTensorDescriptor
xyz_desc
,
pts_feature_desc
,
boxes3d_desc
,
pooled_features_desc
,
pooled_empty_flag_desc
;
xyz_desc
.
set
(
xyz_contiguous
);
// launch kernel
pts_feature_desc
.
set
(
pts_feature_contiguous
);
if
(
boxes_num
<=
10240
)
{
boxes3d_desc
.
set
(
boxes3d_contiguous
);
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelRoiPointPool3dForward<<<"
<<
k_dim
.
x
pooled_features_desc
.
set
(
pooled_features_contiguous
);
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
pooled_empty_flag_desc
.
set
(
pooled_empty_flag_contiguous
);
KernelRoiPointPool3dForward
(
k_dim
,
k_type
,
queue
,
data_type
,
batch_size
,
pts_num
,
boxes_num
,
// get workspace
feature_in_len
,
sampled_pts_num
,
xyz_ptr
,
boxes3d_ptr
,
pts_feature_ptr
,
size_t
workspace_size
=
0
;
pooled_features_ptr
,
(
int
*
)
pooled_empty_flag_ptr
);
auto
handle
=
mluOpGetCurrentHandle
();
}
else
{
TORCH_MLUOP_CHECK
(
mluOpGetRoiPointPool3dWorkspaceSize
(
handle
,
batch_size
,
CNLOG
(
INFO
)
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_desc
.
desc
(),
<<
"Launch Kernel MLUKernelRoiPointPool3dLargeBoxesNumForward<<<"
pts_feature_desc
.
desc
(),
boxes3d_desc
.
desc
(),
pooled_features_desc
.
desc
(),
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
pooled_empty_flag_desc
.
desc
(),
&
workspace_size
));
KernelRoiPointPool3dLargeBoxesNumForward
(
k_dim
,
k_type
,
queue
,
data_type
,
batch_size
,
pts_num
,
boxes_num
,
auto
workspace
=
at
::
empty
(
workspace_size
,
xyz
.
options
().
dtype
(
at
::
kByte
));
feature_in_len
,
sampled_pts_num
,
xyz_ptr
,
boxes3d_ptr
,
pts_feature_ptr
,
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
pooled_features_ptr
,
(
int
*
)
pooled_empty_flag_ptr
);
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
}
TORCH_MLUOP_CHECK
(
mluOpRoiPointPool3d
(
handle
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_desc
.
desc
(),
xyz_ptr
,
pts_feature_desc
.
desc
(),
pts_feature_ptr
,
boxes3d_desc
.
desc
(),
boxes3d_ptr
,
workspace_ptr
,
workspace_size
,
pooled_features_desc
.
desc
(),
pooled_features_ptr
,
pooled_empty_flag_desc
.
desc
(),
(
int
*
)
pooled_empty_flag_ptr
));
}
}
void
roipoint_pool3d_forward_mlu
(
int
batch_size
,
int
pts_num
,
int
boxes_num
,
void
roipoint_pool3d_forward_mlu
(
int
batch_size
,
int
pts_num
,
int
boxes_num
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment