Unverified Commit 8b8bf5e1 authored by Chris Jiang's avatar Chris Jiang Committed by GitHub
Browse files

[Refactor] Replace roipoint_pool3d op of MLU backend with mlu-ops (#2875)

parent 099ee24d
...@@ -9,32 +9,7 @@ ...@@ -9,32 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/ *************************************************************************/
#include "pytorch_device_registry.hpp" #include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void KernelRoiPointPool3dForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const int batch_size, const int pts_num,
const int boxes_num, const int feature_in_len,
const int sampled_pts_num, const void *xyz,
const void *boxes3d, const void *pts_feature,
void *pooled_features, int *pooled_empty_flag);
void KernelRoiPointPool3dLargeBoxesNumForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const int batch_size, const int pts_num,
const int boxes_num, const int feature_in_len, const int sampled_pts_num,
const void *xyz, const void *boxes3d, const void *pts_feature,
void *pooled_features, int *pooled_empty_flag);
// policy function
static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
// start U1 task, occupy all available clusters
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
void RoIPointPool3dForwardMLUKernelLauncher( void RoIPointPool3dForwardMLUKernelLauncher(
int batch_size, int pts_num, int boxes_num, int feature_in_len, int batch_size, int pts_num, int boxes_num, int feature_in_len,
...@@ -98,50 +73,55 @@ void RoIPointPool3dForwardMLUKernelLauncher( ...@@ -98,50 +73,55 @@ void RoIPointPool3dForwardMLUKernelLauncher(
"pts_feature element num should be less than 2^31, got ", "pts_feature element num should be less than 2^31, got ",
pts_feature.numel(), "."); pts_feature.numel(), ".");
// calculate task dimension // set contiguous
cnrtDim3_t k_dim; auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
cnrtFunctionType_t k_type; xyz, xyz.suggest_memory_format());
policyFuncForward(&k_dim, &k_type); auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pts_feature, pts_feature.suggest_memory_format());
// get compute queue auto boxes3d_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
auto queue = torch_mlu::getCurQueue(); boxes3d, boxes3d.suggest_memory_format());
auto pooled_features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pooled_features, pooled_features.suggest_memory_format());
auto pooled_empty_flag_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
pooled_empty_flag, pooled_empty_flag.suggest_memory_format());
// get ptr of tensors // get ptr of tensors
// transpose points [B, N ,3] -> [3, B, N] auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous);
auto xyz_ = xyz.permute({2, 0, 1}).contiguous();
auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_);
auto xyz_ptr = xyz_impl->cnnlMalloc(); auto xyz_ptr = xyz_impl->cnnlMalloc();
// transpose point_features [B, N, C] -> [B, C, N] auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_contiguous);
auto pts_feature_ = pts_feature.permute({0, 2, 1}).contiguous();
auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_);
auto pts_feature_ptr = pts_feature_impl->cnnlMalloc(); auto pts_feature_ptr = pts_feature_impl->cnnlMalloc();
auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d); auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d_contiguous);
auto boxes3d_ptr = boxes3d_impl->cnnlMalloc(); auto boxes3d_ptr = boxes3d_impl->cnnlMalloc();
auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features); auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features_contiguous);
auto pooled_features_ptr = pooled_features_impl->cnnlMalloc(); auto pooled_features_ptr = pooled_features_impl->cnnlMalloc();
auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag); auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous);
auto pooled_empty_flag_ptr = pooled_empty_flag_impl->cnnlMalloc(); auto pooled_empty_flag_ptr = pooled_empty_flag_impl->cnnlMalloc();
// get compute dtype of input // create tensor descriptors
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(xyz_.dtype()); MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc, pooled_features_desc, pooled_empty_flag_desc;
xyz_desc.set(xyz_contiguous);
// launch kernel pts_feature_desc.set(pts_feature_contiguous);
if (boxes_num <= 10240) { boxes3d_desc.set(boxes3d_contiguous);
CNLOG(INFO) << "Launch Kernel MLUKernelRoiPointPool3dForward<<<" << k_dim.x pooled_features_desc.set(pooled_features_contiguous);
<< ", " << k_dim.y << ", " << k_dim.z << ">>>"; pooled_empty_flag_desc.set(pooled_empty_flag_contiguous);
KernelRoiPointPool3dForward(
k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num, // get workspace
feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, pts_feature_ptr, size_t workspace_size = 0;
pooled_features_ptr, (int *)pooled_empty_flag_ptr); auto handle = mluOpGetCurrentHandle();
} else { TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize(handle, batch_size,
CNLOG(INFO) pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz_desc.desc(),
<< "Launch Kernel MLUKernelRoiPointPool3dLargeBoxesNumForward<<<" pts_feature_desc.desc(), boxes3d_desc.desc(), pooled_features_desc.desc(),
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; pooled_empty_flag_desc.desc(), &workspace_size));
KernelRoiPointPool3dLargeBoxesNumForward(
k_dim, k_type, queue, data_type, batch_size, pts_num, boxes_num, auto workspace = at::empty(workspace_size, xyz.options().dtype(at::kByte));
feature_in_len, sampled_pts_num, xyz_ptr, boxes3d_ptr, pts_feature_ptr, auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
pooled_features_ptr, (int *)pooled_empty_flag_ptr); auto workspace_ptr = workspace_impl->cnnlMalloc();
} TORCH_MLUOP_CHECK(mluOpRoiPointPool3d(
handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num,
xyz_desc.desc(), xyz_ptr, pts_feature_desc.desc(), pts_feature_ptr,
boxes3d_desc.desc(), boxes3d_ptr, workspace_ptr, workspace_size,
pooled_features_desc.desc(), pooled_features_ptr, pooled_empty_flag_desc.desc(),
(int *)pooled_empty_flag_ptr));
} }
void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num, void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment