Unverified Commit 2611b990 authored by qipengh's avatar qipengh Committed by GitHub
Browse files

[Refactor] Replace carafe op of MLU backend with mlu-ops (#2817)

parent 7ff7095c
This diff is collapsed.
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CARAFE_UTILS_HPP_
#define CARAFE_UTILS_HPP_
#define NRAM_ALIGN_SIZE 64
struct CarafeForwardParam {
int N; // batch size
int Hi; // input height
int Wi; // input width
int Ci; // input channels
int Ho; // output height
int Wo; // output width
int Cg; // channels per group
int kernel_size; // kernel_size
int group_size; // group_size
int scale_factor; // scale_factor
int kernel_size_half; // kernel half size (K-1)/2
int kernel_size_sq; // square of kernel size
int dtype_size; // size of tensor data type
// Host arrays' geometry
int input_stride_g;
int input_stride_w;
int input_stride_h;
int input_stride_n;
int input_size;
int mask_stride_kh;
int mask_stride_g;
int mask_stride_w;
int mask_stride_h;
int mask_stride_n;
int mask_size;
int output_stride_g;
int output_stride_w;
int output_stride_h;
int output_stride_n;
int output_size;
// NRAM arrays' geometry
int input_nram_stride_g;
int input_nram_stride_w;
int input_nram_stride_h;
int input_nram_size;
int mask_nram_stride_kh;
int mask_nram_stride_g;
int mask_nram_stride_w;
int mask_nram_stride_h;
int mask_nram_size;
int output_nram_stride_g;
int output_nram_stride_w;
int output_nram_stride_h;
int output_nram_size;
// for address/compute alignment
int align_size_NRAM; // for addressing on NRAM
int align_size_NFU; // for NFU operation length
int block_Cg_NFU; // for bang_mul_const
int job_num; // total job number
};
struct CarafeForwardBlockDim {
int Ho; // block size of output height
int Wo; // block size of output width
int Kh; // block size of kernel height
int Kw; // block size of kernel width
int G; // block size of groups
int Cg; // block size of channels within a group
int Hi; // block size of input height
int Wi; // block size of input width
};
struct CarafeForwardGridDim {
int Ho; // number of blocks of output height
int Wo;
int Kh;
int Kw;
int G;
int Cg;
};
#endif // CARAFE_UTILS_HPP_
...@@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) { ...@@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b; return a > b ? a : b;
} }
/*!
* @brief loads data from global DRAM to NRAM with 2D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void loadStr2D(T *dst, T *src, const int size, const int dst_str,
const int src_str, const int seg_num) {
if (dst_str == src_str && size == src_str) {
__memcpy(dst, src, src_str * seg_num * sizeof(T), GDRAM2NRAM);
} else if ((size == src_str || src_str <= dst_str) &&
src_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
T *tmp = (T *)dst + (dst_str - src_str) * seg_num;
__memcpy(tmp, src, (src_str * (seg_num - 1) + size) * sizeof(T),
GDRAM2NRAM);
if (dst_str != src_str) {
__memcpy(dst, tmp, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
} else {
__memcpy(dst, src, size * sizeof(T), GDRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief loads data from global DRAM to NRAM with 3D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void loadStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
loadStr2D(tmp_dst, tmp_src, size, dst_str_in, src_str_in, seg_num_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*!
* @brief stores data from NRAM to global DRAM with 2D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void storeStr2D(T *dst, T *src, const int size, const int seg_num,
const int dst_str, const int src_str) {
if ((size == dst_str && dst_str <= src_str) && dst_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
if (dst_str != src_str) {
__memcpy(src, src, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
__memcpy(dst, src, size * seg_num * sizeof(T), NRAM2GDRAM);
} else {
__memcpy(dst, src, size * sizeof(T), NRAM2GDRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief stores data from NRAM to global DRAM with 3D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void storeStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
storeStr2D(tmp_dst, tmp_src, size, seg_num_in, dst_str_in, src_str_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*! /*!
* @brief Converts int32 to float32 data type. * @brief Converts int32 to float32 data type.
* *
......
...@@ -9,200 +9,13 @@ ...@@ -9,200 +9,13 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/ *************************************************************************/
#include "carafe_utils.hpp" #include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *mask,
const CarafeForwardParam &param,
const CarafeForwardBlockDim &block_dim,
const CarafeForwardGridDim &grid_dim, void *output);
void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t dtype,
const void *input, const void *mask,
const void *grad_output, void *grad_input,
void *grad_mask, const int n, const int hi,
const int wi, const int c, const int k_up,
const int group, const int scale);
// Get total NRAM usage and set strides of NRAM arrays.
static void getNramUsage(CarafeForwardParam *param,
CarafeForwardBlockDim *block_dim, int *nram_usage) {
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_G, blkDim_Cg]
block_dim->Hi = CEIL_DIV(block_dim->Ho, param->scale_factor) + 1;
block_dim->Wi = CEIL_DIV(block_dim->Wo, param->scale_factor) + 1;
param->input_nram_stride_g = PAD_UP(block_dim->Cg, param->align_size_NRAM);
param->input_nram_stride_w = param->input_nram_stride_g * block_dim->G;
param->input_nram_stride_h =
(block_dim->Wi + block_dim->Kw - 1) * param->input_nram_stride_w;
param->input_nram_size =
(block_dim->Hi + block_dim->Kh - 1) * param->input_nram_stride_h;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_G, blkDim_Kh, blkDim_Kw]
param->mask_nram_stride_kh = block_dim->Kw;
param->mask_nram_stride_g = block_dim->Kh * param->mask_nram_stride_kh;
param->mask_nram_stride_w = block_dim->G * param->mask_nram_stride_g;
param->mask_nram_stride_h = block_dim->Wo * param->mask_nram_stride_w;
param->mask_nram_size =
PAD_UP(block_dim->Ho * param->mask_nram_stride_h, param->align_size_NRAM);
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
param->output_nram_stride_g = param->input_nram_stride_g;
param->output_nram_stride_w =
PAD_UP(param->input_nram_stride_w, param->align_size_NFU);
param->output_nram_stride_h = block_dim->Wo * param->output_nram_stride_w;
param->output_nram_size = block_dim->Ho * param->output_nram_stride_h;
// sum_array[blkDim_(G*Cg)]
// ensure the last mul_const on Cg does not exceed memory boundary
int sum_array_size_bang_mul_const =
(block_dim->G - 1) * param->input_nram_stride_g +
PAD_UP(param->input_nram_stride_g, param->align_size_NFU);
int sum_array_size =
std::max(param->output_nram_stride_w, sum_array_size_bang_mul_const);
*nram_usage = param->input_nram_size + param->mask_nram_size +
param->output_nram_size + sum_array_size;
}
// Policy Function for Forward
static void genPolicyForward(CarafeForwardParam *param,
CarafeForwardBlockDim *block_dim,
CarafeForwardGridDim *grid_dim, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
// device info
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
auto core_num = core_dim * cluster_num;
// maximum NRAM size as the number of <dtype>
auto max_nram_size =
torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore) / param->dtype_size;
// determine grid and block dimensions
// set initial values for block_dim and grid_dim
block_dim->Ho = param->Ho;
block_dim->Wo = param->Wo;
block_dim->Kh = param->kernel_size;
block_dim->Kw = param->kernel_size;
block_dim->G = param->group_size;
block_dim->Cg = param->Cg;
grid_dim->Ho = 1;
grid_dim->Wo = 1;
grid_dim->Kh = 1;
grid_dim->Kw = 1;
grid_dim->G = 1;
grid_dim->Cg = 1;
// decrease the block size to fit in the NRAM.
int nram_usage = 0;
while (true) {
getNramUsage(param, block_dim, &nram_usage);
if (nram_usage > max_nram_size) {
// decrease Ho
// decrease block_Ho and block_Wo evenly
// so that the block is close to a square.
if (block_dim->Ho > 1 && block_dim->Ho >= block_dim->Wo) {
grid_dim->Ho += 1;
block_dim->Ho = CEIL_DIV(param->Ho, grid_dim->Ho);
} else if (block_dim->Wo > 1 && block_dim->Wo > block_dim->Ho) {
// decrease Wo
grid_dim->Wo += 1;
block_dim->Wo = CEIL_DIV(param->Wo, grid_dim->Wo);
} else if (block_dim->Kh > 1) {
// decrease Kh
grid_dim->Kh += 1;
block_dim->Kh = CEIL_DIV(param->kernel_size, grid_dim->Kh);
// reset Hi, Wi to maximize NRAM usage
grid_dim->Ho = 1;
block_dim->Ho = param->Ho;
grid_dim->Wo = 1;
block_dim->Wo = param->Wo;
} else if (block_dim->Kw > 1) {
// decrease Kw
grid_dim->Kw += 1;
block_dim->Kw = CEIL_DIV(param->kernel_size, grid_dim->Kw);
// reset Kh
grid_dim->Kh = 1;
block_dim->Kh = param->kernel_size;
} else if (block_dim->G > 1) {
// decrease G
grid_dim->G += 1;
block_dim->G = CEIL_DIV(param->group_size, grid_dim->G);
// reset Kw
grid_dim->Kw = 1;
block_dim->Kw = param->kernel_size;
} else if (block_dim->Cg > 1) {
// decrease block_Cg
// This is done in the last since c is the continuous dim
// (input layout is NHWC) and large c can improve
// IO & compute efficiency.
grid_dim->Cg += 1;
block_dim->Cg = CEIL_DIV(param->Cg, grid_dim->Cg);
// reset G
grid_dim->G = 1;
block_dim->G = param->group_size;
} else {
// the block volume is one now, cannot decrease the block size anymore!
// this situation should not occur.
break;
}
} else {
break;
}
}
// define parameters depending on block_dim, grid_dim
param->block_Cg_NFU = PAD_UP(block_dim->Cg, param->align_size_NFU);
// define host arrays' strides
// input[N,H,W,G,Cg]
param->input_stride_g = param->Cg;
param->input_stride_w = param->Ci;
param->input_stride_h = param->Wi * param->input_stride_w;
param->input_stride_n = param->Hi * param->input_stride_h;
// mask[N,Ho,Wo,G,Kh,Kw]
param->mask_stride_kh = param->kernel_size;
param->mask_stride_g = param->kernel_size * param->mask_stride_kh;
param->mask_stride_w = param->group_size * param->mask_stride_g;
param->mask_stride_h = param->Wo * param->mask_stride_w;
param->mask_stride_n = param->Ho * param->mask_stride_h;
// output[N,Ho,Wo,G,Cg]
param->output_stride_g = param->Cg;
param->output_stride_w = param->Ci;
param->output_stride_h = param->Wo * param->output_stride_w;
param->output_stride_n = param->Ho * param->output_stride_h;
param->job_num =
param->N * grid_dim->Ho * grid_dim->Wo * grid_dim->G * grid_dim->Cg;
// determine task type and dims
*k_type = CNRT_FUNC_TYPE_BLOCK;
k_dim->x = std::min(param->job_num, static_cast<int>(core_num));
k_dim->y = 1;
k_dim->z = 1;
}
void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
Tensor rinput, Tensor routput, Tensor rmask, Tensor rinput, Tensor routput, Tensor rmask,
Tensor output, const int kernel_size, Tensor output, const int kernel_size,
const int group_size, const int group_size,
const int scale_factor) { const int scale_factor) {
const int batch_size = output.size(0);
const int channels = output.size(1);
const int ho = output.size(2);
const int wo = output.size(3);
// check tensor data type // check tensor data type
TORCH_CHECK( TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
...@@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, ...@@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
// return fast on zero-element tensor // return fast on zero-element tensor
if (output.numel() == 0) { if (output.numel() == 0) {
output = at::zeros({batch_size, channels, ho, wo}, output.options()); output = at::zeros(output.sizes().vec(), output.options());
return; return;
} }
// set param
CarafeForwardParam param;
param.N = input.size(0);
param.Ci = input.size(1);
param.Hi = input.size(2);
param.Wi = input.size(3);
param.kernel_size = kernel_size;
param.group_size = group_size;
param.scale_factor = scale_factor;
param.Cg = param.Ci / group_size;
param.dtype_size = input.itemsize();
param.align_size_NRAM = NRAM_ALIGN_SIZE / param.dtype_size;
param.align_size_NFU = NFU_ALIGN_SIZE / param.dtype_size;
param.kernel_size_sq = param.kernel_size * param.kernel_size;
param.kernel_size_half = (param.kernel_size - 1) / 2;
param.Ho = param.Hi * param.scale_factor;
param.Wo = param.Wi * param.scale_factor;
// generate policy
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
CarafeForwardBlockDim block_dim;
CarafeForwardGridDim grid_dim;
genPolicyForward(&param, &block_dim, &grid_dim, &k_dim, &k_type);
// convert NCHW to NHWC // convert NCHW to NHWC
auto memory_format_input_nhwc = auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
...@@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, ...@@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto routput_ = auto routput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format_output_nhwc); torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format_output_nhwc);
// set tensor descriptor
MluOpTensorDescriptor input_desc, mask_desc, output_desc;
input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC);
mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC);
output_desc.set_with_layout(routput_, MLUOP_LAYOUT_NHWC);
// get ptr of tensors // get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_); auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
auto input_ptr = input_impl->cnnlMalloc(); auto input_ptr = input_impl->cnnlMalloc();
...@@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, ...@@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto output_impl = torch_mlu::getMluTensorImpl(routput_); auto output_impl = torch_mlu::getMluTensorImpl(routput_);
auto output_ptr = output_impl->cnnlMalloc(); auto output_ptr = output_impl->cnnlMalloc();
// get compute queue // set op descriptor
auto queue = torch_mlu::getCurQueue(); auto handle = mluOpGetCurrentHandle();
mluOpCarafeDescriptor_t carafe_desc;
// get dtype of input mluOpCreateCarafeDescriptor(&carafe_desc);
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype()); mluOpSetCarafeDescriptor(carafe_desc, input.dim(), kernel_size, group_size,
scale_factor);
// launch kernel // launch kernel
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); mluOpCarafeForward(handle, carafe_desc, input_desc.desc(), input_ptr,
CNLOG(INFO) << "Launch Kernel KernelCarafeForward<<<Union" mask_desc.desc(), mask_ptr, output_desc.desc(),
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", " output_ptr);
<< k_dim.z << ">>>"; // destroy op descriptor
mluOpDestroyCarafeDescriptor(carafe_desc);
KernelCarafeForward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, param,
block_dim, grid_dim, output_ptr);
// copy output from NHWC back into NCHW // copy output from NHWC back into NCHW
rinput.copy_(rinput_); rinput.copy_(rinput_);
output.copy_(routput_); output.copy_(routput_);
} }
// Policy Function for Backward
static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
// set Union1 Job
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
}
void CARAFEBackwardMLUKernelLauncher( void CARAFEBackwardMLUKernelLauncher(
const Tensor grad_output, const Tensor rinput, const Tensor mask, const Tensor grad_output, const Tensor rinput, const Tensor mask,
Tensor rgrad_output, Tensor rgrad_input_hs, Tensor rgrad_input, Tensor rgrad_output, Tensor rgrad_input_hs, Tensor rgrad_input,
Tensor rgrad_mask, Tensor grad_input, Tensor grad_mask, Tensor rgrad_mask, Tensor grad_input, Tensor grad_mask,
const int kernel_size, const int group_size, const int scale_factor) { const int kernel_size, const int group_size, const int scale_factor) {
const int batch_size = rinput.size(0);
const int channels = rinput.size(1);
const int hi = rinput.size(2);
const int wi = rinput.size(3);
// data type check // data type check
TORCH_CHECK(grad_output.scalar_type() == at::kFloat || TORCH_CHECK(grad_output.scalar_type() == at::kFloat ||
grad_output.scalar_type() == at::kHalf, grad_output.scalar_type() == at::kHalf,
...@@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher( ...@@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher(
TORCH_CHECK(kernel_size < 137, "kernel_size should be less than 137, got ", TORCH_CHECK(kernel_size < 137, "kernel_size should be less than 137, got ",
kernel_size); kernel_size);
// set task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncBackward(&k_dim, &k_type);
// convert NCHW to NHWC // convert NCHW to NHWC
auto memory_format_input_nhwc = auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(rinput.dim()); torch_mlu::cnnl::ops::get_channels_last_memory_format(rinput.dim());
...@@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher( ...@@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher(
auto rgrad_mask_ = torch_mlu::cnnl::ops::cnnl_contiguous( auto rgrad_mask_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_mask, memory_format_grad_mask_nhwc); grad_mask, memory_format_grad_mask_nhwc);
// get compute queue // set tensor descriptor
auto queue = torch_mlu::getCurQueue(); MluOpTensorDescriptor input_desc, mask_desc;
input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC);
mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC);
MluOpTensorDescriptor grad_output_desc, grad_input_desc, grad_mask_desc;
grad_output_desc.set_with_layout(rgrad_output_, MLUOP_LAYOUT_NHWC);
grad_input_desc.set_with_layout(rgrad_input_, MLUOP_LAYOUT_NHWC);
grad_mask_desc.set_with_layout(rgrad_mask_, MLUOP_LAYOUT_NHWC);
// get ptr of tensors // get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_); auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
...@@ -378,19 +156,19 @@ void CARAFEBackwardMLUKernelLauncher( ...@@ -378,19 +156,19 @@ void CARAFEBackwardMLUKernelLauncher(
auto grad_mask_impl = torch_mlu::getMluTensorImpl(rgrad_mask_); auto grad_mask_impl = torch_mlu::getMluTensorImpl(rgrad_mask_);
auto grad_mask_ptr = grad_mask_impl->cnnlMalloc(); auto grad_mask_ptr = grad_mask_impl->cnnlMalloc();
// get dtype of grad_output // set op descriptor
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(grad_output.dtype()); auto handle = mluOpGetCurrentHandle();
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); mluOpCarafeDescriptor_t carafe_desc;
mluOpCreateCarafeDescriptor(&carafe_desc);
CNLOG(INFO) << "Launch Kernel KernelCarafeBackward<<<Union" mluOpSetCarafeDescriptor(carafe_desc, grad_output.dim(), kernel_size,
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", " group_size, scale_factor);
<< k_dim.z << ">>>";
// launch kernel // launch kernel
KernelCarafeBackward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, mluOpCarafeBackward(handle, carafe_desc, input_desc.desc(), input_ptr,
grad_output_ptr, grad_input_ptr, grad_mask_ptr, mask_desc.desc(), mask_ptr, grad_output_desc.desc(),
batch_size, hi, wi, channels, kernel_size, group_size, grad_output_ptr, grad_input_desc.desc(), grad_input_ptr,
scale_factor); grad_mask_desc.desc(), grad_mask_ptr);
// destroy op descriptor
mluOpDestroyCarafeDescriptor(carafe_desc);
// copy output from NHWC back into NCHW // copy output from NHWC back into NCHW
grad_input.copy_(rgrad_input_); grad_input.copy_(rgrad_input_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment