Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
91da9643
Commit
91da9643
authored
Aug 13, 2024
by
limm
Browse files
support v2.1.0
parent
6f674c7e
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1437 additions
and
311 deletions
+1437
-311
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
+46
-62
mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp
+115
-0
mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp
+156
-0
mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp
+444
-0
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
+33
-70
mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp
+45
-110
mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp
+99
-0
mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp
mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp
+174
-2
mmcv/ops/csrc/pytorch/nms.cpp
mmcv/ops/csrc/pytorch/nms.cpp
+45
-0
mmcv/ops/csrc/pytorch/nms_rotated.cpp
mmcv/ops/csrc/pytorch/nms_rotated.cpp
+14
-3
mmcv/ops/csrc/pytorch/npu/active_rotated_filter_npu.cpp
mmcv/ops/csrc/pytorch/npu/active_rotated_filter_npu.cpp
+36
-0
mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp
mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp
+28
-11
mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp
mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp
+47
-0
mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
+13
-24
mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp
mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp
+3
-2
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
+44
-0
mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp
mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp
+45
-0
mmcv/ops/csrc/pytorch/npu/nms_npu.cpp
mmcv/ops/csrc/pytorch/npu/nms_npu.cpp
+15
-20
mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp
mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp
+8
-7
mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
+27
-0
No files found.
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
View file @
91da9643
...
...
@@ -9,32 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelRoiPointPool3dForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
batch_size
,
const
int
pts_num
,
const
int
boxes_num
,
const
int
feature_in_len
,
const
int
sampled_pts_num
,
const
void
*
xyz
,
const
void
*
boxes3d
,
const
void
*
pts_feature
,
void
*
pooled_features
,
int
*
pooled_empty_flag
);
void
KernelRoiPointPool3dLargeBoxesNumForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
batch_size
,
const
int
pts_num
,
const
int
boxes_num
,
const
int
feature_in_len
,
const
int
sampled_pts_num
,
const
void
*
xyz
,
const
void
*
boxes3d
,
const
void
*
pts_feature
,
void
*
pooled_features
,
int
*
pooled_empty_flag
);
// policy function
static
void
policyFuncForward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// start U1 task, occupy all available clusters
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
#include "mlu_common_helper.h"
void
RoIPointPool3dForwardMLUKernelLauncher
(
int
batch_size
,
int
pts_num
,
int
boxes_num
,
int
feature_in_len
,
...
...
@@ -98,50 +73,59 @@ void RoIPointPool3dForwardMLUKernelLauncher(
"pts_feature element num should be less than 2^31, got "
,
pts_feature
.
numel
(),
"."
);
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncForward
(
&
k_dim
,
&
k_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// set contiguous
auto
xyz_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
xyz
,
xyz
.
suggest_memory_format
());
auto
pts_feature_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_feature
,
pts_feature
.
suggest_memory_format
());
auto
boxes3d_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes3d
,
boxes3d
.
suggest_memory_format
());
auto
pooled_features_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pooled_features
,
pooled_features
.
suggest_memory_format
());
auto
pooled_empty_flag_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pooled_empty_flag
,
pooled_empty_flag
.
suggest_memory_format
());
// get ptr of tensors
// transpose points [B, N ,3] -> [3, B, N]
auto
xyz_
=
xyz
.
permute
({
2
,
0
,
1
}).
contiguous
();
auto
xyz_impl
=
torch_mlu
::
getMluTensorImpl
(
xyz_
);
auto
xyz_impl
=
torch_mlu
::
getMluTensorImpl
(
xyz_contiguous
);
auto
xyz_ptr
=
xyz_impl
->
cnnlMalloc
();
// transpose point_features [B, N, C] -> [B, C, N]
auto
pts_feature_
=
pts_feature
.
permute
({
0
,
2
,
1
}).
contiguous
();
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_
);
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_contiguous
);
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
boxes3d_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes3d
);
auto
boxes3d_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes3d
_contiguous
);
auto
boxes3d_ptr
=
boxes3d_impl
->
cnnlMalloc
();
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features
);
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features_contiguous
);
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
pooled_empty_flag_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_empty_flag
);
auto
pooled_empty_flag_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_empty_flag_contiguous
);
auto
pooled_empty_flag_ptr
=
pooled_empty_flag_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
xyz_
.
dtype
());
// launch kernel
if
(
boxes_num
<=
10240
)
{
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelRoiPointPool3dForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelRoiPointPool3dForward
(
k_dim
,
k_type
,
queue
,
data_type
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_ptr
,
boxes3d_ptr
,
pts_feature_ptr
,
pooled_features_ptr
,
(
int
*
)
pooled_empty_flag_ptr
);
}
else
{
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelRoiPointPool3dLargeBoxesNumForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelRoiPointPool3dLargeBoxesNumForward
(
k_dim
,
k_type
,
queue
,
data_type
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_ptr
,
boxes3d_ptr
,
pts_feature_ptr
,
pooled_features_ptr
,
(
int
*
)
pooled_empty_flag_ptr
);
}
// create tensor descriptors
MluOpTensorDescriptor
xyz_desc
,
pts_feature_desc
,
boxes3d_desc
,
pooled_features_desc
,
pooled_empty_flag_desc
;
xyz_desc
.
set
(
xyz_contiguous
);
pts_feature_desc
.
set
(
pts_feature_contiguous
);
boxes3d_desc
.
set
(
boxes3d_contiguous
);
pooled_features_desc
.
set
(
pooled_features_contiguous
);
pooled_empty_flag_desc
.
set
(
pooled_empty_flag_contiguous
);
// get workspace
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpGetRoiPointPool3dWorkspaceSize
(
handle
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_desc
.
desc
(),
pts_feature_desc
.
desc
(),
boxes3d_desc
.
desc
(),
pooled_features_desc
.
desc
(),
pooled_empty_flag_desc
.
desc
(),
&
workspace_size
));
auto
workspace
=
at
::
empty
(
workspace_size
,
xyz
.
options
().
dtype
(
at
::
kByte
));
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
TORCH_MLUOP_CHECK
(
mluOpRoiPointPool3d
(
handle
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_desc
.
desc
(),
xyz_ptr
,
pts_feature_desc
.
desc
(),
pts_feature_ptr
,
boxes3d_desc
.
desc
(),
boxes3d_ptr
,
workspace_ptr
,
workspace_size
,
pooled_features_desc
.
desc
(),
pooled_features_ptr
,
pooled_empty_flag_desc
.
desc
(),
(
int
*
)
pooled_empty_flag_ptr
));
}
void
roipoint_pool3d_forward_mlu
(
int
batch_size
,
int
pts_num
,
int
boxes_num
,
...
...
mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void
RotatedFeatureAlignForwardMLUKernelLauncher
(
const
Tensor
features
,
const
Tensor
best_bboxes
,
const
float
spatial_scale
,
const
int
points
,
Tensor
output
)
{
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
features
.
dim
());
auto
features_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
features
,
memory_format
);
auto
best_bboxes_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
best_bboxes
,
best_bboxes
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format
);
MluOpTensorDescriptor
features_desc
,
best_bboxes_desc
,
output_desc
;
features_desc
.
set_with_layout
(
features_
,
MLUOP_LAYOUT_NHWC
);
best_bboxes_desc
.
set
(
best_bboxes_contiguous
);
output_desc
.
set_with_layout
(
output_contiguous
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
auto
features_impl
=
torch_mlu
::
getMluTensorImpl
(
features_
);
auto
features_ptr
=
features_impl
->
cnnlMalloc
();
auto
best_bboxes_impl
=
torch_mlu
::
getMluTensorImpl
(
best_bboxes_contiguous
);
auto
best_bboxes_ptr
=
best_bboxes_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_contiguous
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpRotatedFeatureAlignForward
(
handle
,
features_desc
.
desc
(),
features_ptr
,
best_bboxes_desc
.
desc
(),
best_bboxes_ptr
,
spatial_scale
,
points
,
output_desc
.
desc
(),
output_ptr
));
output
.
copy_
(
output_contiguous
);
}
void
RotatedFeatureAlignBackwardMLUKernelLauncher
(
const
Tensor
top_grad
,
const
Tensor
best_bboxes
,
const
float
spatial_scale
,
const
int
points
,
Tensor
bottom_grad
)
{
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
top_grad
.
dim
());
auto
top_grad_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
top_grad
,
memory_format
);
auto
best_bboxes_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
best_bboxes
,
best_bboxes
.
suggest_memory_format
());
auto
bottom_grad_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
bottom_grad
,
memory_format
);
// get ptr of tensors
auto
top_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
top_grad_
);
auto
top_grad_ptr
=
top_grad_impl
->
cnnlMalloc
();
auto
best_bboxes_impl
=
torch_mlu
::
getMluTensorImpl
(
best_bboxes_contiguous
);
auto
best_bboxes_ptr
=
best_bboxes_impl
->
cnnlMalloc
();
auto
bottom_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
bottom_grad_
);
auto
bottom_grad_ptr
=
bottom_grad_impl
->
cnnlMalloc
();
MluOpTensorDescriptor
top_grad_desc
,
best_bboxes_desc
,
bottom_grad_desc
;
top_grad_desc
.
set_with_layout
(
top_grad_
,
MLUOP_LAYOUT_NHWC
);
best_bboxes_desc
.
set
(
best_bboxes_contiguous
);
bottom_grad_desc
.
set_with_layout
(
bottom_grad_
,
MLUOP_LAYOUT_NHWC
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpRotatedFeatureAlignBackward
(
handle
,
top_grad_desc
.
desc
(),
top_grad_ptr
,
best_bboxes_desc
.
desc
(),
best_bboxes_ptr
,
spatial_scale
,
points
,
bottom_grad_desc
.
desc
(),
bottom_grad_ptr
));
bottom_grad
.
copy_
(
bottom_grad_
);
}
void
rotated_feature_align_forward_mlu
(
const
Tensor
features
,
const
Tensor
best_bboxes
,
const
float
spatial_scale
,
const
int
points
,
Tensor
output
)
{
RotatedFeatureAlignForwardMLUKernelLauncher
(
features
,
best_bboxes
,
spatial_scale
,
points
,
output
);
}
void
rotated_feature_align_backward_mlu
(
const
Tensor
top_grad
,
const
Tensor
best_bboxes
,
const
float
spatial_scale
,
const
int
points
,
Tensor
bottom_grad
)
{
RotatedFeatureAlignBackwardMLUKernelLauncher
(
top_grad
,
best_bboxes
,
spatial_scale
,
points
,
bottom_grad
);
}
void
rotated_feature_align_forward_impl
(
const
Tensor
features
,
const
Tensor
best_bboxes
,
const
float
spatial_scale
,
const
int
points
,
Tensor
output
);
void
rotated_feature_align_backward_impl
(
const
Tensor
top_grad
,
const
Tensor
best_bboxes
,
const
float
spatial_scale
,
const
int
points
,
Tensor
bottom_grad
);
REGISTER_DEVICE_IMPL
(
rotated_feature_align_forward_impl
,
MLU
,
rotated_feature_align_forward_mlu
);
REGISTER_DEVICE_IMPL
(
rotated_feature_align_backward_impl
,
MLU
,
rotated_feature_align_backward_mlu
);
mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
std
::
vector
<
Tensor
>
dynamic_point_to_voxel_forward_mlu
(
const
Tensor
&
feats
,
const
Tensor
&
coors
,
const
reduce_t
reduce_type
)
{
// params check
TORCH_CHECK
(
feats
.
scalar_type
()
==
at
::
kFloat
,
"feats type should be Float, got "
,
feats
.
scalar_type
());
TORCH_CHECK
(
coors
.
scalar_type
()
==
at
::
kInt
,
"coors type should be Int32, got "
,
coors
.
scalar_type
());
TORCH_CHECK
(
feats
.
size
(
0
)
==
coors
.
size
(
0
),
"feats.dim(0) and coors.dim(0) should be same, got "
,
feats
.
size
(
0
),
" vs "
,
coors
.
size
(
0
));
const
int
num_input
=
feats
.
size
(
0
);
const
int
num_feats
=
feats
.
size
(
1
);
// zero-element check
if
(
num_input
==
0
)
return
{
feats
.
clone
().
detach
(),
coors
.
clone
().
detach
(),
coors
.
new_empty
({
0
},
torch
::
kInt32
),
coors
.
new_empty
({
0
},
torch
::
kInt32
)};
auto
mlu_reduce_type
=
getMluOpReduceMode
(
reduce_type
);
auto
reduced_feats
=
at
::
empty
({
num_input
,
num_feats
},
feats
.
options
());
auto
out_coors
=
at
::
empty
({
num_input
,
3
},
coors
.
options
());
auto
coors_map
=
at
::
empty
({
num_input
},
coors
.
options
());
auto
reduce_count
=
at
::
empty
({
num_input
},
coors
.
options
());
auto
voxel_num
=
at
::
empty
({
1
},
coors
.
options
());
INITIAL_MLU_PARAM_WITH_TENSOR
(
feats
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
coors
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
reduced_feats
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
out_coors
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
coors_map
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
reduce_count
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxel_num
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
size_t
workspace_size
;
TORCH_MLUOP_CHECK
(
mluOpGetDynamicPointToVoxelForwardWorkspaceSize
(
handle
,
feats_desc
.
desc
(),
coors_desc
.
desc
(),
&
workspace_size
));
auto
workspace_tensor
=
at
::
empty
(
workspace_size
,
feats
.
options
().
dtype
(
at
::
kByte
));
INITIAL_MLU_PARAM_WITH_TENSOR
(
workspace_tensor
);
// launch kernel
TORCH_MLUOP_CHECK
(
mluOpDynamicPointToVoxelForward
(
handle
,
mlu_reduce_type
,
feats_desc
.
desc
(),
feats_ptr
,
coors_desc
.
desc
(),
coors_ptr
,
workspace_tensor_ptr
,
workspace_size
,
reduced_feats_desc
.
desc
(),
reduced_feats_ptr
,
out_coors_desc
.
desc
(),
out_coors_ptr
,
coors_map_desc
.
desc
(),
coors_map_ptr
,
reduce_count_desc
.
desc
(),
reduce_count_ptr
,
voxel_num_desc
.
desc
(),
voxel_num_ptr
));
int
voxel_num_value
=
*
static_cast
<
int
*>
(
voxel_num
.
cpu
().
data_ptr
());
TORCH_CHECK
(
voxel_num_value
<=
feats
.
size
(
0
),
"voxel_num should be less than or equal to feats_num, got "
,
voxel_num_value
,
" vs "
,
feats
.
size
(
0
));
return
{
reduced_feats
.
slice
(
0
,
0
,
voxel_num_value
),
out_coors
.
slice
(
0
,
0
,
voxel_num_value
),
coors_map
,
reduce_count
.
slice
(
0
,
0
,
voxel_num_value
)};
}
void
dynamic_point_to_voxel_backward_mlu
(
Tensor
&
grad_feats
,
const
Tensor
&
grad_reduced_feats
,
const
Tensor
&
feats
,
const
Tensor
&
reduced_feats
,
const
Tensor
&
coors_idx
,
const
Tensor
&
reduce_count
,
const
reduce_t
reduce_type
)
{
// params check
TORCH_CHECK
(
grad_reduced_feats
.
scalar_type
()
==
at
::
kFloat
,
"grad_reduced_feats type should be Float, got "
,
grad_reduced_feats
.
scalar_type
());
TORCH_CHECK
(
feats
.
scalar_type
()
==
at
::
kFloat
,
"feats type should be Float, got "
,
feats
.
scalar_type
());
TORCH_CHECK
(
reduced_feats
.
scalar_type
()
==
at
::
kFloat
,
"reduced_feats type should be Float, got "
,
reduced_feats
.
scalar_type
());
TORCH_CHECK
(
coors_idx
.
scalar_type
()
==
at
::
kInt
,
"coors_idx type should be Int32, got "
,
coors_idx
.
scalar_type
());
TORCH_CHECK
(
reduce_count
.
scalar_type
()
==
at
::
kInt
,
"reduce_count type should be Int32, got "
,
reduce_count
.
scalar_type
());
const
int
num_input
=
feats
.
size
(
0
);
const
int
num_reduced
=
reduced_feats
.
size
(
0
);
const
int
num_feats
=
feats
.
size
(
1
);
grad_feats
.
fill_
(
0
);
// zero-element check
if
(
num_input
==
0
||
num_reduced
==
0
)
return
;
// TODO(miaochen): remove this after mlu-ops supports other mode of reduce.
TORCH_CHECK
(
reduce_type
==
reduce_t
::
MAX
,
"only supports max reduce in current version, got "
,
to_string
(
reduce_type
));
int
voxel_num_value
=
reduced_feats
.
size
(
0
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
voxel_num
=
torch
::
from_blob
(
&
voxel_num_value
,
{
1
},
opts
).
clone
().
to
(
at
::
kMLU
);
auto
mlu_reduce_type
=
getMluOpReduceMode
(
reduce_type
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_feats
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_reduced_feats
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
feats
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
reduced_feats
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
coors_idx
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
reduce_count
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxel_num
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
size_t
workspace_size
;
TORCH_MLUOP_CHECK
(
mluOpGetDynamicPointToVoxelBackwardWorkspaceSize
(
handle
,
mlu_reduce_type
,
grad_feats_desc
.
desc
(),
feats_desc
.
desc
(),
grad_reduced_feats_desc
.
desc
(),
coors_idx_desc
.
desc
(),
reduce_count_desc
.
desc
(),
voxel_num_desc
.
desc
(),
&
workspace_size
));
auto
workspace_tensor
=
at
::
empty
(
workspace_size
,
feats
.
options
().
dtype
(
at
::
kByte
));
INITIAL_MLU_PARAM_WITH_TENSOR
(
workspace_tensor
);
// launch kernel
TORCH_MLUOP_CHECK
(
mluOpDynamicPointToVoxelBackward
(
handle
,
mlu_reduce_type
,
grad_reduced_feats_desc
.
desc
(),
grad_reduced_feats_ptr
,
feats_desc
.
desc
(),
feats_ptr
,
reduced_feats_desc
.
desc
(),
reduced_feats_ptr
,
coors_idx_desc
.
desc
(),
coors_idx_ptr
,
reduce_count_desc
.
desc
(),
reduce_count_ptr
,
voxel_num_desc
.
desc
(),
voxel_num_ptr
,
workspace_tensor_ptr
,
workspace_size
,
grad_feats_desc
.
desc
(),
grad_feats_ptr
));
}
std
::
vector
<
Tensor
>
dynamic_point_to_voxel_forward_impl
(
const
Tensor
&
feats
,
const
Tensor
&
coors
,
const
reduce_t
reduce_type
);
void
dynamic_point_to_voxel_backward_impl
(
Tensor
&
grad_feats
,
const
Tensor
&
grad_reduced_feats
,
const
Tensor
&
feats
,
const
Tensor
&
reduced_feats
,
const
Tensor
&
coors_idx
,
const
Tensor
&
reduce_count
,
const
reduce_t
reduce_type
);
REGISTER_DEVICE_IMPL
(
dynamic_point_to_voxel_forward_impl
,
MLU
,
dynamic_point_to_voxel_forward_mlu
);
REGISTER_DEVICE_IMPL
(
dynamic_point_to_voxel_backward_impl
,
MLU
,
dynamic_point_to_voxel_backward_mlu
);
mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <torch/script.h>
#include <vector>
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
template
<
unsigned
NDim
>
std
::
vector
<
torch
::
Tensor
>
GetIndicePairsForwardMLUKernelLauncher
(
torch
::
Tensor
indices
,
int64_t
batchSize
,
std
::
vector
<
int64_t
>
outSpatialShape
,
std
::
vector
<
int64_t
>
spatialShape
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outPadding
,
int64_t
_subM
,
int64_t
_transpose
)
{
// The following code is copied from
// mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu to ensure the output is
// available for network train. The outputs of this function have correct
// shape but wrong value.
auto
numAct
=
indices
.
size
(
0
);
auto
kernelVolume
=
kernelSize
[
0
];
int
sub_m
=
(
int
)
_subM
;
int
transpose
=
(
int
)
_transpose
;
int
batch
=
(
int
)
batchSize
;
auto
coorDim
=
indices
.
size
(
1
)
-
1
;
for
(
int
i
=
1
;
i
<
kernelSize
.
size
();
++
i
)
{
kernelVolume
*=
kernelSize
[
i
];
}
auto
outputVolume
=
outSpatialShape
[
0
];
for
(
int
i
=
1
;
i
<
outSpatialShape
.
size
();
++
i
)
{
outputVolume
*=
outSpatialShape
[
i
];
}
torch
::
Tensor
indicePairs
=
at
::
full
({
kernelVolume
,
2
,
numAct
},
-
1
,
indices
.
options
().
dtype
(
at
::
kInt
));
torch
::
Tensor
indiceNum
=
at
::
zeros
({
kernelVolume
},
indices
.
options
().
dtype
(
at
::
kInt
));
int
out_size
=
sub_m
==
1
?
numAct
:
std
::
min
(
numAct
*
kernelVolume
,
batch
*
outputVolume
);
torch
::
Tensor
out_indices
=
at
::
zeros
({
out_size
,
coorDim
+
1
},
indices
.
options
().
dtype
(
at
::
kInt
));
auto
indices_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
indices
,
at
::
MemoryFormat
::
Contiguous
);
auto
indicePairs_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
indicePairs
,
at
::
MemoryFormat
::
Contiguous
);
auto
indiceNum_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
indiceNum
,
at
::
MemoryFormat
::
Contiguous
);
auto
out_indices_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
out_indices
,
at
::
MemoryFormat
::
Contiguous
);
std
::
vector
<
int
>
input_space
;
std
::
vector
<
int
>
filter_space
;
std
::
vector
<
int
>
output_space
;
std
::
vector
<
int
>
padding32
;
std
::
vector
<
int
>
stride32
;
std
::
vector
<
int
>
dilation32
;
for
(
int
i
=
0
;
i
<
NDim
;
i
++
)
{
input_space
.
push_back
(
spatialShape
[
i
]);
filter_space
.
push_back
(
kernelSize
[
i
]);
output_space
.
push_back
(
outSpatialShape
[
i
]);
padding32
.
push_back
(
padding
[
i
]);
stride32
.
push_back
(
stride
[
i
]);
dilation32
.
push_back
(
dilation
[
i
]);
}
MluOpTensorDescriptor
indices_desc
,
out_indices_desc
,
indicePairs_desc
,
indiceNum_desc
;
indices_desc
.
set
(
indices_contiguous
);
indicePairs_desc
.
set
(
indicePairs_contiguous
);
indiceNum_desc
.
set
(
indiceNum_contiguous
);
out_indices_desc
.
set
(
out_indices_contiguous
);
{
mluOpTensorLayout_t
layout
=
MLUOP_LAYOUT_ARRAY
;
mluOpDataType_t
dtype
=
MLUOP_DTYPE_INT32
;
std
::
vector
<
int
>
dims
;
dims
=
{
numAct
,
coorDim
+
1
};
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
indices_desc
.
desc
(),
layout
,
dtype
,
dims
.
size
(),
dims
.
data
()));
dims
=
{
kernelVolume
,
2
,
numAct
};
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
indicePairs_desc
.
desc
(),
layout
,
dtype
,
dims
.
size
(),
dims
.
data
()));
dims
=
{
kernelVolume
};
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
indiceNum_desc
.
desc
(),
layout
,
dtype
,
dims
.
size
(),
dims
.
data
()));
dims
=
{
out_size
,
coorDim
+
1
};
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
out_indices_desc
.
desc
(),
layout
,
dtype
,
dims
.
size
(),
dims
.
data
()));
}
mluOpSparseConvolutionDescriptor_t
sparse_conv_desc
;
TORCH_MLUOP_CHECK
(
mluOpCreateSparseConvolutionDescriptor
(
&
sparse_conv_desc
));
TORCH_MLUOP_CHECK
(
mluOpSetSparseConvolutionDescriptor
(
sparse_conv_desc
,
NDim
+
2
,
batch
,
padding32
.
data
(),
stride32
.
data
(),
dilation32
.
data
(),
input_space
.
data
(),
filter_space
.
data
(),
output_space
.
data
(),
sub_m
,
transpose
,
0
));
auto
handle
=
mluOpGetCurrentHandle
();
size_t
workspace_size
=
0
;
TORCH_MLUOP_CHECK
(
mluOpGetIndicePairsWorkspaceSize
(
handle
,
sparse_conv_desc
,
indices_desc
.
desc
(),
indicePairs_desc
.
desc
(),
out_indices_desc
.
desc
(),
indiceNum_desc
.
desc
(),
&
workspace_size
));
auto
indice_workspace_size
=
at
::
empty
(
workspace_size
,
indices
.
options
().
dtype
(
at
::
kByte
));
auto
indices_impl
=
torch_mlu
::
getMluTensorImpl
(
indices_contiguous
);
auto
out_indices_impl
=
torch_mlu
::
getMluTensorImpl
(
out_indices_contiguous
);
auto
indicePairs_impl
=
torch_mlu
::
getMluTensorImpl
(
indicePairs_contiguous
);
auto
indiceNum_impl
=
torch_mlu
::
getMluTensorImpl
(
indiceNum_contiguous
);
auto
indice_workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
indice_workspace_size
);
auto
indices_ptr
=
indices_impl
->
cnnlMalloc
();
auto
out_indices_ptr
=
out_indices_impl
->
cnnlMalloc
();
auto
indicePairs_ptr
=
indicePairs_impl
->
cnnlMalloc
();
auto
indiceNum_ptr
=
indiceNum_impl
->
cnnlMalloc
();
auto
indice_workspace_ptr
=
indice_workspace_impl
->
cnnlMalloc
();
TORCH_MLUOP_CHECK
(
mluOpGetIndicePairs
(
handle
,
sparse_conv_desc
,
indices_desc
.
desc
(),
indices_ptr
,
indice_workspace_ptr
,
workspace_size
,
indicePairs_desc
.
desc
(),
indicePairs_ptr
,
out_indices_desc
.
desc
(),
out_indices_ptr
,
indiceNum_desc
.
desc
(),
indiceNum_ptr
));
int
num_act_out
=
0
;
TORCH_MLUOP_CHECK
(
mluOpGetSparseConvolutionNumActOut
(
sparse_conv_desc
,
&
num_act_out
));
TORCH_MLUOP_CHECK
(
mluOpDestroySparseConvolutionDescriptor
(
sparse_conv_desc
));
if
(
!
sub_m
)
{
return
{
out_indices
.
slice
(
0
,
0
,
num_act_out
),
indicePairs
,
indiceNum
};
}
else
{
return
{
indices
,
indicePairs
,
indiceNum
};
}
}
torch
::
Tensor
IndiceConvForwardMLUKernelLauncher
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
numActOut
,
int64_t
_inverse
,
int64_t
_subM
)
{
auto
indice_num_cpu
=
indiceNum
.
to
({
torch
::
kCPU
});
auto
indice_num_cpu_64
=
indice_num_cpu
.
to
(
torch
::
kInt64
);
auto
indice_num
=
indice_num_cpu_64
.
data_ptr
<
int64_t
>
();
// generate empty output
int
C
=
filters
.
dim
()
==
4
?
filters
.
size
(
3
)
:
filters
.
size
(
4
);
torch
::
Tensor
output
=
at
::
zeros
({
numActOut
,
C
},
features
.
options
().
dtype
(
at
::
kFloat
));
// generate descriptor
auto
features_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
features
,
at
::
MemoryFormat
::
Contiguous
);
auto
filters_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
filters
,
at
::
MemoryFormat
::
Contiguous
);
auto
indice_pairs_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
indicePairs
,
at
::
MemoryFormat
::
Contiguous
);
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
at
::
MemoryFormat
::
Contiguous
);
MluOpTensorDescriptor
features_desc
,
filters_desc
,
indice_pairs_desc
,
output_desc
;
features_desc
.
set
(
features_contiguous
);
filters_desc
.
set
(
filters_contiguous
);
indice_pairs_desc
.
set
(
indice_pairs_contiguous
);
output_desc
.
set
(
output_contiguous
);
// set layout
{
mluOpTensorLayout_t
layout
;
mluOpDataType_t
dtype
;
int
dim
;
int
dims
[
8
];
// features_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
features_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
features_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
// filters_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
filters_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
filters_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
// indice_pairs_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
indice_pairs_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
indice_pairs_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
// output_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
output_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
output_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
}
auto
handle
=
mluOpGetCurrentHandle
();
size_t
workspace_size
=
0
;
TORCH_MLUOP_CHECK
(
mluOpGetIndiceConvolutionForwardWorkspaceSize
(
handle
,
features_desc
.
desc
(),
filters_desc
.
desc
(),
indice_pairs_desc
.
desc
(),
output_desc
.
desc
(),
indice_num
,
numActOut
,
_inverse
,
_subM
,
&
workspace_size
));
auto
workspace
=
at
::
empty
(
workspace_size
,
features
.
options
().
dtype
(
at
::
kByte
));
auto
features_impl
=
torch_mlu
::
getMluTensorImpl
(
features_contiguous
);
auto
filters_impl
=
torch_mlu
::
getMluTensorImpl
(
filters_contiguous
);
auto
indice_pairs_impl
=
torch_mlu
::
getMluTensorImpl
(
indice_pairs_contiguous
);
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
auto
features_ptr
=
features_impl
->
cnnlMalloc
();
auto
filters_ptr
=
filters_impl
->
cnnlMalloc
();
auto
indice_pairs_ptr
=
indice_pairs_impl
->
cnnlMalloc
();
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
// outputs
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
TORCH_MLUOP_CHECK
(
mluOpIndiceConvolutionForward
(
handle
,
features_desc
.
desc
(),
features_ptr
,
filters_desc
.
desc
(),
filters_ptr
,
indice_pairs_desc
.
desc
(),
indice_pairs_ptr
,
indice_num
,
numActOut
,
_inverse
,
_subM
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
));
return
output
;
}
std
::
vector
<
torch
::
Tensor
>
IndiceConvBackwardMLUKernelLauncher
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
outGrad
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
_inverse
,
int64_t
_subM
)
{
auto
indice_num_cpu
=
indiceNum
.
to
({
torch
::
kCPU
});
auto
indice_num_cpu_64
=
indice_num_cpu
.
to
(
torch
::
kInt64
);
auto
indice_num
=
indice_num_cpu_64
.
data_ptr
<
int64_t
>
();
// generate empty input_grad
torch
::
Tensor
input_grad
=
at
::
zeros
({
features
.
size
(
0
),
features
.
size
(
1
)},
features
.
options
().
dtype
(
at
::
kFloat
));
torch
::
Tensor
filters_grad
;
if
(
filters
.
dim
()
==
4
)
{
int
h
=
filters
.
size
(
0
);
int
w
=
filters
.
size
(
1
);
int
c
=
filters
.
size
(
2
);
int
n
=
filters
.
size
(
3
);
filters_grad
=
at
::
zeros
({
h
,
w
,
c
,
n
},
filters
.
options
().
dtype
(
at
::
kFloat
));
}
else
if
(
filters
.
dim
()
==
5
)
{
int
d
=
filters
.
size
(
0
);
int
h
=
filters
.
size
(
1
);
int
w
=
filters
.
size
(
2
);
int
c
=
filters
.
size
(
3
);
int
n
=
filters
.
size
(
4
);
filters_grad
=
at
::
zeros
({
d
,
h
,
w
,
c
,
n
},
filters
.
options
().
dtype
(
at
::
kFloat
));
}
auto
features_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
features
,
at
::
MemoryFormat
::
Contiguous
);
auto
filters_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
filters
,
at
::
MemoryFormat
::
Contiguous
);
auto
output_grad_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
outGrad
,
at
::
MemoryFormat
::
Contiguous
);
auto
indice_pairs_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
indicePairs
,
at
::
MemoryFormat
::
Contiguous
);
auto
input_grad_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
features
,
at
::
MemoryFormat
::
Contiguous
);
auto
filters_grad_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
filters
,
at
::
MemoryFormat
::
Contiguous
);
MluOpTensorDescriptor
features_desc
,
output_grad_desc
,
filters_desc
,
indice_pairs_desc
,
input_grad_desc
,
filters_grad_desc
;
features_desc
.
set
(
features_contiguous
);
filters_desc
.
set
(
filters_contiguous
);
output_grad_desc
.
set
(
output_grad_contiguous
);
indice_pairs_desc
.
set
(
indice_pairs_contiguous
);
input_grad_desc
.
set
(
input_grad_contiguous
);
filters_grad_desc
.
set
(
filters_grad_contiguous
);
// need to set desc layout with mluOp functions
{
mluOpTensorLayout_t
layout
;
mluOpDataType_t
dtype
;
int
dim
;
int
dims
[
8
];
// features_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
features_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
features_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
// filters_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
filters_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
if
(
dim
==
4
)
{
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
filters_desc
.
desc
(),
MLUOP_LAYOUT_HWCN
,
dtype
,
dim
,
dims
));
}
else
{
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
filters_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
}
// output_grad_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
output_grad_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
output_grad_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
// indice_pairs_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
indice_pairs_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
indice_pairs_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
// input_grad_desc
TORCH_MLUOP_CHECK
(
mluOpGetTensorDescriptor
(
input_grad_desc
.
desc
(),
&
layout
,
&
dtype
,
&
dim
,
dims
));
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
input_grad_desc
.
desc
(),
MLUOP_LAYOUT_ARRAY
,
dtype
,
dim
,
dims
));
}
auto
handle
=
mluOpGetCurrentHandle
();
size_t
data_workspace_size
=
0
;
mluOpGetIndiceConvolutionBackwardDataWorkspaceSize
(
handle
,
output_grad_desc
.
desc
(),
filters_desc
.
desc
(),
indice_pairs_desc
.
desc
(),
input_grad_desc
.
desc
(),
indice_num
,
_inverse
,
&
data_workspace_size
);
size_t
filters_workspace_size
=
0
;
TORCH_MLUOP_CHECK
(
mluOpGetIndiceConvolutionBackwardFilterWorkspaceSize
(
handle
,
features_desc
.
desc
(),
output_grad_desc
.
desc
(),
indice_pairs_desc
.
desc
(),
filters_grad_desc
.
desc
(),
indice_num
,
_inverse
,
_subM
,
&
filters_workspace_size
));
auto
indice_convbpdata_workspace
=
at
::
empty
(
data_workspace_size
,
features
.
options
().
dtype
(
at
::
kByte
));
auto
indice_convbpfilter_workspace
=
at
::
empty
(
filters_workspace_size
,
filters
.
options
().
dtype
(
at
::
kByte
));
auto
features_impl
=
torch_mlu
::
getMluTensorImpl
(
features_contiguous
);
auto
filters_impl
=
torch_mlu
::
getMluTensorImpl
(
filters_contiguous
);
auto
output_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
output_grad_contiguous
);
auto
indice_pairs_impl
=
torch_mlu
::
getMluTensorImpl
(
indice_pairs_contiguous
);
auto
indice_convbpdata_workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
indice_convbpdata_workspace
);
auto
indice_convbpfilter_workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
indice_convbpfilter_workspace
);
auto
features_ptr
=
features_impl
->
cnnlMalloc
();
auto
filters_ptr
=
filters_impl
->
cnnlMalloc
();
auto
output_grad_ptr
=
output_grad_impl
->
cnnlMalloc
();
auto
indice_pairs_ptr
=
indice_pairs_impl
->
cnnlMalloc
();
auto
indice_convbpdata_workspace_ptr
=
indice_convbpdata_workspace_impl
->
cnnlMalloc
();
auto
indice_convbpfilter_workspace_ptr
=
indice_convbpfilter_workspace_impl
->
cnnlMalloc
();
// outputs
auto
input_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
input_grad
);
auto
input_grad_ptr
=
input_grad_impl
->
cnnlMalloc
();
auto
filters_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
filters_grad
);
auto
filters_grad_ptr
=
filters_grad_impl
->
cnnlMalloc
();
TORCH_MLUOP_CHECK
(
mluOpIndiceConvolutionBackwardData
(
handle
,
output_grad_desc
.
desc
(),
output_grad_ptr
,
filters_desc
.
desc
(),
filters_ptr
,
indice_pairs_desc
.
desc
(),
indice_pairs_ptr
,
indice_num
,
_inverse
,
_subM
,
indice_convbpdata_workspace_ptr
,
data_workspace_size
,
input_grad_desc
.
desc
(),
input_grad_ptr
));
TORCH_MLUOP_CHECK
(
mluOpIndiceConvolutionBackwardFilter
(
handle
,
features_desc
.
desc
(),
features_ptr
,
output_grad_desc
.
desc
(),
output_grad_ptr
,
indice_pairs_desc
.
desc
(),
indice_pairs_ptr
,
indice_num
,
_inverse
,
_subM
,
indice_convbpfilter_workspace_ptr
,
filters_workspace_size
,
filters_grad_desc
.
desc
(),
filters_grad_ptr
));
std
::
vector
<
torch
::
Tensor
>
result
;
result
.
push_back
(
input_grad
);
result
.
push_back
(
filters_grad
);
return
result
;
}
torch
::
Tensor
indice_conv_forward_mlu
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
numActOut
,
int64_t
_inverse
,
int64_t
_subM
)
{
return
IndiceConvForwardMLUKernelLauncher
(
features
,
filters
,
indicePairs
,
indiceNum
,
numActOut
,
_inverse
,
_subM
);
}
std
::
vector
<
torch
::
Tensor
>
indice_conv_backward_mlu
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
outGrad
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
_inverse
,
int64_t
_subM
)
{
return
IndiceConvBackwardMLUKernelLauncher
(
features
,
filters
,
outGrad
,
indicePairs
,
indiceNum
,
_inverse
,
_subM
);
}
torch
::
Tensor
indice_conv_forward_impl
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
numActOut
,
int64_t
_inverse
,
int64_t
_subM
);
std
::
vector
<
torch
::
Tensor
>
indice_conv_backward_impl
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
outGrad
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
_inverse
,
int64_t
_subM
);
REGISTER_DEVICE_IMPL
(
indice_conv_forward_impl
,
MLU
,
indice_conv_forward_mlu
);
REGISTER_DEVICE_IMPL
(
indice_conv_backward_impl
,
MLU
,
indice_conv_backward_mlu
);
template
std
::
vector
<
torch
::
Tensor
>
GetIndicePairsForwardMLUKernelLauncher
<
2
>
(
torch
::
Tensor
indices
,
int64_t
batchSize
,
std
::
vector
<
int64_t
>
outSpatialShape
,
std
::
vector
<
int64_t
>
spatialShape
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outPadding
,
int64_t
_subM
,
int64_t
_transpose
);
template
std
::
vector
<
torch
::
Tensor
>
GetIndicePairsForwardMLUKernelLauncher
<
3
>
(
torch
::
Tensor
indices
,
int64_t
batchSize
,
std
::
vector
<
int64_t
>
outSpatialShape
,
std
::
vector
<
int64_t
>
spatialShape
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outPadding
,
int64_t
_subM
,
int64_t
_transpose
);
template
std
::
vector
<
torch
::
Tensor
>
GetIndicePairsForwardMLUKernelLauncher
<
4
>
(
torch
::
Tensor
indices
,
int64_t
batchSize
,
std
::
vector
<
int64_t
>
outSpatialShape
,
std
::
vector
<
int64_t
>
spatialShape
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outPadding
,
int64_t
_subM
,
int64_t
_transpose
);
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
View file @
91da9643
...
...
@@ -9,84 +9,47 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelThreeNNForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
data_type
,
const
void
*
unknown
,
const
void
*
known
,
void
*
dist2
,
int
*
idx
,
const
int
b
,
const
int
n
,
const
int
m
);
#include "mlu_common_helper.h"
void
ThreeNNMLUKernelLauncher
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
const
Tensor
known
,
Tensor
dist2
,
Tensor
idx
)
{
// Check dtype.
TORCH_CHECK
(
unknown
.
scalar_type
()
==
at
::
kFloat
||
unknown
.
scalar_type
()
==
at
::
kHalf
,
"unknown type should be Float or Half, got "
,
unknown
.
scalar_type
(),
"."
);
TORCH_CHECK
(
unknown
.
scalar_type
()
==
known
.
scalar_type
(),
"known should have the same type as unknown."
);
TORCH_CHECK
(
unknown
.
scalar_type
()
==
dist2
.
scalar_type
(),
"dist2 should have the same type as unknown."
);
TORCH_CHECK
(
idx
.
scalar_type
()
==
at
::
kInt
,
"idx type should be Int."
);
// Check shape.
TORCH_CHECK
(
unknown
.
dim
()
==
3
,
"unknown should be 3d tensor, got "
,
unknown
.
dim
(),
"D."
);
TORCH_CHECK
(
known
.
dim
()
==
3
,
"known should be 3d tensor, got "
,
known
.
dim
(),
"D."
);
TORCH_CHECK
(
unknown
.
size
(
0
)
==
known
.
size
(
0
),
"known.dim0 should be equal to unknown.dim0, got "
,
known
.
size
(
0
),
"."
);
TORCH_CHECK
(
unknown
.
size
(
2
)
==
3
,
"unknown dim2 should be 3, got "
,
unknown
.
size
(
2
),
"."
);
TORCH_CHECK
(
known
.
size
(
2
)
==
3
,
"known dim2 should be 3, got "
,
known
.
size
(
2
),
"."
);
// zero element check
TORCH_CHECK
(
unknown
.
numel
()
>
0
,
"unknown.numel should greater than zero, got "
,
unknown
.
numel
(),
"."
);
if
(
known
.
numel
()
==
0
)
{
// return if known zero element
return
;
}
// large tensor check
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
unknown
.
numel
()
<
max_input_num
,
"unknown.numel() should be less than 2147483648, got "
,
unknown
.
numel
(),
"."
);
TORCH_CHECK
(
known
.
numel
()
<
max_input_num
,
"known.numel() should be less than 2147483648, got "
,
known
.
numel
(),
"."
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown
);
auto
unknown_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
unknown
,
unknown
.
suggest_memory_format
());
auto
known_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
known
,
known
.
suggest_memory_format
());
auto
dist2_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
dist2
,
dist2
.
suggest_memory_format
());
auto
idx_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
idx
,
idx
.
suggest_memory_format
());
MluOpTensorDescriptor
unknown_desc
,
known_desc
,
dist2_desc
,
idx_desc
;
unknown_desc
.
set
(
unknown_contiguous
);
known_desc
.
set
(
known_contiguous
);
dist2_desc
.
set
(
dist2_contiguous
);
idx_desc
.
set
(
idx_contiguous
);
auto
handle
=
mluOpGetCurrentHandle
();
size_t
workspace_size
=
0
;
TORCH_MLUOP_CHECK
(
mluOpGetThreeNNForwardWorkspaceSize
(
handle
,
known_desc
.
desc
(),
&
workspace_size
));
auto
known_workspace
=
at
::
empty
(
workspace_size
,
known
.
options
().
dtype
(
at
::
kByte
));
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown_contiguous
);
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_contiguous
);
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2_contiguous
);
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx_contiguous
);
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
known_workspace
);
auto
unknown_ptr
=
unknown_impl
->
cnnlMalloc
();
auto
known_t
=
known
.
permute
({
0
,
2
,
1
}).
contiguous
();
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_t
);
auto
known_ptr
=
known_impl
->
cnnlMalloc
();
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2
);
auto
dist2_ptr
=
dist2_impl
->
cnnlMalloc
();
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx
);
auto
idx_ptr
=
idx_impl
->
cnnlMalloc
();
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
k_dim
.
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
.
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
.
z
=
1
;
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
unknown
.
dtype
());
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelThreeNNForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>."
;
KernelThreeNNForward
(
k_dim
,
k_type
,
queue
,
data_type
,
unknown_ptr
,
known_ptr
,
dist2_ptr
,
(
int
*
)
idx_ptr
,
b
,
n
,
m
);
TORCH_MLUOP_CHECK
(
mluOpThreeNNForward
(
handle
,
unknown_desc
.
desc
(),
unknown_ptr
,
known_desc
.
desc
(),
known_ptr
,
workspace_ptr
,
workspace_size
,
dist2_desc
.
desc
(),
dist2_ptr
,
idx_desc
.
desc
(),
idx_ptr
));
}
void
three_nn_forward_mlu
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
...
...
mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp
View file @
91da9643
...
...
@@ -9,65 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelTinShiftForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
input
,
const
void
*
shifts
,
void
*
output
,
const
int
batch_size
,
const
int
time_size
,
const
int
channel_size
,
const
int
hw_size
,
const
int
group_size
,
const
int
group_channel
,
const
cnrtDataType_t
data_dtype
,
const
int
channel_per_core
,
const
int
max_number_hw_per_core
,
const
int
max_length_per_core
);
void
KernelTinShiftBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
grad_output
,
const
void
*
shifts
,
void
*
grad_input
,
const
int
batch_size
,
const
int
time_size
,
const
int
channel_size
,
const
int
hw_size
,
const
int
group_size
,
const
int
group_channel
,
const
cnrtDataType_t
data_dtype
,
const
int
channel_per_core
,
const
int
max_number_hw_per_core
,
const
int
max_length_per_core
);
// policy function
static
void
policyFunc
(
const
Tensor
&
input
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
*
channel_per_core
,
int
*
max_number_hw_per_core
,
int
*
max_length_per_core
)
{
const
int32_t
cluster_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
const
int32_t
core_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
int
core_num
=
core_limit
*
cluster_limit
;
const
int
batch_size
=
input
.
size
(
0
);
const
int
time_size
=
input
.
size
(
1
);
const
int
channel_size
=
input
.
size
(
2
);
const
int
hw_size
=
input
.
size
(
3
);
const
size_t
size_per_channel
=
time_size
*
hw_size
*
input
.
itemsize
();
*
channel_per_core
=
nram_size
/
size_per_channel
;
int
task_dim
=
0
;
if
(
*
channel_per_core
==
0
)
{
const
size_t
size_per_hw
=
hw_size
*
input
.
itemsize
();
*
max_number_hw_per_core
=
nram_size
/
size_per_hw
;
if
(
*
max_number_hw_per_core
<=
0
)
{
*
max_length_per_core
=
nram_size
/
input
.
itemsize
();
}
int
tmp_max_number_hw_per_core
=
*
max_number_hw_per_core
>
0
?
*
max_number_hw_per_core
:
1
;
const
int
loop_time
=
(
time_size
/
(
tmp_max_number_hw_per_core
))
+
((
time_size
%
(
tmp_max_number_hw_per_core
))
>
0
?
1
:
0
);
task_dim
=
batch_size
*
channel_size
*
loop_time
<
core_num
?
batch_size
*
channel_size
*
loop_time
:
core_num
;
}
else
{
task_dim
=
batch_size
*
channel_size
<
core_num
?
batch_size
*
channel_size
:
core_num
;
}
k_dim
->
x
=
core_limit
;
k_dim
->
y
=
(
task_dim
/
core_limit
)
>
0
?
(
task_dim
/
core_limit
)
:
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
#include "mlu_common_helper.h"
void
TINShiftForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
shift
,
Tensor
output
)
{
...
...
@@ -89,40 +31,37 @@ void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift,
if
(
input
.
size
(
1
)
==
0
)
{
return
;
}
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
int
channel_per_core
=
0
;
int
max_number_hw_per_core
=
0
;
int
max_length_per_core
=
0
;
policyFunc
(
input
,
&
k_dim
,
&
k_type
,
&
channel_per_core
,
&
max_number_hw_per_core
,
&
max_length_per_core
);
const
int
batch_size
=
input
.
size
(
0
);
const
int
time_size
=
input
.
size
(
1
);
const
int
channel_size
=
input
.
size
(
2
);
const
int
hw_size
=
input
.
size
(
3
);
const
int
group_size
=
shift
.
size
(
1
);
int
group_channel
=
channel_size
/
group_size
;
// get tensor impl
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
// set contiguous
auto
input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
input
.
suggest_memory_format
());
auto
shift_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
shift
,
shift
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
output
.
suggest_memory_format
());
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get tensor impl
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_contiguous
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift_contiguous
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_contiguous
);
// get the mlu ptr
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
shift_ptr
=
shift_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
cnrtDataType_t
data_dtype
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
shift_desc
,
output_desc
;
input_desc
.
set
(
input_contiguous
);
shift_desc
.
set
(
shift_contiguous
);
output_desc
.
set
(
output_contiguous
);
KernelTinShiftForward
(
k_dim
,
k_type
,
queue
,
input_ptr
,
shift_ptr
,
output_ptr
,
batch_size
,
time_size
,
channel_size
,
hw_size
,
group_size
,
group_channel
,
data_dtype
,
channel_per_core
,
max_number_hw_per_core
,
max_length_per_core
);
// get current handle
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpTinShiftForward
(
handle
,
input_desc
.
desc
(),
input_ptr
,
shift_desc
.
desc
(),
shift_ptr
,
output_desc
.
desc
(),
output_ptr
));
}
void
TINShiftBackwardMLUKernelLauncher
(
Tensor
grad_output
,
Tensor
shift
,
...
...
@@ -148,41 +87,37 @@ void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift,
if
(
grad_output
.
size
(
1
)
==
0
)
{
return
;
}
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
int
channel_per_core
=
0
;
int
max_number_hw_per_core
=
0
;
int
max_length_per_core
=
0
;
policyFunc
(
grad_output
,
&
k_dim
,
&
k_type
,
&
channel_per_core
,
&
max_number_hw_per_core
,
&
max_length_per_core
);
const
int
batch_size
=
grad_output
.
size
(
0
);
const
int
time_size
=
grad_output
.
size
(
1
);
const
int
channel_size
=
grad_output
.
size
(
2
);
const
int
hw_size
=
grad_output
.
size
(
3
);
const
int
group_size
=
shift
.
size
(
1
);
int
group_channel
=
channel_size
/
group_size
;
// get tensor impl
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift
);
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input
);
// set contiguous
auto
grad_output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_output
,
grad_output
.
suggest_memory_format
());
auto
shift_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
shift
,
shift
.
suggest_memory_format
());
auto
grad_input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_input
,
grad_input
.
suggest_memory_format
());
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get tensor impl
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output_contiguous
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift_contiguous
);
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_contiguous
);
// get the mlu ptr
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
shift_ptr
=
shift_impl
->
cnnlMalloc
();
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
cnrtDataType_t
data_dtype
=
torch_mlu
::
toCnrtDtype
(
grad_output
.
dtype
());
// set tensor descriptor
MluOpTensorDescriptor
grad_output_desc
,
shift_desc
,
grad_input_desc
;
grad_output_desc
.
set
(
grad_output_contiguous
);
shift_desc
.
set
(
shift_contiguous
);
grad_input_desc
.
set
(
grad_input_contiguous
);
// get current handle
auto
handle
=
mluOpGetCurrentHandle
();
KernelTinShiftBackward
(
k_dim
,
k_type
,
queue
,
grad_output_ptr
,
shift_ptr
,
grad_input_ptr
,
batch_size
,
time_size
,
channel_size
,
hw_size
,
group_size
,
group_channel
,
data_dtype
,
channel_per_core
,
max_number_hw_per_core
,
max_length_per_core
);
TORCH_MLUOP_CHECK
(
mluOpTinShiftBackward
(
handle
,
grad_output_desc
.
desc
(),
grad_output_ptr
,
shift_desc
.
desc
(),
shift_ptr
,
grad_input_desc
.
desc
(),
grad_input_ptr
));
}
void
tin_shift_forward_mlu
(
Tensor
input
,
Tensor
shift
,
Tensor
output
)
{
...
...
mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
int
HardVoxelizeForwardMLUKernelLauncher
(
const
at
::
Tensor
&
points
,
at
::
Tensor
&
voxels
,
at
::
Tensor
&
coors
,
at
::
Tensor
&
num_points_per_voxel
,
const
std
::
vector
<
float
>
voxel_size
,
const
std
::
vector
<
float
>
coors_range
,
const
int
max_points
,
const
int
max_voxels
,
const
int
NDim
=
3
)
{
std
::
vector
<
float
>
_voxel_size
(
voxel_size
.
begin
(),
voxel_size
.
end
());
std
::
vector
<
float
>
_coors_range
(
coors_range
.
begin
(),
coors_range
.
end
());
auto
opts
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
);
auto
voxel_size_tensor
=
torch
::
from_blob
(
_voxel_size
.
data
(),
{
int64_t
(
_voxel_size
.
size
())},
opts
)
.
clone
()
.
to
(
at
::
kMLU
);
auto
coors_range_tensor
=
torch
::
from_blob
(
_coors_range
.
data
(),
{
int64_t
(
_coors_range
.
size
())},
opts
)
.
clone
()
.
to
(
at
::
kMLU
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
points
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxels
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
coors
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
num_points_per_voxel
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxel_size_tensor
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
coors_range_tensor
);
auto
voxel_num_tensor
=
at
::
empty
({
1
},
points
.
options
().
dtype
(
torch
::
kInt32
));
INITIAL_MLU_PARAM_WITH_TENSOR
(
voxel_num_tensor
);
size_t
workspace_size
;
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpGetVoxelizationWorkspaceSize
(
handle
,
points_desc
.
desc
(),
voxel_size_tensor_desc
.
desc
(),
coors_range_tensor_desc
.
desc
(),
max_points
,
max_voxels
,
NDim
,
true
,
voxels_desc
.
desc
(),
coors_desc
.
desc
(),
num_points_per_voxel_desc
.
desc
(),
voxel_num_tensor_desc
.
desc
(),
&
workspace_size
));
auto
workspace_tensor
=
at
::
empty
(
workspace_size
,
points
.
options
().
dtype
(
at
::
kByte
));
INITIAL_MLU_PARAM_WITH_TENSOR
(
workspace_tensor
);
TORCH_MLUOP_CHECK
(
mluOpVoxelization
(
handle
,
points_desc
.
desc
(),
points_ptr
,
voxel_size_tensor_desc
.
desc
(),
voxel_size_tensor_ptr
,
coors_range_tensor_desc
.
desc
(),
coors_range_tensor_ptr
,
max_points
,
max_voxels
,
NDim
,
true
,
workspace_tensor_ptr
,
workspace_size
,
voxels_desc
.
desc
(),
voxels_ptr
,
coors_desc
.
desc
(),
coors_ptr
,
num_points_per_voxel_desc
.
desc
(),
num_points_per_voxel_ptr
,
voxel_num_tensor_desc
.
desc
(),
voxel_num_tensor_ptr
));
auto
voxel_num_cpu
=
voxel_num_tensor
.
to
(
at
::
kCPU
);
int
voxel_num_int
=
voxel_num_cpu
.
data_ptr
<
int
>
()[
0
];
return
voxel_num_int
;
}
int
hard_voxelize_forward_mlu
(
const
at
::
Tensor
&
points
,
at
::
Tensor
&
voxels
,
at
::
Tensor
&
coors
,
at
::
Tensor
&
num_points_per_voxel
,
const
std
::
vector
<
float
>
voxel_size
,
const
std
::
vector
<
float
>
coors_range
,
const
int
max_points
,
const
int
max_voxels
,
const
int
NDim
)
{
return
HardVoxelizeForwardMLUKernelLauncher
(
points
,
voxels
,
coors
,
num_points_per_voxel
,
voxel_size
,
coors_range
,
max_points
,
max_voxels
,
NDim
);
}
int
hard_voxelize_forward_impl
(
const
at
::
Tensor
&
points
,
at
::
Tensor
&
voxels
,
at
::
Tensor
&
coors
,
at
::
Tensor
&
num_points_per_voxel
,
const
std
::
vector
<
float
>
voxel_size
,
const
std
::
vector
<
float
>
coors_range
,
const
int
max_points
,
const
int
max_voxels
,
const
int
NDim
);
REGISTER_DEVICE_IMPL
(
hard_voxelize_forward_impl
,
MLU
,
hard_voxelize_forward_mlu
);
mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp
View file @
91da9643
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include "csrc_dipu/diopirt/diopirt_impl.h"
using
dipu
::
diopi_helper
::
toDiopiScalar
;
using
dipu
::
diopi_helper
::
toDiopiTensorHandle
;
#endif
void
modulated_deformable_im2col_impl
(
const
Tensor
data_im
,
const
Tensor
data_offset
,
const
Tensor
data_mask
,
...
...
@@ -45,7 +55,7 @@ void modulated_deformable_col2im_coord_impl(
dilation_w
,
deformable_group
,
grad_offset
,
grad_mask
);
}
void
modulated_deform_conv_forward
(
void
modulated_deform_conv_forward
_fallthrough
(
Tensor
input
,
Tensor
weight
,
Tensor
bias
,
Tensor
ones
,
Tensor
offset
,
Tensor
mask
,
Tensor
output
,
Tensor
columns
,
int
kernel_h
,
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
...
...
@@ -123,7 +133,7 @@ void modulated_deform_conv_forward(
}
}
void
modulated_deform_conv_backward
(
void
modulated_deform_conv_backward
_fallthrough
(
Tensor
input
,
Tensor
weight
,
Tensor
bias
,
Tensor
ones
,
Tensor
offset
,
Tensor
mask
,
Tensor
columns
,
Tensor
grad_input
,
Tensor
grad_weight
,
Tensor
grad_bias
,
Tensor
grad_offset
,
Tensor
grad_mask
,
Tensor
grad_output
,
...
...
@@ -235,3 +245,165 @@ void modulated_deform_conv_backward(
grad_output
.
size
(
2
),
grad_output
.
size
(
3
),
grad_output
.
size
(
4
)});
}
#ifdef MMCV_WITH_DIOPI
void
modulated_deform_conv_forward_diopi
(
Tensor
input
,
Tensor
weight
,
Tensor
bias
,
Tensor
ones
,
Tensor
offset
,
Tensor
mask
,
Tensor
output
,
Tensor
columns
,
int
kernel_h
,
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
deformable_group
,
const
bool
with_bias
)
{
auto
input_p
=
toDiopiTensorHandle
(
input
);
diopiDevice_t
device
;
diopiGetTensorDevice
(
input_p
,
&
device
);
if
(
device
==
diopi_host
)
{
modulated_deform_conv_forward_fallthrough
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
output
,
columns
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
return
;
}
diopiContext
ctx
(
dipu
::
getCurrentDIPUStream
().
rawstream
());
diopiContextHandle_t
ch
=
&
ctx
;
auto
weight_p
=
toDiopiTensorHandle
(
weight
);
auto
bias_p
=
toDiopiTensorHandle
(
bias
);
auto
ones_p
=
toDiopiTensorHandle
(
ones
);
auto
offset_p
=
toDiopiTensorHandle
(
offset
);
auto
mask_p
=
toDiopiTensorHandle
(
mask
);
auto
output_p
=
toDiopiTensorHandle
(
output
);
auto
columns_p
=
toDiopiTensorHandle
(
columns
);
if
(
reinterpret_cast
<
void
*>
(
diopiModulatedDeformConvMmcv
)
!=
nullptr
)
{
auto
ret
=
diopiModulatedDeformConvMmcv
(
ch
,
output_p
,
columns_p
,
ones_p
,
input_p
,
weight_p
,
bias_p
,
offset_p
,
mask_p
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
if
(
ret
==
diopiSuccess
)
return
;
}
LOG
(
WARNING
)
<<
"Fallback to cpu: mmcv ext op modulated_deform_conv_forward"
;
auto
input_cpu
=
input
.
cpu
();
auto
weight_cpu
=
weight
.
cpu
();
auto
bias_cpu
=
bias
.
cpu
();
auto
ones_cpu
=
ones
.
cpu
();
auto
offset_cpu
=
offset
.
cpu
();
auto
mask_cpu
=
mask
.
cpu
();
auto
output_cpu
=
output
.
cpu
();
auto
columns_cpu
=
columns
.
cpu
();
modulated_deform_conv_forward_fallthrough
(
input_cpu
,
weight_cpu
,
bias_cpu
,
ones_cpu
,
offset_cpu
,
mask_cpu
,
output_cpu
,
columns_cpu
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
output
.
copy_
(
output_cpu
);
return
;
}
void
modulated_deform_conv_backward_diopi
(
Tensor
input
,
Tensor
weight
,
Tensor
bias
,
Tensor
ones
,
Tensor
offset
,
Tensor
mask
,
Tensor
columns
,
Tensor
grad_input
,
Tensor
grad_weight
,
Tensor
grad_bias
,
Tensor
grad_offset
,
Tensor
grad_mask
,
Tensor
grad_output
,
int
kernel_h
,
int
kernel_w
,
int
stride_h
,
int
stride_w
,
int
pad_h
,
int
pad_w
,
int
dilation_h
,
int
dilation_w
,
int
group
,
int
deformable_group
,
const
bool
with_bias
)
{
auto
input_p
=
toDiopiTensorHandle
(
input
);
diopiDevice_t
device
;
diopiGetTensorDevice
(
input_p
,
&
device
);
if
(
device
==
diopi_host
)
{
modulated_deform_conv_backward_fallthrough
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
columns
,
grad_input
,
grad_weight
,
grad_bias
,
grad_offset
,
grad_mask
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
return
;
}
diopiContext
ctx
(
dipu
::
getCurrentDIPUStream
().
rawstream
());
diopiContextHandle_t
ch
=
&
ctx
;
auto
weight_p
=
toDiopiTensorHandle
(
weight
);
auto
bias_p
=
toDiopiTensorHandle
(
bias
);
auto
ones_p
=
toDiopiTensorHandle
(
ones
);
auto
offset_p
=
toDiopiTensorHandle
(
offset
);
auto
mask_p
=
toDiopiTensorHandle
(
mask
);
auto
columns_p
=
toDiopiTensorHandle
(
columns
);
auto
grad_input_p
=
toDiopiTensorHandle
(
grad_input
);
auto
grad_weight_p
=
toDiopiTensorHandle
(
grad_weight
);
auto
grad_bias_p
=
toDiopiTensorHandle
(
grad_bias
);
auto
grad_offset_p
=
toDiopiTensorHandle
(
grad_offset
);
auto
grad_mask_p
=
toDiopiTensorHandle
(
grad_mask
);
auto
grad_output_p
=
toDiopiTensorHandle
(
grad_output
);
if
(
reinterpret_cast
<
void
*>
(
diopiModulatedDeformConvBackwardMmcv
)
!=
nullptr
)
{
auto
ret
=
diopiModulatedDeformConvBackwardMmcv
(
ch
,
grad_input_p
,
grad_weight_p
,
grad_bias_p
,
grad_offset_p
,
grad_mask_p
,
input_p
,
weight_p
,
bias_p
,
ones_p
,
offset_p
,
mask_p
,
columns_p
,
grad_output_p
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
if
(
ret
==
diopiSuccess
)
return
;
}
LOG
(
WARNING
)
<<
"Fallback to cpu: mmcv ext op modulated_deform_conv_forward"
;
auto
input_cpu
=
input
.
cpu
();
auto
weight_cpu
=
weight
.
cpu
();
auto
bias_cpu
=
bias
.
cpu
();
auto
ones_cpu
=
ones
.
cpu
();
auto
offset_cpu
=
offset
.
cpu
();
auto
mask_cpu
=
mask
.
cpu
();
auto
columns_cpu
=
columns
.
cpu
();
auto
grad_input_cpu
=
grad_input
.
cpu
();
auto
grad_weight_cpu
=
grad_weight
.
cpu
();
auto
grad_bias_cpu
=
grad_bias
.
cpu
();
auto
grad_offset_cpu
=
grad_offset
.
cpu
();
auto
grad_mask_cpu
=
grad_mask
.
cpu
();
auto
grad_output_cpu
=
grad_output
.
cpu
();
modulated_deform_conv_backward_fallthrough
(
input_cpu
,
weight_cpu
,
bias_cpu
,
ones_cpu
,
offset_cpu
,
mask_cpu
,
columns_cpu
,
grad_input_cpu
,
grad_weight_cpu
,
grad_bias_cpu
,
grad_offset_cpu
,
grad_mask_cpu
,
grad_output_cpu
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
grad_input
.
copy_
(
grad_input_cpu
);
grad_weight
.
copy_
(
grad_weight_cpu
);
grad_bias
.
copy_
(
grad_bias_cpu
);
grad_offset
.
copy_
(
grad_offset_cpu
);
grad_mask
.
copy_
(
grad_mask_cpu
);
return
;
}
#endif
void
modulated_deform_conv_forward
(
Tensor
input
,
Tensor
weight
,
Tensor
bias
,
Tensor
ones
,
Tensor
offset
,
Tensor
mask
,
Tensor
output
,
Tensor
columns
,
int
kernel_h
,
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
deformable_group
,
const
bool
with_bias
)
{
#ifdef MMCV_WITH_DIOPI
modulated_deform_conv_forward_diopi
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
output
,
columns
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
#else
modulated_deform_conv_forward_fallthrough
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
output
,
columns
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
#endif
}
void
modulated_deform_conv_backward
(
Tensor
input
,
Tensor
weight
,
Tensor
bias
,
Tensor
ones
,
Tensor
offset
,
Tensor
mask
,
Tensor
columns
,
Tensor
grad_input
,
Tensor
grad_weight
,
Tensor
grad_bias
,
Tensor
grad_offset
,
Tensor
grad_mask
,
Tensor
grad_output
,
int
kernel_h
,
int
kernel_w
,
int
stride_h
,
int
stride_w
,
int
pad_h
,
int
pad_w
,
int
dilation_h
,
int
dilation_w
,
int
group
,
int
deformable_group
,
const
bool
with_bias
)
{
#ifdef MMCV_WITH_DIOPI
modulated_deform_conv_backward_diopi
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
columns
,
grad_input
,
grad_weight
,
grad_bias
,
grad_offset
,
grad_mask
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
#else
modulated_deform_conv_backward_fallthrough
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
columns
,
grad_input
,
grad_weight
,
grad_bias
,
grad_offset
,
grad_mask
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
#endif
}
mmcv/ops/csrc/pytorch/nms.cpp
View file @
91da9643
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include "csrc_dipu/base/basedef.h"
#include "csrc_dipu/diopirt/diopirt_impl.h"
using
dipu
::
diopi_helper
::
toDiopiScalar
;
using
dipu
::
diopi_helper
::
toDiopiTensorHandle
;
#endif
Tensor
nms_impl
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
return
DISPATCH_DEVICE_IMPL
(
nms_impl
,
boxes
,
scores
,
iou_threshold
,
offset
);
...
...
@@ -18,8 +29,42 @@ std::vector<std::vector<int> > nms_match_impl(Tensor dets,
return
DISPATCH_DEVICE_IMPL
(
nms_match_impl
,
dets
,
iou_threshold
);
}
#ifdef MMCV_WITH_DIOPI
Tensor
nms_diopi
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
auto
boxes_p
=
toDiopiTensorHandle
(
boxes
);
diopiDevice_t
device
;
diopiGetTensorDevice
(
boxes_p
,
&
device
);
if
(
device
==
diopi_host
)
{
return
nms_impl
(
boxes
,
scores
,
iou_threshold
,
offset
);
}
diopiContext
ctx
(
dipu
::
getCurrentDIPUStream
().
rawstream
());
diopiContextHandle_t
ch
=
&
ctx
;
Tensor
out
;
auto
outp
=
toDiopiTensorHandle
(
out
);
diopiTensorHandle_t
*
outhandle
=
&
outp
;
auto
scores_p
=
toDiopiTensorHandle
(
scores
);
bool
is_mock_cuda
=
boxes
.
device
().
type
()
==
dipu
::
DIPU_DEVICE_TYPE
;
if
(
is_mock_cuda
&&
reinterpret_cast
<
void
*>
(
diopiNmsMmcv
)
!=
nullptr
)
{
auto
ret
=
diopiNmsMmcv
(
ch
,
outhandle
,
boxes_p
,
scores_p
,
iou_threshold
,
offset
);
if
(
ret
==
diopiSuccess
)
{
auto
tensorhandle
=
reinterpret_cast
<
Tensor
*>
(
*
outhandle
);
return
*
tensorhandle
;
}
}
LOG
(
WARNING
)
<<
"Fallback to cpu: mmcv ext op nms"
;
auto
boxes_cpu
=
boxes
.
cpu
();
auto
scores_cpu
=
scores
.
cpu
();
return
nms_impl
(
boxes_cpu
,
scores_cpu
,
iou_threshold
,
offset
);
}
#endif
Tensor
nms
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
#ifdef MMCV_WITH_DIOPI
return
nms_diopi
(
boxes
,
scores
,
iou_threshold
,
offset
);
#else
return
nms_impl
(
boxes
,
scores
,
iou_threshold
,
offset
);
#endif
}
Tensor
softnms
(
Tensor
boxes
,
Tensor
scores
,
Tensor
dets
,
float
iou_threshold
,
...
...
mmcv/ops/csrc/pytorch/nms_rotated.cpp
View file @
91da9643
...
...
@@ -17,6 +17,11 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
const
Tensor
labels
,
const
float
iou_threshold
);
#endif
#ifdef MMCV_WITH_MLU
Tensor
nms_rotated_mlu
(
const
Tensor
dets
,
const
Tensor
scores
,
const
float
iou_threshold
);
#endif
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
...
...
@@ -31,11 +36,17 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
#else
AT_ERROR
(
"Not compiled with GPU support"
);
#endif
#ifdef MMCV_WITH_XLA
}
else
if
(
dets
.
device
().
type
()
==
at
::
kXLA
)
{
#ifdef MMCV_WITH_NPU
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
#else
AT_ERROR
(
"Not compiled with NPU support"
);
#endif
#ifdef MMCV_WITH_KPRIVATE
}
else
if
(
dets
.
device
().
type
()
==
at
::
kPrivateUse1
)
{
return
nms_rotated_npu
(
dets
,
scores
,
labels
,
iou_threshold
);
#endif
#ifdef MMCV_WITH_MLU
}
else
if
(
dets
.
device
().
type
()
==
at
::
kMLU
)
{
return
nms_rotated_mlu
(
dets
,
scores
,
iou_threshold
);
#endif
}
...
...
mmcv/ops/csrc/pytorch/npu/active_rotated_filter_npu.cpp
0 → 100644
View file @
91da9643
#include "pytorch_npu_helper.hpp"
using
namespace
NPU_NAME_SPACE
;
using
namespace
std
;
void
active_rotated_filter_forward_impl
(
const
Tensor
input
,
const
Tensor
indices
,
Tensor
output
);
void
active_rotated_filter_backward_impl
(
const
Tensor
grad_out
,
const
Tensor
indices
,
Tensor
grad_in
);
void
active_rotated_filter_forward_npu
(
const
Tensor
input
,
const
Tensor
indices
,
Tensor
output
)
{
OpCommand
cmd
;
cmd
.
Name
(
"ActiveRotatedFilter"
)
.
Input
(
input
)
.
Input
(
indices
)
.
Output
(
output
)
.
Run
();
}
void
active_rotated_filter_backward_npu
(
const
Tensor
grad_out
,
const
Tensor
indices
,
Tensor
grad_in
)
{
OpCommand
cmd
;
cmd
.
Name
(
"ActiveRotatedFilterGrad"
)
.
Input
(
grad_out
)
.
Input
(
indices
)
.
Output
(
grad_in
)
.
Run
();
}
REGISTER_NPU_IMPL
(
active_rotated_filter_forward_impl
,
active_rotated_filter_forward_npu
);
REGISTER_NPU_IMPL
(
active_rotated_filter_backward_impl
,
active_rotated_filter_backward_npu
);
mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp
View file @
91da9643
...
...
@@ -12,23 +12,40 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
if
(
mode
==
1
)
{
modeStr
=
"iof"
;
}
float
offset_
=
1
;
if
(
offset
==
0
)
{
offset_
=
0.01
;
bool
swap_flag
=
false
;
at
::
Tensor
bboxesFP32
=
bboxes2
;
at
::
Tensor
gtboxesFP32
=
bboxes1
;
if
(
bboxes2
.
size
(
0
)
<
bboxes1
.
size
(
0
))
{
swap_flag
=
true
;
bboxesFP32
=
bboxes1
;
gtboxesFP32
=
bboxes2
;
}
at
::
Tensor
bboxes
=
at
::
ones_like
(
bboxes2
);
at
::
Tensor
gtboxes
=
at
::
ones_like
(
bboxes1
);
bboxes
=
aligned
?
bboxes2
.
transpose
(
0
,
1
)
:
bboxes2
;
gtboxes
=
aligned
?
bboxes1
.
transpose
(
0
,
1
)
:
bboxes1
;
if
(
bboxes2
.
scalar_type
()
!=
at
::
kFloat
)
{
bboxesFP32
=
bboxesFP32
.
to
(
at
::
kFloat
);
gtboxesFP32
=
gtboxesFP32
.
to
(
at
::
kFloat
);
}
c10
::
SmallVector
<
int64_t
,
SIZE
>
iousSize
=
{
gtboxesFP32
.
size
(
0
),
bboxesFP32
.
size
(
0
)};
if
(
aligned
)
{
iousSize
=
{
gtboxesFP32
.
size
(
0
),
1
};
}
at
::
Tensor
iousFP32
=
at
::
empty
(
iousSize
,
bboxesFP32
.
options
());
bboxesFP32
=
aligned
?
bboxesFP32
.
transpose
(
0
,
1
)
:
bboxesFP32
;
gtboxesFP32
=
aligned
?
gtboxesFP32
.
transpose
(
0
,
1
)
:
gtboxesFP32
;
OpCommand
cmd
;
cmd
.
Name
(
"Iou"
)
.
Input
(
bboxes
)
.
Input
(
gtboxes
)
.
Output
(
ious
)
.
Input
(
bboxes
FP32
)
.
Input
(
gtboxes
FP32
)
.
Output
(
ious
FP32
)
.
Attr
(
"mode"
,
modeStr
)
.
Attr
(
"eps"
,
offset
_
)
.
Attr
(
"eps"
,
(
float
)
offset
)
.
Attr
(
"aligned"
,
aligned
)
.
Run
();
if
(
bboxes2
.
scalar_type
()
!=
at
::
kFloat
)
{
iousFP32
=
iousFP32
.
to
(
at
::
kHalf
);
}
iousFP32
=
swap_flag
?
iousFP32
.
transpose
(
0
,
1
)
:
iousFP32
;
ious
.
copy_
(
iousFP32
);
}
REGISTER_NPU_IMPL
(
bbox_overlaps_impl
,
bbox_overlaps_npu
);
mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp
0 → 100644
View file @
91da9643
#include "pytorch_npu_helper.hpp"
using
namespace
NPU_NAME_SPACE
;
using
namespace
std
;
void
box_iou_rotated_impl
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
);
void
box_iou_rotated_npu
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
)
{
at
::
Tensor
boxes
=
at
::
ones_like
(
boxes1
);
at
::
Tensor
query_boxes
=
at
::
ones_like
(
boxes2
);
boxes
=
boxes1
.
transpose
(
0
,
1
).
unsqueeze
(
0
);
query_boxes
=
boxes2
.
transpose
(
0
,
1
).
unsqueeze
(
0
);
bool
is_trans
=
false
;
string
modeStr
=
"iou"
;
if
(
mode_flag
==
1
)
{
modeStr
=
"iof"
;
}
bool
is_cross
=
true
;
if
(
aligned
)
{
is_cross
=
false
;
}
float
v_threshold
=
0
;
float
e_threshold
=
0
;
OpCommand
cmd
;
cmd
.
Name
(
"RotatedIou"
)
.
Input
(
boxes
)
.
Input
(
query_boxes
)
.
Output
(
ious
)
.
Attr
(
"trans"
,
is_trans
)
.
Attr
(
"mode"
,
modeStr
)
.
Attr
(
"is_cross"
,
is_cross
)
.
Attr
(
"v_threshold"
,
v_threshold
)
.
Attr
(
"e_threshold"
,
e_threshold
)
.
Run
();
if
(
is_cross
)
{
ious
=
ious
.
view
({
boxes1
.
size
(
0
),
boxes2
.
size
(
0
)});
}
else
{
ious
=
ious
.
view
({
boxes1
.
size
(
0
),
1
});
}
}
REGISTER_NPU_IMPL
(
box_iou_rotated_impl
,
box_iou_rotated_npu
);
mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
View file @
91da9643
...
...
@@ -12,15 +12,13 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
target_y
=
at
::
mul
(
target_y
,
-
1.0
);
target_y
=
at
::
add
(
target_y
,
1.0
);
}
else
{
target_y
=
at
_npu
::
native
::
NPUNativeFunctions
::
one_hot
(
target
,
n_class
);
target_y
=
at
::
one_hot
(
target
,
n_class
);
}
target_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_dtype_cast
(
target_y
,
at
::
kInt
);
target_y
=
target_y
.
to
(
at
::
kInt
);
int64_t
weight_size
=
weight
.
size
(
0
);
at
::
Tensor
weight_y
=
at
::
ones_like
(
input
);
if
(
weight_size
>
0
)
{
weight_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_broadcast
(
weight
,
input
.
sizes
());
weight_y
=
at
::
broadcast_to
(
weight
,
input
.
sizes
());
}
OpCommand
cmd
;
string
reduction
=
"none"
;
...
...
@@ -46,18 +44,16 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
if
(
n_class
==
1
)
{
target_y
=
at
::
reshape
(
target
,
input
.
sizes
());
}
else
{
target_y
=
at
_npu
::
native
::
NPUNativeFunctions
::
one_hot
(
target
,
n_class
);
target_y
=
at
::
one_hot
(
target
,
n_class
);
target_y
=
at
::
mul
(
target_y
,
-
1.0
);
target_y
=
at
::
add
(
target_y
,
1.0
);
}
target_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_dtype_cast
(
target_y
,
at
::
kInt
);
target_y
=
target_y
.
to
(
at
::
kInt
);
at
::
Tensor
grad_up
=
at
::
ones_like
(
input
);
int64_t
weight_size
=
weight
.
size
(
0
);
at
::
Tensor
weight_y
=
at
::
ones_like
(
input
);
if
(
weight_size
>
0
)
{
weight_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_broadcast
(
weight
,
input
.
sizes
());
weight_y
=
at
::
broadcast_to
(
weight
,
input
.
sizes
());
}
OpCommand
cmd
;
string
reduction
=
"none"
;
...
...
@@ -80,15 +76,12 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
void
softmax_focal_loss_forward_npu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
float
gamma
,
float
alpha
)
{
int64_t
n_class
=
input
.
size
(
1
);
at
::
Tensor
target_y
=
at_npu
::
native
::
NPUNativeFunctions
::
one_hot
(
target
,
n_class
);
target_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_dtype_cast
(
target_y
,
at
::
kInt
);
at
::
Tensor
target_y
=
at
::
one_hot
(
target
,
n_class
);
target_y
=
target_y
.
to
(
at
::
kInt
);
int64_t
weight_size
=
weight
.
size
(
0
);
at
::
Tensor
weight_y
=
at
::
ones_like
(
input
);
if
(
weight_size
>
0
)
{
weight_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_broadcast
(
weight
,
input
.
sizes
());
weight_y
=
at
::
broadcast_to
(
weight
,
input
.
sizes
());
}
at
::
Tensor
op_output
=
at
::
ones_like
(
input
);
OpCommand
cmd
;
...
...
@@ -107,8 +100,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
c10
::
SmallVector
<
int64_t
,
2
>
sizes
=
{
n_batch
,
1
};
at
::
IntArrayRef
offset
=
at
::
IntArrayRef
(
offsets
);
at
::
IntArrayRef
size
=
at
::
IntArrayRef
(
sizes
);
at_npu
::
native
::
NPUNativeFunctions
::
npu_slice_out
(
op_output
,
offset
,
size
,
output
);
at_npu
::
native
::
custom_ops
::
npu_slice_out
(
op_output
,
offset
,
size
,
output
);
}
void
softmax_focal_loss_forward_impl
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
...
...
@@ -119,16 +111,13 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor
buff
,
Tensor
grad_input
,
float
gamma
,
float
alpha
)
{
int64_t
n_class
=
input
.
size
(
1
);
at
::
Tensor
target_y
=
at_npu
::
native
::
NPUNativeFunctions
::
one_hot
(
target
,
n_class
);
target_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_dtype_cast
(
target_y
,
at
::
kInt
);
at
::
Tensor
target_y
=
at
::
one_hot
(
target
,
n_class
);
target_y
=
target_y
.
to
(
at
::
kInt
);
at
::
Tensor
grad_up
=
at
::
ones_like
(
input
);
int64_t
weight_size
=
weight
.
size
(
0
);
at
::
Tensor
weight_y
=
at
::
ones_like
(
input
);
if
(
weight_size
>
0
)
{
weight_y
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_broadcast
(
weight
,
input
.
sizes
());
weight_y
=
at
::
broadcast_to
(
weight
,
input
.
sizes
());
}
OpCommand
cmd
;
string
reduction
=
"none"
;
...
...
mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp
View file @
91da9643
...
...
@@ -25,8 +25,9 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias,
}
}
at
::
Tensor
bias_tmp
=
at
::
reshape
(
bias
,
input_size_tmp
);
at
::
Tensor
bias_
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_broadcast
(
bias_tmp
,
input
.
sizes
());
// at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast(
// bias_tmp, input.sizes());
at
::
Tensor
bias_
=
at
::
broadcast_to
(
bias_tmp
,
input
.
sizes
());
OpCommand
cmd
;
cmd
.
Name
(
"FusedBiasLeakyRelu"
)
.
Input
(
input
)
...
...
mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
View file @
91da9643
...
...
@@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
.
Attr
(
"batch_dims"
,
batch_dims
)
.
Run
();
}
void
gather_points_backward_npu
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
grad_out
,
const
Tensor
idx
,
Tensor
grad_points
)
{
at
::
Tensor
indices
=
idx
;
if
(
idx
.
scalar_type
()
!=
at
::
ScalarType
::
Int
)
{
indices
=
idx
.
to
(
at
::
kInt
);
}
if
(
idx
.
dim
()
==
0
)
{
indices
.
unsqueeze_
(
0
);
}
int64_t
dim
=
0
;
at
::
SmallVector
<
int64_t
,
N
>
pad_size
=
array_to_small_vector
(
idx
.
sizes
());
at
::
Tensor
trans_grad_points
=
grad_points
.
transpose
(
1
,
2
).
contiguous
();
at
::
Tensor
grad_points_view
=
trans_grad_points
.
view
(
{
trans_grad_points
.
sizes
()[
0
]
*
trans_grad_points
.
sizes
()[
1
],
trans_grad_points
.
sizes
()[
2
]});
at
::
Tensor
trans_grad_out
=
grad_out
.
transpose
(
1
,
2
).
contiguous
();
trans_grad_out
=
trans_grad_out
.
view
(
{
trans_grad_out
.
sizes
()[
0
]
*
trans_grad_out
.
sizes
()[
1
],
trans_grad_out
.
sizes
()[
2
]});
auto
index
=
at
::
arange
(
0
,
b
);
index
=
index
.
to
(
grad_out
.
device
());
index
=
at
::
mul
(
index
,
n
);
index
=
index
.
view
({
b
,
1
});
index
=
at
::
broadcast_to
(
index
,
pad_size
);
indices
=
at
::
add
(
index
,
indices
);
indices
=
indices
.
view
({
-
1
});
OpCommand
cmd
;
cmd
.
Name
(
"InplaceIndexAdd"
)
.
Input
(
grad_points_view
)
.
Input
(
indices
)
.
Input
(
trans_grad_out
)
.
Output
(
grad_points_view
)
.
Attr
(
"axis"
,
dim
)
.
Run
();
at
::
Tensor
grad_points_result
=
grad_points_view
.
view
(
trans_grad_points
.
sizes
());
grad_points_result
=
grad_points_result
.
transpose
(
1
,
2
);
grad_points
.
copy_
(
grad_points_result
);
}
void
gather_points_forward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
points
,
const
Tensor
idx
,
Tensor
out
);
void
gather_points_backward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
Tensor
grad_out
,
const
Tensor
idx
,
Tensor
grad_points
);
REGISTER_NPU_IMPL
(
gather_points_forward_impl
,
gather_points_forward_npu
);
REGISTER_NPU_IMPL
(
gather_points_backward_impl
,
gather_points_backward_npu
);
mmcv/ops/csrc/pytorch/npu/group_points_npu.cpp
0 → 100644
View file @
91da9643
#include "pytorch_npu_helper.hpp"
using
namespace
NPU_NAME_SPACE
;
using
namespace
std
;
void
group_points_forward_npu
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
Tensor
points
,
const
Tensor
idx
,
Tensor
out
)
{
// b, c, n, and npoints do not need to be passed into gatherv2,
// b, c, n, and npoints are calculated inside the operator
// gatherv2 operator in ascend needs to set axis to 0, batch_dims is 0
c10
::
SmallVector
<
int64_t
,
N
>
axis
=
{
0
};
int64_t
batch_dims
=
0
;
auto
index
=
at
::
arange
(
0
,
b
);
index
=
index
.
to
(
points
.
device
());
index
=
index
.
view
({
-
1
,
1
,
1
});
index
=
at
::
mul
(
index
,
n
);
at
::
Tensor
indices
=
at
::
add
(
index
,
idx
);
indices
=
indices
.
view
({
-
1
});
at
::
Tensor
trans_features
=
points
.
transpose
(
1
,
2
);
at
::
Tensor
features
=
NpuUtils
::
format_contiguous
(
trans_features
);
features
=
features
.
view
({
b
*
n
,
c
});
OpCommand
cmd
;
cmd
.
Name
(
"GatherV2"
)
.
Input
(
features
)
.
Input
(
indices
)
.
Input
(
axis
)
.
Output
(
out
)
.
Attr
(
"batch_dims"
,
batch_dims
)
.
Run
();
at
::
Tensor
output
=
out
.
view
({
b
,
npoints
,
nsample
,
c
}).
transpose
(
1
,
3
).
transpose
(
2
,
3
);
at
::
Tensor
res
=
NpuUtils
::
format_contiguous
(
output
);
out
.
copy_
(
res
);
}
void
group_points_forward_impl
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
Tensor
points
,
const
Tensor
idx
,
Tensor
out
);
REGISTER_NPU_IMPL
(
group_points_forward_impl
,
group_points_forward_npu
);
mmcv/ops/csrc/pytorch/npu/nms_npu.cpp
View file @
91da9643
...
...
@@ -4,24 +4,19 @@ using namespace NPU_NAME_SPACE;
using
namespace
std
;
Tensor
nms_npu
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
at
::
Tensor
boxed_offest
=
at_npu
::
native
::
OpPreparation
::
ApplyTensor
(
boxes
);
at
::
Tensor
ones_tensor
=
at_npu
::
native
::
OpPreparation
::
ApplyTensor
(
boxes
).
fill_
(
1
);
at
::
add_out
(
boxed_offest
,
boxes
,
ones_tensor
,
offset
);
at
::
Tensor
iou_threshold_y
=
at_npu
::
native
::
OpPreparation
::
ApplyTensor
(
{},
boxes
.
options
().
dtype
(
at
::
kFloat
),
boxes
)
.
fill_
(
iou_threshold
);
TORCH_CHECK
((
boxes
.
scalar_type
()
==
at
::
ScalarType
::
Float
),
"The type of boxes tensor passed in nms_npu should be float"
);
int64_t
offset_64
=
offset
;
at
::
Tensor
iou_threshold_y
=
at
::
empty
({},
boxes
.
options
().
dtype
(
at
::
kFloat
)).
fill_
(
iou_threshold
);
at
::
Tensor
scores_threshold_y
=
at_npu
::
native
::
OpPreparation
::
ApplyTensor
(
{},
boxes
.
options
().
dtype
(
at
::
kFloat
),
boxes
)
.
fill_
(
0
);
at
::
Tensor
max_outputsize_y
=
at_npu
::
native
::
OpPreparation
::
ApplyTensor
(
{},
boxes
.
options
().
dtype
(
at
::
kInt
),
boxes
)
.
fill_
(
boxes
.
size
(
0
));
at
::
empty
({},
boxes
.
options
().
dtype
(
at
::
kFloat
)).
fill_
(
0
);
at
::
Tensor
max_outputsize_y
=
at
::
empty
({},
boxes
.
options
().
dtype
(
at
::
kInt
)).
fill_
(
boxes
.
size
(
0
));
c10
::
SmallVector
<
int64_t
,
SIZE
>
outputsize
=
{
boxes
.
size
(
0
)};
at
::
Tensor
output
=
at_npu
::
native
::
OpPreparation
::
ApplyTensor
(
outputsize
,
boxes
.
options
().
dtype
(
at
::
kInt
),
boxes
)
.
fill_
(
-
1
);
at
::
Tensor
output
=
at
::
empty
(
outputsize
,
boxes
.
options
().
dtype
(
at
::
kInt
)).
fill_
(
-
1
);
OpCommand
cmd
;
cmd
.
Name
(
"NonMaxSuppressionV3"
)
.
Input
(
boxes
)
...
...
@@ -29,14 +24,14 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
.
Input
(
max_outputsize_y
)
.
Input
(
iou_threshold_y
)
.
Input
(
scores_threshold_y
)
.
Attr
(
"offset"
,
offset_64
)
.
Output
(
output
)
.
Run
();
auto
outputsizeBool
=
at
::
gt
(
output
,
-
1
);
auto
outputsizeInt
=
outputsizeBool
.
to
(
at
::
ScalarType
::
Int
);
auto
countLen
=
at
::
sum
(
outputsizeInt
,
at
::
ScalarType
::
Int
);
auto
outputsizeInt
=
outputsizeBool
.
to
(
at
::
k
Int
);
auto
countLen
=
at
::
sum
(
outputsizeInt
,
at
::
k
Int
);
at
::
Tensor
actual_output
=
output
.
slice
(
0
,
0
,
countLen
.
item
().
toLong
());
actual_output
=
at_npu
::
native
::
NPUNativeFunctions
::
npu_dtype_cast
(
actual_output
,
at
::
kLong
);
actual_output
=
actual_output
.
to
(
at
::
kLong
);
return
actual_output
;
}
...
...
mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp
View file @
91da9643
...
...
@@ -7,14 +7,15 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
auto
originDtype
=
dets
.
scalar_type
();
at
::
Tensor
detsCast
=
dets
;
at
::
Tensor
scoresCast
=
scores
;
if
(
originDtype
!=
at
::
ScalarType
::
Float
)
{
detsCast
=
NPUNativeFunctions
::
npu_dtype_cast
(
dets
,
at
::
kFloat
);
scoresCast
=
NPUNativeFunctions
::
npu_dtype_cast
(
scores
,
at
::
kFloat
);
if
(
originDtype
!=
at
::
k
Float
)
{
detsCast
=
detsCast
.
to
(
at
::
kFloat
);
scoresCast
=
scoresCast
.
to
(
at
::
kFloat
);
}
c10
::
SmallVector
<
int64_t
,
SIZE
>
selectedIndexSize
=
{
dets
.
size
(
0
)};
at
::
Tensor
selectedBox
=
OpPreparation
::
ApplyTensor
(
dets
);
at
::
Tensor
selectedIndex
=
OpPreparation
::
ApplyTensor
(
selectedIndexSize
,
dets
.
options
().
dtype
(
at
::
kInt
),
dets
);
at
::
Tensor
selectedBox
=
at
::
empty_like
(
dets
);
at
::
Tensor
selectedIndex
=
at
::
empty
(
selectedIndexSize
,
dets
.
options
().
dtype
(
at
::
kInt
));
c10
::
SmallVector
<
int64_t
,
N
>
output_sync_idx
=
{
0
,
1
};
OpCommand
cmd
;
...
...
@@ -27,6 +28,6 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
.
Output
(
selectedIndex
)
.
Attr
(
"iou_threshold"
,
(
float
)
iou_threshold
)
.
Run
();
selectedIndex
=
NPUNativeFunctions
::
npu_dtype_cast
(
selectedIndex
,
at
::
kLong
);
selectedIndex
=
selectedIndex
.
to
(
at
::
kLong
);
return
selectedIndex
;
}
mmcv/ops/csrc/pytorch/npu/points_in_polygons_npu.cpp
0 → 100644
View file @
91da9643
#include "pytorch_npu_helper.hpp"
using
namespace
NPU_NAME_SPACE
;
using
namespace
std
;
constexpr
int32_t
MAX_POLYGONS_BATCH
=
2800
;
void
points_in_polygons_npu
(
const
Tensor
points
,
Tensor
polygons
,
Tensor
output
,
const
int
rows
,
const
int
cols
)
{
TORCH_CHECK
(
(
polygons
.
sizes
()[
0
]
<=
MAX_POLYGONS_BATCH
),
"The batch of polygons tensor must be less than MAX_POLYGONS_BATCH"
);
at
::
Tensor
trans_polygons
=
polygons
.
transpose
(
0
,
1
);
OpCommand
cmd
;
at
::
Tensor
new_trans_polygons
=
NpuUtils
::
format_contiguous
(
trans_polygons
);
cmd
.
Name
(
"PointsInPolygons"
)
.
Input
(
points
,
(
string
)
"points"
)
.
Input
(
new_trans_polygons
,
(
string
)
"polygons"
)
.
Output
(
output
)
.
Run
();
}
void
points_in_polygons_forward_impl
(
const
Tensor
points
,
Tensor
polygons
,
Tensor
output
,
const
int
rows
,
const
int
cols
);
REGISTER_NPU_IMPL
(
points_in_polygons_forward_impl
,
points_in_polygons_npu
);
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment