Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
2611b990
Unverified
Commit
2611b990
authored
Jun 01, 2023
by
qipengh
Committed by
GitHub
Jun 01, 2023
Browse files
[Refactor] Replace carafe op of MLU backend with mlu-ops (#2817)
parent
7ff7095c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
1051 deletions
+40
-1051
mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu
+0
-552
mmcv/ops/csrc/common/mlu/carafe_utils.hpp
mmcv/ops/csrc/common/mlu/carafe_utils.hpp
+0
-95
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
+0
-142
mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp
+40
-262
No files found.
mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu
deleted
100644 → 0
View file @
7ff7095c
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "carafe_utils.hpp"
#include "common_mlu_helper.hpp"
#define INDEX3(n, h, w, c, strN, strH, strW) \
(strN) * (n) + (strH) * (h) + (strW) * (w) + (c)
#define NRAM_BLOCK PAD_DOWN(MAX_NRAM_SIZE / 5, NRAM_ALIGN_SIZE)
__nram__ char nram_buf[MAX_NRAM_SIZE];
namespace forward {
struct BlockId {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
};
// start indices of block
struct BlockStart {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
int C;
};
struct BlockEnd {
int Ho;
int Wo;
int Kh;
int Kw;
int Hi;
int Wi;
};
struct BlockSize {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
};
template <typename T>
__mlu_func__ void carafeForwardBLOCK(T *input, T *mask,
const CarafeForwardParam param,
const CarafeForwardBlockDim block_dim,
const CarafeForwardGridDim grid_dim,
T *output) {
// data block info
BlockId blkId;
BlockStart blkStart;
BlockEnd blkEnd;
BlockSize blkSize;
// set pointers on NRAM arrays
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_(G*Cg)]
T *input_nram = (T *)nram_buf;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Kh*Kw)]
T *mask_nram = input_nram + param.input_nram_size;
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
T *output_nram = mask_nram + param.mask_nram_size;
// sum_array[blkDim_(G*Cg)]
T *sum_array = output_nram + param.output_nram_size;
/* ===== loop over N, grid_dim(Ho,Wo,G,Cg)
* iterations are distributed over computing cores
*/
for (int loop_index = taskId; loop_index < param.job_num;
loop_index += taskDim) {
// block idx
blkId.Cg = loop_index;
blkId.G = blkId.Cg / grid_dim.Cg;
blkId.Wo = blkId.G / grid_dim.G;
blkId.Ho = blkId.Wo / grid_dim.Wo;
int sample_idx = blkId.Ho / grid_dim.Ho;
blkId.Cg %= grid_dim.Cg;
blkId.G %= grid_dim.G;
blkId.Wo %= grid_dim.Wo;
blkId.Ho %= grid_dim.Ho;
// block starting indices
blkStart.Ho = blkId.Ho * block_dim.Ho;
blkStart.Wo = blkId.Wo * block_dim.Wo;
blkStart.G = blkId.G * block_dim.G;
blkStart.Cg = blkId.Cg * block_dim.Cg;
blkStart.C = blkStart.G * param.Cg + blkStart.Cg;
// block size
blkSize.Ho = block_dim.Ho;
blkSize.Wo = block_dim.Wo;
blkSize.G = block_dim.G;
blkSize.Cg = block_dim.Cg;
// take care of blocks near the end of each dimension
if (blkId.Ho == (grid_dim.Ho - 1)) {
blkSize.Ho = param.Ho - (grid_dim.Ho - 1) * block_dim.Ho;
}
if (blkId.Wo == (grid_dim.Wo - 1)) {
blkSize.Wo = param.Wo - (grid_dim.Wo - 1) * block_dim.Wo;
}
if (blkId.G == (grid_dim.G - 1)) {
blkSize.G = param.group_size - (grid_dim.G - 1) * block_dim.G;
}
if (blkId.Cg == (grid_dim.Cg - 1)) {
blkSize.Cg = param.Cg - (grid_dim.Cg - 1) * block_dim.Cg;
}
// block end indices
blkEnd.Ho = blkStart.Ho + blkSize.Ho - 1;
blkEnd.Wo = blkStart.Wo + blkSize.Wo - 1;
// set output_nram to zero
__bang_write_value(output_nram, param.output_nram_size, T(0));
// loop blocks of kernel window: grid_dim.(Kh, Kw)
for (blkId.Kh = 0; blkId.Kh < grid_dim.Kh; ++blkId.Kh) {
blkStart.Kh = blkId.Kh * block_dim.Kh;
blkSize.Kh = block_dim.Kh;
if (blkId.Kh == (grid_dim.Kh - 1)) {
blkSize.Kh = param.kernel_size - (grid_dim.Kh - 1) * block_dim.Kh;
}
blkEnd.Kh = blkStart.Kh + blkSize.Kh - 1;
blkStart.Hi = blkStart.Ho / param.scale_factor - param.kernel_size_half +
blkStart.Kh;
blkEnd.Hi =
blkEnd.Ho / param.scale_factor - param.kernel_size_half + blkEnd.Kh;
blkSize.Hi = blkEnd.Hi - blkStart.Hi + 1;
for (blkId.Kw = 0; blkId.Kw < grid_dim.Kw; ++blkId.Kw) {
blkStart.Kw = blkId.Kw * block_dim.Kw;
blkSize.Kw = block_dim.Kw;
if (blkId.Kw == (grid_dim.Kw - 1)) {
blkSize.Kw = param.kernel_size - (grid_dim.Kw - 1) * block_dim.Kw;
}
blkEnd.Kw = blkStart.Kw + blkSize.Kw - 1;
blkStart.Wi = blkStart.Wo / param.scale_factor -
param.kernel_size_half + blkStart.Kw;
blkEnd.Wi =
blkEnd.Wo / param.scale_factor - param.kernel_size_half + blkEnd.Kw;
blkSize.Wi = blkEnd.Wi - blkStart.Wi + 1;
// load input block from gdram2nram
//
// input_nram[ | input[ sample_idx,
// 0:blkSize.Hi-1, | blkStart.Hi + 0:blkSize.Hi-1,
// 0:blkSize.Wi-1, | blkStart.Wi + 0:blkSize.Wi-1,
// 0:blkSize.G-1 | blkStart.G + 0:blkSize.G-1
// 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1]
//
// To skip out of bound indices:
//
// input_nram[
// hi_start_local:hi_end_local,
// wi_start_local:wi_end_local, ...]
// = input[n,
// hi_start_global:hi_end_global,
// wi_start_global:wi_end_global, ...]
//
int hi_start_local = 0;
int hi_start_global = blkStart.Hi;
if (blkStart.Hi < 0) {
hi_start_local = -blkStart.Hi;
hi_start_global = 0;
}
int wi_start_local = 0;
int wi_start_global = blkStart.Wi;
if (blkStart.Wi < 0) {
wi_start_local = -blkStart.Wi;
wi_start_global = 0;
}
int hi_end_local = blkSize.Hi - 1;
int hi_end_global = blkEnd.Hi;
if (blkEnd.Hi > param.Hi - 1) {
hi_end_global = param.Hi - 1;
hi_end_local -= blkEnd.Hi - hi_end_global;
}
int wi_end_local = blkSize.Wi - 1;
int wi_end_global = blkEnd.Wi;
if (blkEnd.Wi > param.Wi - 1) {
wi_end_global = param.Wi - 1;
wi_end_local -= blkEnd.Wi - wi_end_global;
}
int dst_offset = param.input_nram_stride_h * hi_start_local +
param.input_nram_stride_w * wi_start_local;
T *dst = input_nram + dst_offset;
int src_offset = INDEX3(sample_idx, hi_start_global, wi_start_global,
blkStart.C, param.input_stride_n,
param.input_stride_h, param.input_stride_w);
T *src = input + src_offset;
int input_seg_num_h = hi_end_local - hi_start_local + 1;
int input_seg_num_w = wi_end_local - wi_start_local + 1;
for (int i = 0; i < input_seg_num_h; ++i) {
loadStr3D(dst, src, blkSize.Cg, blkSize.G, input_seg_num_w,
param.input_nram_stride_g, param.input_nram_stride_w,
param.input_stride_g, param.input_stride_w);
dst += param.input_nram_stride_h;
src += param.input_stride_h;
}
/* load mask block from gdram2nram
*
* mask_nram[ | mask[sample_idx,
* 0:blkSize.Ho-1 , | blkStart.Ho + 0:blkSize.Ho-1,
* 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1,
* 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1,
* 0:blkSize.Kh-1, | blkStart.Kh + 0:blkSize.Kh-1,
* 0:blkSize.Kw-1] | blkStart.Kw + 0:blkSize.Kw-1]
*/
src_offset = INDEX3(blkStart.Wo, blkStart.G, blkStart.Kh, blkStart.Kw,
param.mask_stride_w, param.mask_stride_g,
param.mask_stride_kh);
src_offset += sample_idx * param.mask_stride_n +
blkStart.Ho * param.mask_stride_h;
for (int ho = 0; ho < blkSize.Ho; ++ho) {
dst = mask_nram + ho * param.mask_nram_stride_h;
src = mask + src_offset + ho * param.mask_stride_h;
for (int wo = 0; wo < blkSize.Wo; ++wo) {
loadStr3D(dst, src, blkSize.Kw, blkSize.Kh, blkSize.G,
param.mask_nram_stride_kh, param.mask_nram_stride_g,
param.mask_stride_kh, param.mask_stride_g);
dst += param.mask_nram_stride_w;
src += param.mask_stride_w;
}
}
// loop each pixel of the output block
for (int ho = 0; ho < blkSize.Ho; ++ho) {
int kernel_hi_start_global = (blkStart.Ho + ho) / param.scale_factor -
param.kernel_size_half + blkStart.Kh;
int kernel_hi_start_local = kernel_hi_start_global - blkStart.Hi;
// int kernel_hi_end_global = kernel_hi_start_global + blkSize.Kh - 1;
// int kernel_hi_end_local = kernel_hi_end_global - blkStart.Hi;
// exclude out of bound indices which should be ignored
int kh_min = hi_start_local - kernel_hi_start_local > 0
? hi_start_local - kernel_hi_start_local
: 0;
int kh_max = hi_end_local - kernel_hi_start_local < blkSize.Kh - 1
? hi_end_local - kernel_hi_start_local
: blkSize.Kh - 1;
for (int wo = 0; wo < blkSize.Wo; ++wo) {
int kernel_wi_start_global =
(blkStart.Wo + wo) / param.scale_factor -
param.kernel_size_half + blkStart.Kw;
int kernel_wi_start_local = kernel_wi_start_global - blkStart.Wi;
// exclude out of bound indices wwich should be ignored
int kw_min = wi_start_local - kernel_wi_start_local > 0
? wi_start_local - kernel_wi_start_local
: 0;
int kw_max = wi_end_local - kernel_wi_start_local < blkSize.Kw - 1
? wi_end_local - kernel_wi_start_local
: blkSize.Kw - 1;
// output_nram[ho, wo, g, c] = sum(mask_nram[ho, wo, g, kh, kw]
// * input_nram[hi+kh, wi+kw, g, c],
// for (kh,kw) in [0:blkSize.Kw-1] x [0:blkSize.Kh-1])
//
// sum(mask_nram[ho, wo, g, kh, kw]
// * input_nram[hi+kh, wi+kw, g, c], (kh,kw))
//
T *mask_array = mask_nram + param.mask_nram_stride_h * ho +
param.mask_nram_stride_w * wo;
for (int kh = kh_min; kh <= kh_max; ++kh) {
for (int kw = kw_min; kw <= kw_max; ++kw) {
T *src =
input_nram +
param.input_nram_stride_h * (kernel_hi_start_local + kh) +
param.input_nram_stride_w * (kernel_wi_start_local + kw);
int mask_index = param.mask_nram_stride_kh * kh + kw;
// mlutiply mask weight with channels for each channel group
T *sum = sum_array;
for (int g = 0; g < blkSize.G; ++g) {
__bang_mul_scalar(sum, src, mask_array[mask_index],
param.block_Cg_NFU);
//
// NOTE: Since block_Cg_NFU >= block_Cg_stride,
// overlapped writing may occur on sum_array.
// So this loop must be executed in order to
// avoid data contamination, as shown below.
//
// |-----block_Cg_NFU---------|
// xxxxxxxxxxxxxxxxxxxxyyyzzzzz------------
// |---block_Cg_stride---|^^^^^will be overwritten
// in the next iteration.
//
// x: actual data used, y: not used, z: overwritten
//
sum += param.input_nram_stride_g;
src += param.input_nram_stride_g;
mask_index += param.mask_nram_stride_g;
} // loop blk_G
// add array[blk_G * blk_C] to output_nram
dst = output_nram + param.output_nram_stride_h * ho +
param.output_nram_stride_w * wo;
__bang_add(dst, dst, sum_array, param.output_nram_stride_w);
} // end loop blk_Kw
} // end loop blk_Kh
} // end loop blk_Wo
} // end loop blk_Ho
} // end loop grid_dim.Kw
} // end loop grid_dim.Kh
/* write output from nram2gdram
*
* output_nram[ | output[sample_idx,
* 0:blkSize.Ho-1, | blkStart.Ho + 0:blkSize.Ho-1,
* 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1,
* 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1,
* 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1]
*/
int dst_offset = INDEX3(sample_idx, blkStart.Ho, blkStart.Wo, blkStart.C,
param.output_stride_n, param.output_stride_h,
param.output_stride_w);
T *dst = output + dst_offset;
T *src = output_nram;
for (int i = 0; i < blkSize.Ho; ++i) {
storeStr3D(dst, src, blkSize.Cg, blkSize.G, blkSize.Wo,
param.output_stride_g, param.output_stride_w,
param.output_nram_stride_g, param.output_nram_stride_w);
dst += param.output_stride_h;
src += param.output_nram_stride_h;
}
} // end loop N, grid_dim.(Hi,Wi,G,Cg)
}
template <typename T>
__mlu_global__ void MLUBLOCKKernelCarafeForward(
const void *input, const void *mask, const CarafeForwardParam param,
const CarafeForwardBlockDim block_dim, const CarafeForwardGridDim grid_dim,
void *output) {
carafeForwardBLOCK((T *)input, (T *)mask, param, block_dim, grid_dim,
(T *)output);
}
} // namespace forward
namespace backward {
template <typename T>
__mlu_func__ void CarafeCompute(T *input, T *mask, T *grad_output,
T *grad_input, T *grad_mask, const int n,
const int hi, const int wi, const int c,
const int k_up, const int group,
const int scale) {
char *input_buff = nram_buf;
char *mask_buff = input_buff + NRAM_BLOCK;
char *grad_input_buff = mask_buff + NRAM_BLOCK;
char *grad_output_buff = grad_input_buff + NRAM_BLOCK;
char *grad_mask_buff = grad_output_buff + NRAM_BLOCK;
int wo = wi * scale;
int ho = hi * scale;
int out_num = n * ho * wo * group;
int group_size = c / group;
int repeat = out_num / taskDim + (int)(taskId < out_num % taskDim);
int num_align = PAD_DOWN(NRAM_BLOCK / sizeof(T), NFU_ALIGN_SIZE / sizeof(T));
int num_per_loop = group_size / num_align;
int rem_for_loop = group_size % num_align;
int rem_for_loop_align = PAD_UP(rem_for_loop, NFU_ALIGN_SIZE / sizeof(T));
for (int k = 0; k < repeat; k++) {
int iter = k * taskDim + taskId;
int group_k = iter % group;
int w_k = (iter / group) % wo;
int h_k = (iter / wo / group) % ho;
int n_k = (iter / ho / wo / group) % n;
int h_i = h_k / scale;
int w_i = w_k / scale;
int start_h = h_i - ((k_up - 1) / 2);
int end_h = h_i + ((k_up - 1) / 2) + 1;
int start_w = w_i - ((k_up - 1) / 2);
int end_w = w_i + ((k_up - 1) / 2) + 1;
T *base_mask = (T *)mask + n_k * ho * wo * group * k_up * k_up +
h_k * wo * group * k_up * k_up + w_k * group * k_up * k_up +
group_k * k_up * k_up;
T *base_grad_mask = (T *)grad_mask + n_k * ho * wo * group * k_up * k_up +
h_k * wo * group * k_up * k_up +
w_k * group * k_up * k_up + group_k * k_up * k_up;
__bang_write_zero((T *)grad_input_buff, NRAM_BLOCK / sizeof(T));
__bang_write_zero((T *)grad_mask_buff, NRAM_BLOCK / sizeof(T));
__bang_write_zero((T *)grad_output_buff, NRAM_BLOCK / sizeof(T));
__memcpy((T *)mask_buff, (T *)base_mask, k_up * k_up * sizeof(T),
GDRAM2NRAM);
for (int i = 0; i < num_per_loop; i++) {
__bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T));
T *base_grad_output = (T *)grad_output + n_k * ho * wo * c +
h_k * wo * c + w_k * c + group_k * group_size +
i * num_align;
__memcpy((T *)grad_output_buff, (T *)base_grad_output,
num_align * sizeof(T), GDRAM2NRAM);
for (int ih = start_h; ih < end_h; ih++) {
for (int iw = start_w; iw < end_w; iw++) {
if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) {
continue;
}
int mask_ih = ih - h_i + (k_up - 1) / 2;
int mask_iw = iw - w_i + (k_up - 1) / 2;
int mask_index = mask_ih * k_up + mask_iw;
int input_index = n_k * hi * wi * c + ih * wi * c + iw * c +
group_k * group_size + i * num_align;
T *base_input = (T *)input + input_index;
T *base_grad_input = (T *)grad_input + input_index;
__memcpy((T *)input_buff, (T *)base_input, num_align * sizeof(T),
GDRAM2NRAM);
__bang_mul_scalar((T *)grad_input_buff, (T *)grad_output_buff,
((T *)mask_buff)[mask_index], num_align);
__bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input,
(T *)grad_input_buff, num_align);
__bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff,
num_align);
__bang_sumpool((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T),
num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1,
num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1);
__bang_reduce_sum((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T));
((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0];
}
}
}
if (rem_for_loop) {
__bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T));
T *base_grad_output = (T *)grad_output + n_k * ho * wo * c +
h_k * wo * c + w_k * c + group_k * group_size +
num_per_loop * num_align;
__memcpy((T *)grad_output_buff, (T *)base_grad_output,
rem_for_loop * sizeof(T), GDRAM2NRAM);
for (int ih = start_h; ih < end_h; ih++) {
for (int iw = start_w; iw < end_w; iw++) {
if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) {
continue;
}
int mask_ih = ih - h_i + (k_up - 1) / 2;
int mask_iw = iw - w_i + (k_up - 1) / 2;
int mask_index = mask_ih * k_up + mask_iw;
int input_index = n_k * hi * wi * c + ih * wi * c + iw * c +
group_k * group_size + num_per_loop * num_align;
T *base_input = (T *)input + input_index;
T *base_grad_input = (T *)grad_input + input_index;
__memcpy((T *)input_buff, (T *)base_input, rem_for_loop * sizeof(T),
GDRAM2NRAM);
__bang_mul_scalar((T *)grad_input_buff, (T *)grad_output_buff,
((T *)mask_buff)[mask_index], rem_for_loop_align);
__bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input,
(T *)grad_input_buff, rem_for_loop);
__bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff,
rem_for_loop_align);
__bang_sumpool(
(T *)input_buff, (T *)input_buff, NFU_ALIGN_SIZE / sizeof(T),
rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1,
rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1);
__bang_reduce_sum((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T));
((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0];
}
}
}
__memcpy((T *)base_grad_mask, (T *)grad_mask_buff, k_up * k_up * sizeof(T),
NRAM2GDRAM);
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelCarafeBackward(
const void *input, const void *mask, const void *grad_output,
void *grad_input, void *grad_mask, const int n, const int hi, const int wi,
const int c, const int k_up, const int group, const int scale) {
CarafeCompute((T *)input, (T *)mask, (T *)grad_output, (T *)grad_input,
(T *)grad_mask, n, hi, wi, c, k_up, group, scale);
}
} // namespace backward
void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *mask,
const CarafeForwardParam ¶m,
const CarafeForwardBlockDim &block_dim,
const CarafeForwardGridDim &grid_dim, void *output) {
if (d_type == CNRT_FLOAT16) {
forward::MLUBLOCKKernelCarafeForward<half><<<k_dim, k_type, queue>>>(
input, mask, param, block_dim, grid_dim, output);
} else {
forward::MLUBLOCKKernelCarafeForward<float><<<k_dim, k_type, queue>>>(
input, mask, param, block_dim, grid_dim, output);
}
}
void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t dtype,
const void *input, const void *mask,
const void *grad_output, void *grad_input,
void *grad_mask, const int n, const int hi,
const int wi, const int c, const int k_up,
const int group, const int scale) {
if (dtype == CNRT_FLOAT16) {
backward::MLUUnion1KernelCarafeBackward<half><<<k_dim, k_type, queue>>>(
input, mask, grad_output, grad_input, grad_mask, n, hi, wi, c, k_up,
group, scale);
} else {
backward::MLUUnion1KernelCarafeBackward<float><<<k_dim, k_type, queue>>>(
input, mask, grad_output, grad_input, grad_mask, n, hi, wi, c, k_up,
group, scale);
}
}
mmcv/ops/csrc/common/mlu/carafe_utils.hpp
deleted
100644 → 0
View file @
7ff7095c
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CARAFE_UTILS_HPP_
#define CARAFE_UTILS_HPP_
#define NRAM_ALIGN_SIZE 64
struct
CarafeForwardParam
{
int
N
;
// batch size
int
Hi
;
// input height
int
Wi
;
// input width
int
Ci
;
// input channels
int
Ho
;
// output height
int
Wo
;
// output width
int
Cg
;
// channels per group
int
kernel_size
;
// kernel_size
int
group_size
;
// group_size
int
scale_factor
;
// scale_factor
int
kernel_size_half
;
// kernel half size (K-1)/2
int
kernel_size_sq
;
// square of kernel size
int
dtype_size
;
// size of tensor data type
// Host arrays' geometry
int
input_stride_g
;
int
input_stride_w
;
int
input_stride_h
;
int
input_stride_n
;
int
input_size
;
int
mask_stride_kh
;
int
mask_stride_g
;
int
mask_stride_w
;
int
mask_stride_h
;
int
mask_stride_n
;
int
mask_size
;
int
output_stride_g
;
int
output_stride_w
;
int
output_stride_h
;
int
output_stride_n
;
int
output_size
;
// NRAM arrays' geometry
int
input_nram_stride_g
;
int
input_nram_stride_w
;
int
input_nram_stride_h
;
int
input_nram_size
;
int
mask_nram_stride_kh
;
int
mask_nram_stride_g
;
int
mask_nram_stride_w
;
int
mask_nram_stride_h
;
int
mask_nram_size
;
int
output_nram_stride_g
;
int
output_nram_stride_w
;
int
output_nram_stride_h
;
int
output_nram_size
;
// for address/compute alignment
int
align_size_NRAM
;
// for addressing on NRAM
int
align_size_NFU
;
// for NFU operation length
int
block_Cg_NFU
;
// for bang_mul_const
int
job_num
;
// total job number
};
struct
CarafeForwardBlockDim
{
int
Ho
;
// block size of output height
int
Wo
;
// block size of output width
int
Kh
;
// block size of kernel height
int
Kw
;
// block size of kernel width
int
G
;
// block size of groups
int
Cg
;
// block size of channels within a group
int
Hi
;
// block size of input height
int
Wi
;
// block size of input width
};
struct
CarafeForwardGridDim
{
int
Ho
;
// number of blocks of output height
int
Wo
;
int
Kh
;
int
Kw
;
int
G
;
int
Cg
;
};
#endif // CARAFE_UTILS_HPP_
mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp
View file @
2611b990
...
@@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) {
...
@@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) {
return
a
>
b
?
a
:
b
;
return
a
>
b
?
a
:
b
;
}
}
/*!
* @brief loads data from global DRAM to NRAM with 2D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template
<
typename
T
>
__mlu_func__
void
loadStr2D
(
T
*
dst
,
T
*
src
,
const
int
size
,
const
int
dst_str
,
const
int
src_str
,
const
int
seg_num
)
{
if
(
dst_str
==
src_str
&&
size
==
src_str
)
{
__memcpy
(
dst
,
src
,
src_str
*
seg_num
*
sizeof
(
T
),
GDRAM2NRAM
);
}
else
if
((
size
==
src_str
||
src_str
<=
dst_str
)
&&
src_str
*
sizeof
(
T
)
<=
512
)
{
// gather data less than 512Bytes to improve IO efficiency
T
*
tmp
=
(
T
*
)
dst
+
(
dst_str
-
src_str
)
*
seg_num
;
__memcpy
(
tmp
,
src
,
(
src_str
*
(
seg_num
-
1
)
+
size
)
*
sizeof
(
T
),
GDRAM2NRAM
);
if
(
dst_str
!=
src_str
)
{
__memcpy
(
dst
,
tmp
,
size
*
sizeof
(
T
),
NRAM2NRAM
,
dst_str
*
sizeof
(
T
),
src_str
*
sizeof
(
T
),
seg_num
-
1
);
}
}
else
{
__memcpy
(
dst
,
src
,
size
*
sizeof
(
T
),
GDRAM2NRAM
,
dst_str
*
sizeof
(
T
),
src_str
*
sizeof
(
T
),
seg_num
-
1
);
}
}
/*!
* @brief loads data from global DRAM to NRAM with 3D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template
<
typename
T
>
__mlu_func__
void
loadStr3D
(
T
*
dst
,
T
*
src
,
const
int
size
,
const
int
seg_num_in
,
const
int
seg_num_out
,
const
int
dst_str_in
,
const
int
dst_str_out
,
const
int
src_str_in
,
const
int
src_str_out
)
{
T
*
tmp_dst
=
dst
;
T
*
tmp_src
=
src
;
for
(
int
i
=
0
;
i
<
seg_num_out
;
++
i
)
{
loadStr2D
(
tmp_dst
,
tmp_src
,
size
,
dst_str_in
,
src_str_in
,
seg_num_in
);
tmp_src
+=
src_str_out
;
tmp_dst
+=
dst_str_out
;
}
}
/*!
* @brief stores data from NRAM to global DRAM with 2D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template
<
typename
T
>
__mlu_func__
void
storeStr2D
(
T
*
dst
,
T
*
src
,
const
int
size
,
const
int
seg_num
,
const
int
dst_str
,
const
int
src_str
)
{
if
((
size
==
dst_str
&&
dst_str
<=
src_str
)
&&
dst_str
*
sizeof
(
T
)
<=
512
)
{
// gather data less than 512Bytes to improve IO efficiency
if
(
dst_str
!=
src_str
)
{
__memcpy
(
src
,
src
,
size
*
sizeof
(
T
),
NRAM2NRAM
,
dst_str
*
sizeof
(
T
),
src_str
*
sizeof
(
T
),
seg_num
-
1
);
}
__memcpy
(
dst
,
src
,
size
*
seg_num
*
sizeof
(
T
),
NRAM2GDRAM
);
}
else
{
__memcpy
(
dst
,
src
,
size
*
sizeof
(
T
),
NRAM2GDRAM
,
dst_str
*
sizeof
(
T
),
src_str
*
sizeof
(
T
),
seg_num
-
1
);
}
}
/*!
* @brief stores data from NRAM to global DRAM with 3D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template
<
typename
T
>
__mlu_func__
void
storeStr3D
(
T
*
dst
,
T
*
src
,
const
int
size
,
const
int
seg_num_in
,
const
int
seg_num_out
,
const
int
dst_str_in
,
const
int
dst_str_out
,
const
int
src_str_in
,
const
int
src_str_out
)
{
T
*
tmp_dst
=
dst
;
T
*
tmp_src
=
src
;
for
(
int
i
=
0
;
i
<
seg_num_out
;
++
i
)
{
storeStr2D
(
tmp_dst
,
tmp_src
,
size
,
seg_num_in
,
dst_str_in
,
src_str_in
);
tmp_src
+=
src_str_out
;
tmp_dst
+=
dst_str_out
;
}
}
/*!
/*!
* @brief Converts int32 to float32 data type.
* @brief Converts int32 to float32 data type.
*
*
...
...
mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp
View file @
2611b990
...
@@ -9,200 +9,13 @@
...
@@ -9,200 +9,13 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "carafe_utils.hpp"
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelCarafeForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
mask
,
const
CarafeForwardParam
&
param
,
const
CarafeForwardBlockDim
&
block_dim
,
const
CarafeForwardGridDim
&
grid_dim
,
void
*
output
);
void
KernelCarafeBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
dtype
,
const
void
*
input
,
const
void
*
mask
,
const
void
*
grad_output
,
void
*
grad_input
,
void
*
grad_mask
,
const
int
n
,
const
int
hi
,
const
int
wi
,
const
int
c
,
const
int
k_up
,
const
int
group
,
const
int
scale
);
// Get total NRAM usage and set strides of NRAM arrays.
static
void
getNramUsage
(
CarafeForwardParam
*
param
,
CarafeForwardBlockDim
*
block_dim
,
int
*
nram_usage
)
{
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_G, blkDim_Cg]
block_dim
->
Hi
=
CEIL_DIV
(
block_dim
->
Ho
,
param
->
scale_factor
)
+
1
;
block_dim
->
Wi
=
CEIL_DIV
(
block_dim
->
Wo
,
param
->
scale_factor
)
+
1
;
param
->
input_nram_stride_g
=
PAD_UP
(
block_dim
->
Cg
,
param
->
align_size_NRAM
);
param
->
input_nram_stride_w
=
param
->
input_nram_stride_g
*
block_dim
->
G
;
param
->
input_nram_stride_h
=
(
block_dim
->
Wi
+
block_dim
->
Kw
-
1
)
*
param
->
input_nram_stride_w
;
param
->
input_nram_size
=
(
block_dim
->
Hi
+
block_dim
->
Kh
-
1
)
*
param
->
input_nram_stride_h
;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_G, blkDim_Kh, blkDim_Kw]
param
->
mask_nram_stride_kh
=
block_dim
->
Kw
;
param
->
mask_nram_stride_g
=
block_dim
->
Kh
*
param
->
mask_nram_stride_kh
;
param
->
mask_nram_stride_w
=
block_dim
->
G
*
param
->
mask_nram_stride_g
;
param
->
mask_nram_stride_h
=
block_dim
->
Wo
*
param
->
mask_nram_stride_w
;
param
->
mask_nram_size
=
PAD_UP
(
block_dim
->
Ho
*
param
->
mask_nram_stride_h
,
param
->
align_size_NRAM
);
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
param
->
output_nram_stride_g
=
param
->
input_nram_stride_g
;
param
->
output_nram_stride_w
=
PAD_UP
(
param
->
input_nram_stride_w
,
param
->
align_size_NFU
);
param
->
output_nram_stride_h
=
block_dim
->
Wo
*
param
->
output_nram_stride_w
;
param
->
output_nram_size
=
block_dim
->
Ho
*
param
->
output_nram_stride_h
;
// sum_array[blkDim_(G*Cg)]
// ensure the last mul_const on Cg does not exceed memory boundary
int
sum_array_size_bang_mul_const
=
(
block_dim
->
G
-
1
)
*
param
->
input_nram_stride_g
+
PAD_UP
(
param
->
input_nram_stride_g
,
param
->
align_size_NFU
);
int
sum_array_size
=
std
::
max
(
param
->
output_nram_stride_w
,
sum_array_size_bang_mul_const
);
*
nram_usage
=
param
->
input_nram_size
+
param
->
mask_nram_size
+
param
->
output_nram_size
+
sum_array_size
;
}
// Policy Function for Forward
static
void
genPolicyForward
(
CarafeForwardParam
*
param
,
CarafeForwardBlockDim
*
block_dim
,
CarafeForwardGridDim
*
grid_dim
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// device info
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
auto
core_num
=
core_dim
*
cluster_num
;
// maximum NRAM size as the number of <dtype>
auto
max_nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
)
/
param
->
dtype_size
;
// determine grid and block dimensions
// set initial values for block_dim and grid_dim
block_dim
->
Ho
=
param
->
Ho
;
block_dim
->
Wo
=
param
->
Wo
;
block_dim
->
Kh
=
param
->
kernel_size
;
block_dim
->
Kw
=
param
->
kernel_size
;
block_dim
->
G
=
param
->
group_size
;
block_dim
->
Cg
=
param
->
Cg
;
grid_dim
->
Ho
=
1
;
grid_dim
->
Wo
=
1
;
grid_dim
->
Kh
=
1
;
grid_dim
->
Kw
=
1
;
grid_dim
->
G
=
1
;
grid_dim
->
Cg
=
1
;
// decrease the block size to fit in the NRAM.
int
nram_usage
=
0
;
while
(
true
)
{
getNramUsage
(
param
,
block_dim
,
&
nram_usage
);
if
(
nram_usage
>
max_nram_size
)
{
// decrease Ho
// decrease block_Ho and block_Wo evenly
// so that the block is close to a square.
if
(
block_dim
->
Ho
>
1
&&
block_dim
->
Ho
>=
block_dim
->
Wo
)
{
grid_dim
->
Ho
+=
1
;
block_dim
->
Ho
=
CEIL_DIV
(
param
->
Ho
,
grid_dim
->
Ho
);
}
else
if
(
block_dim
->
Wo
>
1
&&
block_dim
->
Wo
>
block_dim
->
Ho
)
{
// decrease Wo
grid_dim
->
Wo
+=
1
;
block_dim
->
Wo
=
CEIL_DIV
(
param
->
Wo
,
grid_dim
->
Wo
);
}
else
if
(
block_dim
->
Kh
>
1
)
{
// decrease Kh
grid_dim
->
Kh
+=
1
;
block_dim
->
Kh
=
CEIL_DIV
(
param
->
kernel_size
,
grid_dim
->
Kh
);
// reset Hi, Wi to maximize NRAM usage
grid_dim
->
Ho
=
1
;
block_dim
->
Ho
=
param
->
Ho
;
grid_dim
->
Wo
=
1
;
block_dim
->
Wo
=
param
->
Wo
;
}
else
if
(
block_dim
->
Kw
>
1
)
{
// decrease Kw
grid_dim
->
Kw
+=
1
;
block_dim
->
Kw
=
CEIL_DIV
(
param
->
kernel_size
,
grid_dim
->
Kw
);
// reset Kh
grid_dim
->
Kh
=
1
;
block_dim
->
Kh
=
param
->
kernel_size
;
}
else
if
(
block_dim
->
G
>
1
)
{
// decrease G
grid_dim
->
G
+=
1
;
block_dim
->
G
=
CEIL_DIV
(
param
->
group_size
,
grid_dim
->
G
);
// reset Kw
grid_dim
->
Kw
=
1
;
block_dim
->
Kw
=
param
->
kernel_size
;
}
else
if
(
block_dim
->
Cg
>
1
)
{
// decrease block_Cg
// This is done in the last since c is the continuous dim
// (input layout is NHWC) and large c can improve
// IO & compute efficiency.
grid_dim
->
Cg
+=
1
;
block_dim
->
Cg
=
CEIL_DIV
(
param
->
Cg
,
grid_dim
->
Cg
);
// reset G
grid_dim
->
G
=
1
;
block_dim
->
G
=
param
->
group_size
;
}
else
{
// the block volume is one now, cannot decrease the block size anymore!
// this situation should not occur.
break
;
}
}
else
{
break
;
}
}
// define parameters depending on block_dim, grid_dim
param
->
block_Cg_NFU
=
PAD_UP
(
block_dim
->
Cg
,
param
->
align_size_NFU
);
// define host arrays' strides
// input[N,H,W,G,Cg]
param
->
input_stride_g
=
param
->
Cg
;
param
->
input_stride_w
=
param
->
Ci
;
param
->
input_stride_h
=
param
->
Wi
*
param
->
input_stride_w
;
param
->
input_stride_n
=
param
->
Hi
*
param
->
input_stride_h
;
// mask[N,Ho,Wo,G,Kh,Kw]
param
->
mask_stride_kh
=
param
->
kernel_size
;
param
->
mask_stride_g
=
param
->
kernel_size
*
param
->
mask_stride_kh
;
param
->
mask_stride_w
=
param
->
group_size
*
param
->
mask_stride_g
;
param
->
mask_stride_h
=
param
->
Wo
*
param
->
mask_stride_w
;
param
->
mask_stride_n
=
param
->
Ho
*
param
->
mask_stride_h
;
// output[N,Ho,Wo,G,Cg]
param
->
output_stride_g
=
param
->
Cg
;
param
->
output_stride_w
=
param
->
Ci
;
param
->
output_stride_h
=
param
->
Wo
*
param
->
output_stride_w
;
param
->
output_stride_n
=
param
->
Ho
*
param
->
output_stride_h
;
param
->
job_num
=
param
->
N
*
grid_dim
->
Ho
*
grid_dim
->
Wo
*
grid_dim
->
G
*
grid_dim
->
Cg
;
// determine task type and dims
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
k_dim
->
x
=
std
::
min
(
param
->
job_num
,
static_cast
<
int
>
(
core_num
));
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
}
void
CARAFEForwardMLUKernelLauncher
(
const
Tensor
input
,
const
Tensor
mask
,
void
CARAFEForwardMLUKernelLauncher
(
const
Tensor
input
,
const
Tensor
mask
,
Tensor
rinput
,
Tensor
routput
,
Tensor
rmask
,
Tensor
rinput
,
Tensor
routput
,
Tensor
rmask
,
Tensor
output
,
const
int
kernel_size
,
Tensor
output
,
const
int
kernel_size
,
const
int
group_size
,
const
int
group_size
,
const
int
scale_factor
)
{
const
int
scale_factor
)
{
const
int
batch_size
=
output
.
size
(
0
);
const
int
channels
=
output
.
size
(
1
);
const
int
ho
=
output
.
size
(
2
);
const
int
wo
=
output
.
size
(
3
);
// check tensor data type
// check tensor data type
TORCH_CHECK
(
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
...
@@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
...
@@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
// return fast on zero-element tensor
// return fast on zero-element tensor
if
(
output
.
numel
()
==
0
)
{
if
(
output
.
numel
()
==
0
)
{
output
=
at
::
zeros
(
{
batch_size
,
channels
,
ho
,
wo
}
,
output
.
options
());
output
=
at
::
zeros
(
output
.
sizes
().
vec
()
,
output
.
options
());
return
;
return
;
}
}
// set param
CarafeForwardParam
param
;
param
.
N
=
input
.
size
(
0
);
param
.
Ci
=
input
.
size
(
1
);
param
.
Hi
=
input
.
size
(
2
);
param
.
Wi
=
input
.
size
(
3
);
param
.
kernel_size
=
kernel_size
;
param
.
group_size
=
group_size
;
param
.
scale_factor
=
scale_factor
;
param
.
Cg
=
param
.
Ci
/
group_size
;
param
.
dtype_size
=
input
.
itemsize
();
param
.
align_size_NRAM
=
NRAM_ALIGN_SIZE
/
param
.
dtype_size
;
param
.
align_size_NFU
=
NFU_ALIGN_SIZE
/
param
.
dtype_size
;
param
.
kernel_size_sq
=
param
.
kernel_size
*
param
.
kernel_size
;
param
.
kernel_size_half
=
(
param
.
kernel_size
-
1
)
/
2
;
param
.
Ho
=
param
.
Hi
*
param
.
scale_factor
;
param
.
Wo
=
param
.
Wi
*
param
.
scale_factor
;
// generate policy
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
CarafeForwardBlockDim
block_dim
;
CarafeForwardGridDim
grid_dim
;
genPolicyForward
(
&
param
,
&
block_dim
,
&
grid_dim
,
&
k_dim
,
&
k_type
);
// convert NCHW to NHWC
// convert NCHW to NHWC
auto
memory_format_input_nhwc
=
auto
memory_format_input_nhwc
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
...
@@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
...
@@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto
routput_
=
auto
routput_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format_output_nhwc
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format_output_nhwc
);
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
mask_desc
,
output_desc
;
input_desc
.
set_with_layout
(
rinput_
,
MLUOP_LAYOUT_NHWC
);
mask_desc
.
set_with_layout
(
rmask_
,
MLUOP_LAYOUT_NHWC
);
output_desc
.
set_with_layout
(
routput_
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
...
@@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
...
@@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
routput_
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
routput_
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
//
g
et
compute queue
//
s
et
op descriptor
auto
queue
=
torch_mlu
::
getCurQueu
e
();
auto
handle
=
mluOpGetCurrentHandl
e
();
mluOpCarafeDescriptor_t
carafe_desc
;
// get dtype of input
mluOpCreateCarafeDescriptor
(
&
carafe_desc
);
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
mluOpSetCarafeDescriptor
(
carafe_desc
,
input
.
dim
(),
kernel_size
,
group_size
,
scale_factor
);
// launch kernel
// launch kernel
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
mluOpCarafeForward
(
handle
,
carafe_desc
,
input_desc
.
desc
(),
input_ptr
,
CNLOG
(
INFO
)
<<
"Launch Kernel KernelCarafeForward<<<Union"
mask_desc
.
desc
(),
mask_ptr
,
output_desc
.
desc
(),
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
output_ptr
);
<<
k_dim
.
z
<<
">>>"
;
// destroy op descriptor
mluOpDestroyCarafeDescriptor
(
carafe_desc
);
KernelCarafeForward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
mask_ptr
,
param
,
block_dim
,
grid_dim
,
output_ptr
);
// copy output from NHWC back into NCHW
// copy output from NHWC back into NCHW
rinput
.
copy_
(
rinput_
);
rinput
.
copy_
(
rinput_
);
output
.
copy_
(
routput_
);
output
.
copy_
(
routput_
);
}
}
// Policy Function for Backward
static
void
policyFuncBackward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// set Union1 Job
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
->
z
=
1
;
}
void
CARAFEBackwardMLUKernelLauncher
(
void
CARAFEBackwardMLUKernelLauncher
(
const
Tensor
grad_output
,
const
Tensor
rinput
,
const
Tensor
mask
,
const
Tensor
grad_output
,
const
Tensor
rinput
,
const
Tensor
mask
,
Tensor
rgrad_output
,
Tensor
rgrad_input_hs
,
Tensor
rgrad_input
,
Tensor
rgrad_output
,
Tensor
rgrad_input_hs
,
Tensor
rgrad_input
,
Tensor
rgrad_mask
,
Tensor
grad_input
,
Tensor
grad_mask
,
Tensor
rgrad_mask
,
Tensor
grad_input
,
Tensor
grad_mask
,
const
int
kernel_size
,
const
int
group_size
,
const
int
scale_factor
)
{
const
int
kernel_size
,
const
int
group_size
,
const
int
scale_factor
)
{
const
int
batch_size
=
rinput
.
size
(
0
);
const
int
channels
=
rinput
.
size
(
1
);
const
int
hi
=
rinput
.
size
(
2
);
const
int
wi
=
rinput
.
size
(
3
);
// data type check
// data type check
TORCH_CHECK
(
grad_output
.
scalar_type
()
==
at
::
kFloat
||
TORCH_CHECK
(
grad_output
.
scalar_type
()
==
at
::
kFloat
||
grad_output
.
scalar_type
()
==
at
::
kHalf
,
grad_output
.
scalar_type
()
==
at
::
kHalf
,
...
@@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher(
...
@@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher(
TORCH_CHECK
(
kernel_size
<
137
,
"kernel_size should be less than 137, got "
,
TORCH_CHECK
(
kernel_size
<
137
,
"kernel_size should be less than 137, got "
,
kernel_size
);
kernel_size
);
// set task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncBackward
(
&
k_dim
,
&
k_type
);
// convert NCHW to NHWC
// convert NCHW to NHWC
auto
memory_format_input_nhwc
=
auto
memory_format_input_nhwc
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
rinput
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
rinput
.
dim
());
...
@@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher(
...
@@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher(
auto
rgrad_mask_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
auto
rgrad_mask_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_mask
,
memory_format_grad_mask_nhwc
);
grad_mask
,
memory_format_grad_mask_nhwc
);
// get compute queue
// set tensor descriptor
auto
queue
=
torch_mlu
::
getCurQueue
();
MluOpTensorDescriptor
input_desc
,
mask_desc
;
input_desc
.
set_with_layout
(
rinput_
,
MLUOP_LAYOUT_NHWC
);
mask_desc
.
set_with_layout
(
rmask_
,
MLUOP_LAYOUT_NHWC
);
MluOpTensorDescriptor
grad_output_desc
,
grad_input_desc
,
grad_mask_desc
;
grad_output_desc
.
set_with_layout
(
rgrad_output_
,
MLUOP_LAYOUT_NHWC
);
grad_input_desc
.
set_with_layout
(
rgrad_input_
,
MLUOP_LAYOUT_NHWC
);
grad_mask_desc
.
set_with_layout
(
rgrad_mask_
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
...
@@ -378,19 +156,19 @@ void CARAFEBackwardMLUKernelLauncher(
...
@@ -378,19 +156,19 @@ void CARAFEBackwardMLUKernelLauncher(
auto
grad_mask_impl
=
torch_mlu
::
getMluTensorImpl
(
rgrad_mask_
);
auto
grad_mask_impl
=
torch_mlu
::
getMluTensorImpl
(
rgrad_mask_
);
auto
grad_mask_ptr
=
grad_mask_impl
->
cnnlMalloc
();
auto
grad_mask_ptr
=
grad_mask_impl
->
cnnlMalloc
();
// get dtype of grad_output
// set op descriptor
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
grad_output
.
dtype
());
auto
handle
=
mluOpGetCurrentHandle
();
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
mluOpCarafeDescriptor_t
carafe_desc
;
mluOpCreateCarafeDescriptor
(
&
carafe_desc
);
CNLOG
(
INFO
)
<<
"Launch Kernel KernelCarafeBackward<<<Union"
mluOpSetCarafeDescriptor
(
carafe_desc
,
grad_output
.
dim
(),
kernel_size
,
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
group_size
,
scale_factor
);
<<
k_dim
.
z
<<
">>>"
;
// launch kernel
// launch kernel
KernelCarafeBackward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
mask_ptr
,
mluOpCarafeBackward
(
handle
,
carafe_desc
,
input_desc
.
desc
(),
input_ptr
,
grad_output_ptr
,
grad_input_ptr
,
grad_mask_ptr
,
mask_desc
.
desc
(),
mask_ptr
,
grad_output_desc
.
desc
(),
batch_size
,
hi
,
wi
,
channels
,
kernel_size
,
group_size
,
grad_output_ptr
,
grad_input_desc
.
desc
(),
grad_input_ptr
,
scale_factor
);
grad_mask_desc
.
desc
(),
grad_mask_ptr
);
// destroy op descriptor
mluOpDestroyCarafeDescriptor
(
carafe_desc
);
// copy output from NHWC back into NCHW
// copy output from NHWC back into NCHW
grad_input
.
copy_
(
rgrad_input_
);
grad_input
.
copy_
(
rgrad_input_
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment