Unverified Commit 2611b990 authored by qipengh's avatar qipengh Committed by GitHub
Browse files

[Refactor] Replace carafe op of MLU backend with mlu-ops (#2817)

parent 7ff7095c
This diff is collapsed.
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CARAFE_UTILS_HPP_
#define CARAFE_UTILS_HPP_
#define NRAM_ALIGN_SIZE 64
struct CarafeForwardParam {
int N; // batch size
int Hi; // input height
int Wi; // input width
int Ci; // input channels
int Ho; // output height
int Wo; // output width
int Cg; // channels per group
int kernel_size; // kernel_size
int group_size; // group_size
int scale_factor; // scale_factor
int kernel_size_half; // kernel half size (K-1)/2
int kernel_size_sq; // square of kernel size
int dtype_size; // size of tensor data type
// Host arrays' geometry
int input_stride_g;
int input_stride_w;
int input_stride_h;
int input_stride_n;
int input_size;
int mask_stride_kh;
int mask_stride_g;
int mask_stride_w;
int mask_stride_h;
int mask_stride_n;
int mask_size;
int output_stride_g;
int output_stride_w;
int output_stride_h;
int output_stride_n;
int output_size;
// NRAM arrays' geometry
int input_nram_stride_g;
int input_nram_stride_w;
int input_nram_stride_h;
int input_nram_size;
int mask_nram_stride_kh;
int mask_nram_stride_g;
int mask_nram_stride_w;
int mask_nram_stride_h;
int mask_nram_size;
int output_nram_stride_g;
int output_nram_stride_w;
int output_nram_stride_h;
int output_nram_size;
// for address/compute alignment
int align_size_NRAM; // for addressing on NRAM
int align_size_NFU; // for NFU operation length
int block_Cg_NFU; // for bang_mul_const
int job_num; // total job number
};
struct CarafeForwardBlockDim {
int Ho; // block size of output height
int Wo; // block size of output width
int Kh; // block size of kernel height
int Kw; // block size of kernel width
int G; // block size of groups
int Cg; // block size of channels within a group
int Hi; // block size of input height
int Wi; // block size of input width
};
struct CarafeForwardGridDim {
int Ho; // number of blocks of output height
int Wo;
int Kh;
int Kw;
int G;
int Cg;
};
#endif // CARAFE_UTILS_HPP_
......@@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b;
}
/*!
* @brief loads data from global DRAM to NRAM with 2D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void loadStr2D(T *dst, T *src, const int size, const int dst_str,
const int src_str, const int seg_num) {
if (dst_str == src_str && size == src_str) {
__memcpy(dst, src, src_str * seg_num * sizeof(T), GDRAM2NRAM);
} else if ((size == src_str || src_str <= dst_str) &&
src_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
T *tmp = (T *)dst + (dst_str - src_str) * seg_num;
__memcpy(tmp, src, (src_str * (seg_num - 1) + size) * sizeof(T),
GDRAM2NRAM);
if (dst_str != src_str) {
__memcpy(dst, tmp, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
} else {
__memcpy(dst, src, size * sizeof(T), GDRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief loads data from global DRAM to NRAM with 3D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void loadStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
loadStr2D(tmp_dst, tmp_src, size, dst_str_in, src_str_in, seg_num_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*!
* @brief stores data from NRAM to global DRAM with 2D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void storeStr2D(T *dst, T *src, const int size, const int seg_num,
const int dst_str, const int src_str) {
if ((size == dst_str && dst_str <= src_str) && dst_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
if (dst_str != src_str) {
__memcpy(src, src, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
__memcpy(dst, src, size * seg_num * sizeof(T), NRAM2GDRAM);
} else {
__memcpy(dst, src, size * sizeof(T), NRAM2GDRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief stores data from NRAM to global DRAM with 3D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void storeStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
storeStr2D(tmp_dst, tmp_src, size, seg_num_in, dst_str_in, src_str_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*!
* @brief Converts int32 to float32 data type.
*
......
......@@ -9,200 +9,13 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "carafe_utils.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *mask,
const CarafeForwardParam &param,
const CarafeForwardBlockDim &block_dim,
const CarafeForwardGridDim &grid_dim, void *output);
void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t dtype,
const void *input, const void *mask,
const void *grad_output, void *grad_input,
void *grad_mask, const int n, const int hi,
const int wi, const int c, const int k_up,
const int group, const int scale);
// Get total NRAM usage and set strides of NRAM arrays.
static void getNramUsage(CarafeForwardParam *param,
CarafeForwardBlockDim *block_dim, int *nram_usage) {
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_G, blkDim_Cg]
block_dim->Hi = CEIL_DIV(block_dim->Ho, param->scale_factor) + 1;
block_dim->Wi = CEIL_DIV(block_dim->Wo, param->scale_factor) + 1;
param->input_nram_stride_g = PAD_UP(block_dim->Cg, param->align_size_NRAM);
param->input_nram_stride_w = param->input_nram_stride_g * block_dim->G;
param->input_nram_stride_h =
(block_dim->Wi + block_dim->Kw - 1) * param->input_nram_stride_w;
param->input_nram_size =
(block_dim->Hi + block_dim->Kh - 1) * param->input_nram_stride_h;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_G, blkDim_Kh, blkDim_Kw]
param->mask_nram_stride_kh = block_dim->Kw;
param->mask_nram_stride_g = block_dim->Kh * param->mask_nram_stride_kh;
param->mask_nram_stride_w = block_dim->G * param->mask_nram_stride_g;
param->mask_nram_stride_h = block_dim->Wo * param->mask_nram_stride_w;
param->mask_nram_size =
PAD_UP(block_dim->Ho * param->mask_nram_stride_h, param->align_size_NRAM);
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
param->output_nram_stride_g = param->input_nram_stride_g;
param->output_nram_stride_w =
PAD_UP(param->input_nram_stride_w, param->align_size_NFU);
param->output_nram_stride_h = block_dim->Wo * param->output_nram_stride_w;
param->output_nram_size = block_dim->Ho * param->output_nram_stride_h;
// sum_array[blkDim_(G*Cg)]
// ensure the last mul_const on Cg does not exceed memory boundary
int sum_array_size_bang_mul_const =
(block_dim->G - 1) * param->input_nram_stride_g +
PAD_UP(param->input_nram_stride_g, param->align_size_NFU);
int sum_array_size =
std::max(param->output_nram_stride_w, sum_array_size_bang_mul_const);
*nram_usage = param->input_nram_size + param->mask_nram_size +
param->output_nram_size + sum_array_size;
}
// Policy Function for Forward
static void genPolicyForward(CarafeForwardParam *param,
CarafeForwardBlockDim *block_dim,
CarafeForwardGridDim *grid_dim, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
// device info
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
auto core_num = core_dim * cluster_num;
// maximum NRAM size as the number of <dtype>
auto max_nram_size =
torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore) / param->dtype_size;
// determine grid and block dimensions
// set initial values for block_dim and grid_dim
block_dim->Ho = param->Ho;
block_dim->Wo = param->Wo;
block_dim->Kh = param->kernel_size;
block_dim->Kw = param->kernel_size;
block_dim->G = param->group_size;
block_dim->Cg = param->Cg;
grid_dim->Ho = 1;
grid_dim->Wo = 1;
grid_dim->Kh = 1;
grid_dim->Kw = 1;
grid_dim->G = 1;
grid_dim->Cg = 1;
// decrease the block size to fit in the NRAM.
int nram_usage = 0;
while (true) {
getNramUsage(param, block_dim, &nram_usage);
if (nram_usage > max_nram_size) {
// decrease Ho
// decrease block_Ho and block_Wo evenly
// so that the block is close to a square.
if (block_dim->Ho > 1 && block_dim->Ho >= block_dim->Wo) {
grid_dim->Ho += 1;
block_dim->Ho = CEIL_DIV(param->Ho, grid_dim->Ho);
} else if (block_dim->Wo > 1 && block_dim->Wo > block_dim->Ho) {
// decrease Wo
grid_dim->Wo += 1;
block_dim->Wo = CEIL_DIV(param->Wo, grid_dim->Wo);
} else if (block_dim->Kh > 1) {
// decrease Kh
grid_dim->Kh += 1;
block_dim->Kh = CEIL_DIV(param->kernel_size, grid_dim->Kh);
// reset Hi, Wi to maximize NRAM usage
grid_dim->Ho = 1;
block_dim->Ho = param->Ho;
grid_dim->Wo = 1;
block_dim->Wo = param->Wo;
} else if (block_dim->Kw > 1) {
// decrease Kw
grid_dim->Kw += 1;
block_dim->Kw = CEIL_DIV(param->kernel_size, grid_dim->Kw);
// reset Kh
grid_dim->Kh = 1;
block_dim->Kh = param->kernel_size;
} else if (block_dim->G > 1) {
// decrease G
grid_dim->G += 1;
block_dim->G = CEIL_DIV(param->group_size, grid_dim->G);
// reset Kw
grid_dim->Kw = 1;
block_dim->Kw = param->kernel_size;
} else if (block_dim->Cg > 1) {
// decrease block_Cg
// This is done in the last since c is the continuous dim
// (input layout is NHWC) and large c can improve
// IO & compute efficiency.
grid_dim->Cg += 1;
block_dim->Cg = CEIL_DIV(param->Cg, grid_dim->Cg);
// reset G
grid_dim->G = 1;
block_dim->G = param->group_size;
} else {
// the block volume is one now, cannot decrease the block size anymore!
// this situation should not occur.
break;
}
} else {
break;
}
}
// define parameters depending on block_dim, grid_dim
param->block_Cg_NFU = PAD_UP(block_dim->Cg, param->align_size_NFU);
// define host arrays' strides
// input[N,H,W,G,Cg]
param->input_stride_g = param->Cg;
param->input_stride_w = param->Ci;
param->input_stride_h = param->Wi * param->input_stride_w;
param->input_stride_n = param->Hi * param->input_stride_h;
// mask[N,Ho,Wo,G,Kh,Kw]
param->mask_stride_kh = param->kernel_size;
param->mask_stride_g = param->kernel_size * param->mask_stride_kh;
param->mask_stride_w = param->group_size * param->mask_stride_g;
param->mask_stride_h = param->Wo * param->mask_stride_w;
param->mask_stride_n = param->Ho * param->mask_stride_h;
// output[N,Ho,Wo,G,Cg]
param->output_stride_g = param->Cg;
param->output_stride_w = param->Ci;
param->output_stride_h = param->Wo * param->output_stride_w;
param->output_stride_n = param->Ho * param->output_stride_h;
param->job_num =
param->N * grid_dim->Ho * grid_dim->Wo * grid_dim->G * grid_dim->Cg;
// determine task type and dims
*k_type = CNRT_FUNC_TYPE_BLOCK;
k_dim->x = std::min(param->job_num, static_cast<int>(core_num));
k_dim->y = 1;
k_dim->z = 1;
}
#include "mlu_common_helper.h"
void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
Tensor rinput, Tensor routput, Tensor rmask,
Tensor output, const int kernel_size,
const int group_size,
const int scale_factor) {
const int batch_size = output.size(0);
const int channels = output.size(1);
const int ho = output.size(2);
const int wo = output.size(3);
// check tensor data type
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
......@@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
// return fast on zero-element tensor
if (output.numel() == 0) {
output = at::zeros({batch_size, channels, ho, wo}, output.options());
output = at::zeros(output.sizes().vec(), output.options());
return;
}
// set param
CarafeForwardParam param;
param.N = input.size(0);
param.Ci = input.size(1);
param.Hi = input.size(2);
param.Wi = input.size(3);
param.kernel_size = kernel_size;
param.group_size = group_size;
param.scale_factor = scale_factor;
param.Cg = param.Ci / group_size;
param.dtype_size = input.itemsize();
param.align_size_NRAM = NRAM_ALIGN_SIZE / param.dtype_size;
param.align_size_NFU = NFU_ALIGN_SIZE / param.dtype_size;
param.kernel_size_sq = param.kernel_size * param.kernel_size;
param.kernel_size_half = (param.kernel_size - 1) / 2;
param.Ho = param.Hi * param.scale_factor;
param.Wo = param.Wi * param.scale_factor;
// generate policy
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
CarafeForwardBlockDim block_dim;
CarafeForwardGridDim grid_dim;
genPolicyForward(&param, &block_dim, &grid_dim, &k_dim, &k_type);
// convert NCHW to NHWC
auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
......@@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto routput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format_output_nhwc);
// set tensor descriptor
MluOpTensorDescriptor input_desc, mask_desc, output_desc;
input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC);
mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC);
output_desc.set_with_layout(routput_, MLUOP_LAYOUT_NHWC);
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
auto input_ptr = input_impl->cnnlMalloc();
......@@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto output_impl = torch_mlu::getMluTensorImpl(routput_);
auto output_ptr = output_impl->cnnlMalloc();
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get dtype of input
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype());
// set op descriptor
auto handle = mluOpGetCurrentHandle();
mluOpCarafeDescriptor_t carafe_desc;
mluOpCreateCarafeDescriptor(&carafe_desc);
mluOpSetCarafeDescriptor(carafe_desc, input.dim(), kernel_size, group_size,
scale_factor);
// launch kernel
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
CNLOG(INFO) << "Launch Kernel KernelCarafeForward<<<Union"
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", "
<< k_dim.z << ">>>";
KernelCarafeForward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, param,
block_dim, grid_dim, output_ptr);
mluOpCarafeForward(handle, carafe_desc, input_desc.desc(), input_ptr,
mask_desc.desc(), mask_ptr, output_desc.desc(),
output_ptr);
// destroy op descriptor
mluOpDestroyCarafeDescriptor(carafe_desc);
// copy output from NHWC back into NCHW
rinput.copy_(rinput_);
output.copy_(routput_);
}
// Policy Function for Backward
static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
// set Union1 Job
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
}
void CARAFEBackwardMLUKernelLauncher(
const Tensor grad_output, const Tensor rinput, const Tensor mask,
Tensor rgrad_output, Tensor rgrad_input_hs, Tensor rgrad_input,
Tensor rgrad_mask, Tensor grad_input, Tensor grad_mask,
const int kernel_size, const int group_size, const int scale_factor) {
const int batch_size = rinput.size(0);
const int channels = rinput.size(1);
const int hi = rinput.size(2);
const int wi = rinput.size(3);
// data type check
TORCH_CHECK(grad_output.scalar_type() == at::kFloat ||
grad_output.scalar_type() == at::kHalf,
......@@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher(
TORCH_CHECK(kernel_size < 137, "kernel_size should be less than 137, got ",
kernel_size);
// set task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncBackward(&k_dim, &k_type);
// convert NCHW to NHWC
auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(rinput.dim());
......@@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher(
auto rgrad_mask_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_mask, memory_format_grad_mask_nhwc);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// set tensor descriptor
MluOpTensorDescriptor input_desc, mask_desc;
input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC);
mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC);
MluOpTensorDescriptor grad_output_desc, grad_input_desc, grad_mask_desc;
grad_output_desc.set_with_layout(rgrad_output_, MLUOP_LAYOUT_NHWC);
grad_input_desc.set_with_layout(rgrad_input_, MLUOP_LAYOUT_NHWC);
grad_mask_desc.set_with_layout(rgrad_mask_, MLUOP_LAYOUT_NHWC);
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
......@@ -378,19 +156,19 @@ void CARAFEBackwardMLUKernelLauncher(
auto grad_mask_impl = torch_mlu::getMluTensorImpl(rgrad_mask_);
auto grad_mask_ptr = grad_mask_impl->cnnlMalloc();
// get dtype of grad_output
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(grad_output.dtype());
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
CNLOG(INFO) << "Launch Kernel KernelCarafeBackward<<<Union"
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", "
<< k_dim.z << ">>>";
// set op descriptor
auto handle = mluOpGetCurrentHandle();
mluOpCarafeDescriptor_t carafe_desc;
mluOpCreateCarafeDescriptor(&carafe_desc);
mluOpSetCarafeDescriptor(carafe_desc, grad_output.dim(), kernel_size,
group_size, scale_factor);
// launch kernel
KernelCarafeBackward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr,
grad_output_ptr, grad_input_ptr, grad_mask_ptr,
batch_size, hi, wi, channels, kernel_size, group_size,
scale_factor);
mluOpCarafeBackward(handle, carafe_desc, input_desc.desc(), input_ptr,
mask_desc.desc(), mask_ptr, grad_output_desc.desc(),
grad_output_ptr, grad_input_desc.desc(), grad_input_ptr,
grad_mask_desc.desc(), grad_mask_ptr);
// destroy op descriptor
mluOpDestroyCarafeDescriptor(carafe_desc);
// copy output from NHWC back into NCHW
grad_input.copy_(rgrad_input_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment