Unverified Commit 8b8bf5e1 authored by Chris Jiang's avatar Chris Jiang Committed by GitHub
Browse files

[Refactor] Replace roipoint_pool3d op of MLU backend with mlu-ops (#2875)

parent 099ee24d
......@@ -9,32 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelRoiPointPool3dForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const int batch_size, const int pts_num,
const int boxes_num, const int feature_in_len,
const int sampled_pts_num, const void *xyz,
const void *boxes3d, const void *pts_feature,
void *pooled_features, int *pooled_empty_flag);
void KernelRoiPointPool3dLargeBoxesNumForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int batch_size, const int pts_num,
const int boxes_num, const int feature_in_len, const int sampled_pts_num,
const void *xyz, const void *boxes3d, const void *pts_feature,
void *pooled_features, int *pooled_empty_flag);
// policy function
static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
// start U1 task, occupy all available clusters
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
#include "mlu_common_helper.h"
void RoIPointPool3dForwardMLUKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len,
......@@ -98,50 +73,55 @@ void RoIPointPool3dForwardMLUKernelLauncher(
"pts_feature element num should be less than 2^31, got ",
pts_feature.numel(), ".");
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncForward(&k_dim, &k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// set contiguous
auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
xyz, xyz.suggest_memory_format());
auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pts_feature, pts_feature.suggest_memory_format());
auto boxes3d_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
boxes3d, boxes3d.suggest_memory_format());
auto pooled_features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pooled_features, pooled_features.suggest_memory_format());
auto pooled_empty_flag_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pooled_empty_flag, pooled_empty_flag.suggest_memory_format());
// get ptr of tensors
// transpose points [B, N ,3] -> [3, B, N]
auto xyz_ = xyz.permute({2, 0, 1}).contiguous();
auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_);
auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous);
auto xyz_ptr = xyz_impl->cnnlMalloc();
// transpose point_features [B, N, C] -> [B, C, N]
auto pts_feature_ = pts_feature.permute({0, 2, 1}).contiguous();
auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_);
auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_contiguous);
auto pts_feature_ptr = pts_feature_impl->cnnlMalloc();
auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d);
auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d_contiguous);
auto boxes3d_ptr = boxes3d_impl->cnnlMalloc();
auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features);
auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features_contiguous);
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag);
auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous);
auto pooled_empty_flag_ptr = pooled_empty_flag_impl->cnnlMalloc();
// get compute dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(xyz_.dtype());
// launch kernel
if (boxes_num <= 10240) {
CNLOG(INFO) << "Launch Kernel MLUKernelRoiPointPool3dForward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelRoiPointPool3dForward(
k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num,
feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, pts_feature_ptr,
pooled_features_ptr, (int *)pooled_empty_flag_ptr);
} else {
CNLOG(INFO)
<< "Launch Kernel MLUKernelRoiPointPool3dLargeBoxesNumForward<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelRoiPointPool3dLargeBoxesNumForward(
k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num,
feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, pts_feature_ptr,
pooled_features_ptr, (int *)pooled_empty_flag_ptr);
}
// create tensor descriptors
MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc, pooled_features_desc, pooled_empty_flag_desc;
xyz_desc.set(xyz_contiguous);
pts_feature_desc.set(pts_feature_contiguous);
boxes3d_desc.set(boxes3d_contiguous);
pooled_features_desc.set(pooled_features_contiguous);
pooled_empty_flag_desc.set(pooled_empty_flag_contiguous);
// get workspace
size_t workspace_size = 0;
auto handle = mluOpGetCurrentHandle();
TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize(handle, batch_size,
pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz_desc.desc(),
pts_feature_desc.desc(), boxes3d_desc.desc(), pooled_features_desc.desc(),
pooled_empty_flag_desc.desc(), &workspace_size));
auto workspace = at::empty(workspace_size, xyz.options().dtype(at::kByte));
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
auto workspace_ptr = workspace_impl->cnnlMalloc();
TORCH_MLUOP_CHECK(mluOpRoiPointPool3d(
handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz_desc.desc(), xyz_ptr, pts_feature_desc.desc(), pts_feature_ptr,
boxes3d_desc.desc(), boxes3d_ptr, workspace_ptr, workspace_size,
pooled_features_desc.desc(), pooled_features_ptr, pooled_empty_flag_desc.desc(),
(int *)pooled_empty_flag_ptr));
}
void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment