Unverified Commit 99cb8535 authored by qirun-uiuc's avatar qirun-uiuc Committed by GitHub
Browse files

[Refactor] Replace focal_loss_sigmoid op of MLU backend with mlu-ops (#2855)

parent ee93530a
......@@ -12,87 +12,11 @@
#include <string>
#include <vector>
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#include "mlu_common_helper.h"
void KernelFocalLossSigmoidForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue,
const cnrtDataType_t d_type,
const void *input, const void *target,
const void *weight, const int32_t N,
const int32_t C, const float alpha,
const float gamma, void *output);
void KernelFocalLossSigmoidBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue,
const cnrtDataType_t d_type,
const void *input, const void *target,
const void *weight, const float gamma,
const float alpha, const int32_t dim_n,
const int32_t deal_n, const int32_t dim_c,
void *output);
// Policy Function for Forward
static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type,
const Tensor &input, const Tensor &target,
const Tensor &weight) {
auto N = input.size(0);
auto C = input.size(1);
const size_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const size_t c_align_size = PAD_UP((C * input.itemsize()), NFU_ALIGN_SIZE);
const int split_target_num = 2;
const int split_pipeline_num = 6;
const int has_weight = weight.data_ptr() != nullptr;
const int target_data_width = target.scalar_type() == at::kLong
? target.itemsize() / 2
: target.itemsize();
const int threshold_c =
PAD_DOWN((nram_size - split_target_num * sizeof(int)) /
(split_pipeline_num + has_weight),
NFU_ALIGN_SIZE) /
input.itemsize();
int n_seg = 1;
if (C <= threshold_c) {
int c_size = C * input.itemsize();
int reservered_align_size =
(split_target_num + split_pipeline_num) * NFU_ALIGN_SIZE;
int wegiht_size = 0;
if (has_weight) {
c_size = c_align_size;
reservered_align_size = split_target_num * NFU_ALIGN_SIZE;
wegiht_size = c_align_size;
}
// n_seg * c_size * split_pipeline_num + n_seg * target.itemsize() *
// split_target_num
// + weight_size + reservered_align_size <= nram_size
n_seg = (nram_size - wegiht_size - reservered_align_size) /
(split_pipeline_num * c_size + split_target_num * sizeof(int32_t));
}
auto seg_num = n_seg == 0 ? N : (N + n_seg - 1) / n_seg;
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
auto core_num = core_dim * cluster_num;
k_dim->x = *k_type;
k_dim->y =
seg_num > core_num ? cluster_num : (seg_num + core_dim - 1) / core_dim;
k_dim->z = 1;
}
// Policy Function for Backward
static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
// set Union1 Job
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
}
void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target,
void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma,
const float alpha) {
const float gamma, const float alpha) {
// params check
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
"But now gamma is ", gamma, ".");
......@@ -123,103 +47,50 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target,
return;
}
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
policyFuncForward(&k_dim, &k_type, input, target, weight);
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// contiguous
auto input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
input, input.suggest_memory_format());
// target only support in32
auto target_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
target.toType(at::kInt), target.suggest_memory_format());
auto weight_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
weight, weight.suggest_memory_format());
auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
output, output.suggest_memory_format());
// set tensor descriptor
MluOpTensorDescriptor input_desc, target_desc, weight_desc, output_desc;
input_desc.set(input_contiguous);
target_desc.set(target_contiguous);
weight_desc.set(weight_contiguous);
output_desc.set(output_contiguous);
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input);
auto input_impl = torch_mlu::getMluTensorImpl(input_contiguous);
auto input_ptr = input_impl->cnnlMalloc();
auto target_impl = torch_mlu::getMluTensorImpl(target);
auto target_impl = torch_mlu::getMluTensorImpl(target_contiguous);
auto target_ptr = target_impl->cnnlMalloc();
auto weight_impl = torch_mlu::getMluTensorImpl(weight);
auto weight_impl = torch_mlu::getMluTensorImpl(weight_contiguous);
auto weight_ptr = weight_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output);
auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
auto output_ptr = output_impl->cnnlMalloc();
// get dtype of input
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype());
CNLOG(INFO) << "Launch Kernel KernelFocalLossSigmoidForward<<<Union"
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", "
<< k_dim.z << ">>>";
// launch kernel
KernelFocalLossSigmoidForward(k_dim, k_type, queue, d_type, input_ptr,
target_ptr, weight_ptr, input.size(0),
input.size(1), alpha, gamma, output_ptr);
}
void getDealNAndThresholdC(const int compute_data_bytes,
const int target_data_bytes, const int total_c,
int *deal_n_ptr, int *threshold_c_ptr,
const bool has_weight, const bool is_half) {
/* NRAM partition:
*
* |-----------------ping pong--------------------|
* |input | pt | alpha_t | temp | output | target | flt_min | gamma | weight|
*
* split_pipeline_num is 5: including input, pt, alpha_t, temp, output.
*/
const int nram_split_num = 5;
const int nram_split_pingpong = 2;
const int max_nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
int32_t compute_align_size = NFU_ALIGN_SIZE;
if (is_half) {
compute_align_size += NFU_ALIGN_SIZE;
}
const int32_t compute_align_num = compute_align_size / compute_data_bytes;
// reservered_align_size: including input(ping pong), pt(ping pong),
// alpha_t(ping pong), temp(ping pong),
// output(ping pong), target(ping pong),
// flt_min and gamma.
const int reservered_align_size =
((nram_split_num + 1) * nram_split_pingpong + 2) * compute_align_size;
int nram_pingpong_size = max_nram_size - reservered_align_size;
int compute_c = total_c;
int threshold_c = 0;
if (has_weight) {
// reserved space for weight to align
nram_pingpong_size -= NFU_ALIGN_SIZE;
// set prefer computation performance and redcuntion approach
mluOpComputationPreference_t prefer = MLUOP_COMPUTATION_FAST;
mluOpLossReduction_t reduction = MLUOP_LOSS_REDUCTION_NONE;
// threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num +
// nram_split_pingpong * target_data_bytes +
// threshold_c * compute_data_bytes <= nram_pingpong_size
threshold_c =
(nram_pingpong_size - nram_split_pingpong * target_data_bytes) /
(compute_data_bytes * (nram_split_num * nram_split_pingpong + 1));
threshold_c = PAD_DOWN(threshold_c, compute_align_num);
int weight_space = PAD_UP(total_c * compute_data_bytes, NFU_ALIGN_SIZE);
auto handle = mluOpGetCurrentHandle();
// reserved space for weight
nram_pingpong_size -= weight_space;
compute_c = PAD_UP(total_c, compute_align_num);
} else {
// threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num +
// nram_split_pingpong * target_data_bytes <= nram_pingpong_size
threshold_c =
(nram_pingpong_size / nram_split_pingpong - target_data_bytes) /
(nram_split_num * compute_data_bytes);
}
// deal_n * compute_c * nram_split_pingpong * compute_data_bytes *
// nram_split_num + deal_n * nram_split_pingpong * target_data_bytes <=
// nram_pingpong_size
*deal_n_ptr =
nram_pingpong_size /
((nram_split_num * compute_c * compute_data_bytes + target_data_bytes) *
nram_split_pingpong);
*threshold_c_ptr = threshold_c;
// launch kernel
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(handle, prefer, reduction, input_desc.desc(),
input_ptr, target_desc.desc(), target_ptr,
weight_desc.desc(), weight_ptr, alpha, gamma,
output_desc.desc(), output_ptr));
}
void SigmoidFocalLossBackwardMLUKernelLauncher(Tensor input, Tensor target,
void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma,
const float alpha) {
const float gamma, const float alpha) {
// params check
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
"But now gamma is ", gamma, ".");
......@@ -246,77 +117,51 @@ void SigmoidFocalLossBackwardMLUKernelLauncher(Tensor input, Tensor target,
CNLOG(INFO) << "weight is a empty tensor.";
}
auto dim_c = input.size(1);
const int compute_data_bytes = sizeof(float);
// target supports only INT on MLU device while it keeps LONG on host side,
// so target.itemsize() / 2
const int target_data_bytes = target.scalar_type() == at::kLong
? (target.itemsize() / 2)
: target.itemsize();
int deal_n = 0;
int threshold_c = 0;
bool is_half = false;
if (input.scalar_type() == at::kHalf) {
is_half = true;
}
// calculate deal_n and threshold_c
getDealNAndThresholdC(compute_data_bytes, target_data_bytes, dim_c, &deal_n,
&threshold_c, has_weight, is_half);
// check C
TORCH_CHECK(threshold_c >= dim_c,
"input.size(1) should be in the range of [0, ", threshold_c,
"]. ", "But now input.size(1) is ", dim_c, ".");
if (input.numel() == 0 || target.numel() == 0 || output.numel() == 0) {
// return if zero-element
return;
}
// set task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncBackward(&k_dim, &k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// contiguous
auto input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
input, input.suggest_memory_format());
// only support in32
auto target_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
target.toType(at::kInt), target.suggest_memory_format());
auto weight_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
weight, weight.suggest_memory_format());
auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
output, output.suggest_memory_format());
// set tensor descriptor
MluOpTensorDescriptor input_desc, target_desc, weight_desc, output_desc;
input_desc.set(input_contiguous);
target_desc.set(target_contiguous);
weight_desc.set(weight_contiguous);
output_desc.set(output_contiguous);
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input);
auto input_impl = torch_mlu::getMluTensorImpl(input_contiguous);
auto input_ptr = input_impl->cnnlMalloc();
auto target_impl = torch_mlu::getMluTensorImpl(target);
auto target_impl = torch_mlu::getMluTensorImpl(target_contiguous);
auto target_ptr = target_impl->cnnlMalloc();
auto weight_impl = torch_mlu::getMluTensorImpl(weight);
auto weight_impl = torch_mlu::getMluTensorImpl(weight_contiguous);
auto weight_ptr = weight_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output);
auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
auto output_ptr = output_impl->cnnlMalloc();
// get dtype of input
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype());
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto dim_n = input.size(0);
// set prefer computation performance and redcuntion approach
// backward only support MLUOP_COMPUTATION_HIGH_PRECISION
mluOpComputationPreference_t prefer = MLUOP_COMPUTATION_HIGH_PRECISION;
mluOpLossReduction_t reduction = MLUOP_LOSS_REDUCTION_NONE;
CNLOG(INFO) << "Launch Kernel KernelFocalLossSigmoidBackward<<<Union"
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", "
<< k_dim.z << ">>>";
auto handle = mluOpGetCurrentHandle();
// launch kernel
KernelFocalLossSigmoidBackward(k_dim, k_type, queue, d_type, input_ptr,
target_ptr, weight_ptr, gamma, alpha, dim_n,
deal_n, dim_c, output_ptr);
}
void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
SigmoidFocalLossForwardMLUKernelLauncher(input, target, weight, output, gamma,
alpha);
}
void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha) {
SigmoidFocalLossBackwardMLUKernelLauncher(input, target, weight, grad_input,
gamma, alpha);
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(handle, prefer, reduction, input_desc.desc(),
input_ptr, target_desc.desc(), target_ptr,
weight_desc.desc(), weight_ptr, alpha, gamma,
output_desc.desc(), output_ptr));
}
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment