Commit 625e82ce authored by bdf's avatar bdf Committed by Zaida Zhou
Browse files

[Feature] Add carafe op for MLU (#2212)



* [Feature] Support CARAFE with Cambricon MLU backend

* [Docs] Add comments for common funtions

* [Test] Add allclose test for carafe

* Remove print
Co-authored-by: default avatarzcyKTH <zcy19950525@gmail.com>
Co-authored-by: default avatarbudefei <budefei@cambricon.com>
parent f5a19ef0
...@@ -10,7 +10,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -10,7 +10,7 @@ We implement common ops used in detection, segmentation, etc.
| BBoxOverlaps | | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | | | BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | | | BoxIouRotated | √ | √ | | |
| CARAFE | | √ | | | | CARAFE | | √ | | |
| ChamferDistance | | √ | | | | ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | | | CrissCrossAttention | | √ | | |
| ContourExpand | √ | | | | | ContourExpand | √ | | | |
......
...@@ -10,7 +10,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -10,7 +10,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BBoxOverlaps | | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | | | BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | | | BoxIouRotated | √ | √ | | |
| CARAFE | | √ | | | | CARAFE | | √ | | |
| ChamferDistance | | √ | | | | ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | | | CrissCrossAttention | | √ | | |
| ContourExpand | √ | | | | | ContourExpand | √ | | | |
......
...@@ -159,8 +159,6 @@ class CARAFEFunction(Function): ...@@ -159,8 +159,6 @@ class CARAFEFunction(Function):
def backward( def backward(
ctx, ctx,
grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]: grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]:
assert grad_output.is_cuda
features, masks, rfeatures = ctx.saved_tensors features, masks, rfeatures = ctx.saved_tensors
kernel_size = ctx.kernel_size kernel_size = ctx.kernel_size
group_size = ctx.group_size group_size = ctx.group_size
......
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "carafe_utils.hpp"
#include "common_mlu_helper.hpp"
#define INDEX3(n, h, w, c, strN, strH, strW) \
(strN) * (n) + (strH) * (h) + (strW) * (w) + (c)
#define NRAM_BLOCK PAD_DOWN(MAX_NRAM_SIZE / 5, NRAM_ALIGN_SIZE)
__nram__ char nram_buf[MAX_NRAM_SIZE];
namespace forward {
struct BlockId {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
};
// start indices of block
struct BlockStart {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
int C;
};
struct BlockEnd {
int Ho;
int Wo;
int Kh;
int Kw;
int Hi;
int Wi;
};
struct BlockSize {
int Ho;
int Wo;
int G;
int Cg;
int Kh;
int Kw;
int Hi;
int Wi;
};
template <typename T>
__mlu_func__ void carafeForwardBLOCK(T *input, T *mask,
const CarafeForwardParam param,
const CarafeForwardBlockDim block_dim,
const CarafeForwardGridDim grid_dim,
T *output) {
// data block info
BlockId blkId;
BlockStart blkStart;
BlockEnd blkEnd;
BlockSize blkSize;
// set pointers on NRAM arrays
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_(G*Cg)]
T *input_nram = (T *)nram_buf;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Kh*Kw)]
T *mask_nram = input_nram + param.input_nram_size;
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
T *output_nram = mask_nram + param.mask_nram_size;
// sum_array[blkDim_(G*Cg)]
T *sum_array = output_nram + param.output_nram_size;
/* ===== loop over N, grid_dim(Ho,Wo,G,Cg)
* iterations are distributed over computing cores
*/
for (int loop_index = taskId; loop_index < param.job_num;
loop_index += taskDim) {
// block idx
blkId.Cg = loop_index;
blkId.G = blkId.Cg / grid_dim.Cg;
blkId.Wo = blkId.G / grid_dim.G;
blkId.Ho = blkId.Wo / grid_dim.Wo;
int sample_idx = blkId.Ho / grid_dim.Ho;
blkId.Cg %= grid_dim.Cg;
blkId.G %= grid_dim.G;
blkId.Wo %= grid_dim.Wo;
blkId.Ho %= grid_dim.Ho;
// block starting indices
blkStart.Ho = blkId.Ho * block_dim.Ho;
blkStart.Wo = blkId.Wo * block_dim.Wo;
blkStart.G = blkId.G * block_dim.G;
blkStart.Cg = blkId.Cg * block_dim.Cg;
blkStart.C = blkStart.G * param.Cg + blkStart.Cg;
// block size
blkSize.Ho = block_dim.Ho;
blkSize.Wo = block_dim.Wo;
blkSize.G = block_dim.G;
blkSize.Cg = block_dim.Cg;
// take care of blocks near the end of each dimension
if (blkId.Ho == (grid_dim.Ho - 1)) {
blkSize.Ho = param.Ho - (grid_dim.Ho - 1) * block_dim.Ho;
}
if (blkId.Wo == (grid_dim.Wo - 1)) {
blkSize.Wo = param.Wo - (grid_dim.Wo - 1) * block_dim.Wo;
}
if (blkId.G == (grid_dim.G - 1)) {
blkSize.G = param.group_size - (grid_dim.G - 1) * block_dim.G;
}
if (blkId.Cg == (grid_dim.Cg - 1)) {
blkSize.Cg = param.Cg - (grid_dim.Cg - 1) * block_dim.Cg;
}
// block end indices
blkEnd.Ho = blkStart.Ho + blkSize.Ho - 1;
blkEnd.Wo = blkStart.Wo + blkSize.Wo - 1;
// set output_nram to zero
__nramset(output_nram, param.output_nram_size, T(0));
// loop blocks of kernel window: grid_dim.(Kh, Kw)
for (blkId.Kh = 0; blkId.Kh < grid_dim.Kh; ++blkId.Kh) {
blkStart.Kh = blkId.Kh * block_dim.Kh;
blkSize.Kh = block_dim.Kh;
if (blkId.Kh == (grid_dim.Kh - 1)) {
blkSize.Kh = param.kernel_size - (grid_dim.Kh - 1) * block_dim.Kh;
}
blkEnd.Kh = blkStart.Kh + blkSize.Kh - 1;
blkStart.Hi = blkStart.Ho / param.scale_factor - param.kernel_size_half +
blkStart.Kh;
blkEnd.Hi =
blkEnd.Ho / param.scale_factor - param.kernel_size_half + blkEnd.Kh;
blkSize.Hi = blkEnd.Hi - blkStart.Hi + 1;
for (blkId.Kw = 0; blkId.Kw < grid_dim.Kw; ++blkId.Kw) {
blkStart.Kw = blkId.Kw * block_dim.Kw;
blkSize.Kw = block_dim.Kw;
if (blkId.Kw == (grid_dim.Kw - 1)) {
blkSize.Kw = param.kernel_size - (grid_dim.Kw - 1) * block_dim.Kw;
}
blkEnd.Kw = blkStart.Kw + blkSize.Kw - 1;
blkStart.Wi = blkStart.Wo / param.scale_factor -
param.kernel_size_half + blkStart.Kw;
blkEnd.Wi =
blkEnd.Wo / param.scale_factor - param.kernel_size_half + blkEnd.Kw;
blkSize.Wi = blkEnd.Wi - blkStart.Wi + 1;
// load input block from gdram2nram
//
// input_nram[ | input[ sample_idx,
// 0:blkSize.Hi-1, | blkStart.Hi + 0:blkSize.Hi-1,
// 0:blkSize.Wi-1, | blkStart.Wi + 0:blkSize.Wi-1,
// 0:blkSize.G-1 | blkStart.G + 0:blkSize.G-1
// 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1]
//
// To skip out of bound indices:
//
// input_nram[
// hi_start_local:hi_end_local,
// wi_start_local:wi_end_local, ...]
// = input[n,
// hi_start_global:hi_end_global,
// wi_start_global:wi_end_global, ...]
//
int hi_start_local = 0;
int hi_start_global = blkStart.Hi;
if (blkStart.Hi < 0) {
hi_start_local = -blkStart.Hi;
hi_start_global = 0;
}
int wi_start_local = 0;
int wi_start_global = blkStart.Wi;
if (blkStart.Wi < 0) {
wi_start_local = -blkStart.Wi;
wi_start_global = 0;
}
int hi_end_local = blkSize.Hi - 1;
int hi_end_global = blkEnd.Hi;
if (blkEnd.Hi > param.Hi - 1) {
hi_end_global = param.Hi - 1;
hi_end_local -= blkEnd.Hi - hi_end_global;
}
int wi_end_local = blkSize.Wi - 1;
int wi_end_global = blkEnd.Wi;
if (blkEnd.Wi > param.Wi - 1) {
wi_end_global = param.Wi - 1;
wi_end_local -= blkEnd.Wi - wi_end_global;
}
int dst_offset = param.input_nram_stride_h * hi_start_local +
param.input_nram_stride_w * wi_start_local;
T *dst = input_nram + dst_offset;
int src_offset = INDEX3(sample_idx, hi_start_global, wi_start_global,
blkStart.C, param.input_stride_n,
param.input_stride_h, param.input_stride_w);
T *src = input + src_offset;
int input_seg_num_h = hi_end_local - hi_start_local + 1;
int input_seg_num_w = wi_end_local - wi_start_local + 1;
for (int i = 0; i < input_seg_num_h; ++i) {
loadStr3D(dst, src, blkSize.Cg, blkSize.G, input_seg_num_w,
param.input_nram_stride_g, param.input_nram_stride_w,
param.input_stride_g, param.input_stride_w);
dst += param.input_nram_stride_h;
src += param.input_stride_h;
}
/* load mask block from gdram2nram
*
* mask_nram[ | mask[sample_idx,
* 0:blkSize.Ho-1 , | blkStart.Ho + 0:blkSize.Ho-1,
* 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1,
* 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1,
* 0:blkSize.Kh-1, | blkStart.Kh + 0:blkSize.Kh-1,
* 0:blkSize.Kw-1] | blkStart.Kw + 0:blkSize.Kw-1]
*/
src_offset = INDEX3(blkStart.Wo, blkStart.G, blkStart.Kh, blkStart.Kw,
param.mask_stride_w, param.mask_stride_g,
param.mask_stride_kh);
src_offset += sample_idx * param.mask_stride_n +
blkStart.Ho * param.mask_stride_h;
for (int ho = 0; ho < blkSize.Ho; ++ho) {
dst = mask_nram + ho * param.mask_nram_stride_h;
src = mask + src_offset + ho * param.mask_stride_h;
for (int wo = 0; wo < blkSize.Wo; ++wo) {
loadStr3D(dst, src, blkSize.Kw, blkSize.Kh, blkSize.G,
param.mask_nram_stride_kh, param.mask_nram_stride_g,
param.mask_stride_kh, param.mask_stride_g);
dst += param.mask_nram_stride_w;
src += param.mask_stride_w;
}
}
// loop each pixel of the output block
for (int ho = 0; ho < blkSize.Ho; ++ho) {
int kernel_hi_start_global = (blkStart.Ho + ho) / param.scale_factor -
param.kernel_size_half + blkStart.Kh;
int kernel_hi_start_local = kernel_hi_start_global - blkStart.Hi;
// int kernel_hi_end_global = kernel_hi_start_global + blkSize.Kh - 1;
// int kernel_hi_end_local = kernel_hi_end_global - blkStart.Hi;
// exclude out of bound indices which should be ignored
int kh_min = hi_start_local - kernel_hi_start_local > 0
? hi_start_local - kernel_hi_start_local
: 0;
int kh_max = hi_end_local - kernel_hi_start_local < blkSize.Kh - 1
? hi_end_local - kernel_hi_start_local
: blkSize.Kh - 1;
for (int wo = 0; wo < blkSize.Wo; ++wo) {
int kernel_wi_start_global =
(blkStart.Wo + wo) / param.scale_factor -
param.kernel_size_half + blkStart.Kw;
int kernel_wi_start_local = kernel_wi_start_global - blkStart.Wi;
// exclude out of bound indices wwich should be ignored
int kw_min = wi_start_local - kernel_wi_start_local > 0
? wi_start_local - kernel_wi_start_local
: 0;
int kw_max = wi_end_local - kernel_wi_start_local < blkSize.Kw - 1
? wi_end_local - kernel_wi_start_local
: blkSize.Kw - 1;
// output_nram[ho, wo, g, c] = sum(mask_nram[ho, wo, g, kh, kw]
// * input_nram[hi+kh, wi+kw, g, c],
// for (kh,kw) in [0:blkSize.Kw-1] x [0:blkSize.Kh-1])
//
// sum(mask_nram[ho, wo, g, kh, kw]
// * input_nram[hi+kh, wi+kw, g, c], (kh,kw))
//
T *mask_array = mask_nram + param.mask_nram_stride_h * ho +
param.mask_nram_stride_w * wo;
for (int kh = kh_min; kh <= kh_max; ++kh) {
for (int kw = kw_min; kw <= kw_max; ++kw) {
T *src =
input_nram +
param.input_nram_stride_h * (kernel_hi_start_local + kh) +
param.input_nram_stride_w * (kernel_wi_start_local + kw);
int mask_index = param.mask_nram_stride_kh * kh + kw;
// mlutiply mask weight with channels for each channel group
T *sum = sum_array;
for (int g = 0; g < blkSize.G; ++g) {
__bang_mul_const(sum, src, mask_array[mask_index],
param.block_Cg_NFU);
//
// NOTE: Since block_Cg_NFU >= block_Cg_stride,
// overlapped writing may occur on sum_array.
// So this loop must be executed in order to
// avoid data contamination, as shown below.
//
// |-----block_Cg_NFU---------|
// xxxxxxxxxxxxxxxxxxxxyyyzzzzz------------
// |---block_Cg_stride---|^^^^^will be overwritten
// in the next iteration.
//
// x: actual data used, y: not used, z: overwritten
//
sum += param.input_nram_stride_g;
src += param.input_nram_stride_g;
mask_index += param.mask_nram_stride_g;
} // loop blk_G
// add array[blk_G * blk_C] to output_nram
dst = output_nram + param.output_nram_stride_h * ho +
param.output_nram_stride_w * wo;
__bang_add(dst, dst, sum_array, param.output_nram_stride_w);
} // end loop blk_Kw
} // end loop blk_Kh
} // end loop blk_Wo
} // end loop blk_Ho
} // end loop grid_dim.Kw
} // end loop grid_dim.Kh
/* write output from nram2gdram
*
* output_nram[ | output[sample_idx,
* 0:blkSize.Ho-1, | blkStart.Ho + 0:blkSize.Ho-1,
* 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1,
* 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1,
* 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1]
*/
int dst_offset = INDEX3(sample_idx, blkStart.Ho, blkStart.Wo, blkStart.C,
param.output_stride_n, param.output_stride_h,
param.output_stride_w);
T *dst = output + dst_offset;
T *src = output_nram;
for (int i = 0; i < blkSize.Ho; ++i) {
storeStr3D(dst, src, blkSize.Cg, blkSize.G, blkSize.Wo,
param.output_stride_g, param.output_stride_w,
param.output_nram_stride_g, param.output_nram_stride_w);
dst += param.output_stride_h;
src += param.output_nram_stride_h;
}
} // end loop N, grid_dim.(Hi,Wi,G,Cg)
}
template <typename T>
__mlu_global__ void MLUBLOCKKernelCarafeForward(
const void *input, const void *mask, const CarafeForwardParam param,
const CarafeForwardBlockDim block_dim, const CarafeForwardGridDim grid_dim,
void *output) {
carafeForwardBLOCK((T *)input, (T *)mask, param, block_dim, grid_dim,
(T *)output);
}
} // namespace forward
namespace backward {
template <typename T>
__mlu_func__ void CarafeCompute(T *input, T *mask, T *grad_output,
T *grad_input, T *grad_mask, const int n,
const int hi, const int wi, const int c,
const int k_up, const int group,
const int scale) {
char *input_buff = nram_buf;
char *mask_buff = input_buff + NRAM_BLOCK;
char *grad_input_buff = mask_buff + NRAM_BLOCK;
char *grad_output_buff = grad_input_buff + NRAM_BLOCK;
char *grad_mask_buff = grad_output_buff + NRAM_BLOCK;
int wo = wi * scale;
int ho = hi * scale;
int out_num = n * ho * wo * group;
int group_size = c / group;
int repeat = out_num / taskDim + (int)(taskId < out_num % taskDim);
int num_align = PAD_DOWN(NRAM_BLOCK / sizeof(T), NFU_ALIGN_SIZE / sizeof(T));
int num_per_loop = group_size / num_align;
int rem_for_loop = group_size % num_align;
int rem_for_loop_align = PAD_UP(rem_for_loop, NFU_ALIGN_SIZE / sizeof(T));
for (int k = 0; k < repeat; k++) {
int iter = k * taskDim + taskId;
int group_k = iter % group;
int w_k = (iter / group) % wo;
int h_k = (iter / wo / group) % ho;
int n_k = (iter / ho / wo / group) % n;
int h_i = h_k / scale;
int w_i = w_k / scale;
int start_h = h_i - ((k_up - 1) / 2);
int end_h = h_i + ((k_up - 1) / 2) + 1;
int start_w = w_i - ((k_up - 1) / 2);
int end_w = w_i + ((k_up - 1) / 2) + 1;
T *base_mask = (T *)mask + n_k * ho * wo * group * k_up * k_up +
h_k * wo * group * k_up * k_up + w_k * group * k_up * k_up +
group_k * k_up * k_up;
T *base_grad_mask = (T *)grad_mask + n_k * ho * wo * group * k_up * k_up +
h_k * wo * group * k_up * k_up +
w_k * group * k_up * k_up + group_k * k_up * k_up;
__bang_write_zero((T *)grad_input_buff, NRAM_BLOCK / sizeof(T));
__bang_write_zero((T *)grad_mask_buff, NRAM_BLOCK / sizeof(T));
__bang_write_zero((T *)grad_output_buff, NRAM_BLOCK / sizeof(T));
__memcpy((T *)mask_buff, (T *)base_mask, k_up * k_up * sizeof(T),
GDRAM2NRAM);
for (int i = 0; i < num_per_loop; i++) {
__bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T));
T *base_grad_output = (T *)grad_output + n_k * ho * wo * c +
h_k * wo * c + w_k * c + group_k * group_size +
i * num_align;
__memcpy((T *)grad_output_buff, (T *)base_grad_output,
num_align * sizeof(T), GDRAM2NRAM);
for (int ih = start_h; ih < end_h; ih++) {
for (int iw = start_w; iw < end_w; iw++) {
if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) {
continue;
}
int mask_ih = ih - h_i + (k_up - 1) / 2;
int mask_iw = iw - w_i + (k_up - 1) / 2;
int mask_index = mask_ih * k_up + mask_iw;
int input_index = n_k * hi * wi * c + ih * wi * c + iw * c +
group_k * group_size + i * num_align;
T *base_input = (T *)input + input_index;
T *base_grad_input = (T *)grad_input + input_index;
__memcpy((T *)input_buff, (T *)base_input, num_align * sizeof(T),
GDRAM2NRAM);
__bang_mul_const((T *)grad_input_buff, (T *)grad_output_buff,
((T *)mask_buff)[mask_index], num_align);
__bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input,
(T *)grad_input_buff, num_align);
__bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff,
num_align);
__bang_sumpool((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T),
num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1,
num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1);
__bang_reduce_sum((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T));
((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0];
}
}
}
if (rem_for_loop) {
__bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T));
T *base_grad_output = (T *)grad_output + n_k * ho * wo * c +
h_k * wo * c + w_k * c + group_k * group_size +
num_per_loop * num_align;
__memcpy((T *)grad_output_buff, (T *)base_grad_output,
rem_for_loop * sizeof(T), GDRAM2NRAM);
for (int ih = start_h; ih < end_h; ih++) {
for (int iw = start_w; iw < end_w; iw++) {
if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) {
continue;
}
int mask_ih = ih - h_i + (k_up - 1) / 2;
int mask_iw = iw - w_i + (k_up - 1) / 2;
int mask_index = mask_ih * k_up + mask_iw;
int input_index = n_k * hi * wi * c + ih * wi * c + iw * c +
group_k * group_size + num_per_loop * num_align;
T *base_input = (T *)input + input_index;
T *base_grad_input = (T *)grad_input + input_index;
__memcpy((T *)input_buff, (T *)base_input, rem_for_loop * sizeof(T),
GDRAM2NRAM);
__bang_mul_const((T *)grad_input_buff, (T *)grad_output_buff,
((T *)mask_buff)[mask_index], rem_for_loop_align);
__bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input,
(T *)grad_input_buff, rem_for_loop);
__bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff,
rem_for_loop_align);
__bang_sumpool(
(T *)input_buff, (T *)input_buff, NFU_ALIGN_SIZE / sizeof(T),
rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1,
rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1);
__bang_reduce_sum((T *)input_buff, (T *)input_buff,
NFU_ALIGN_SIZE / sizeof(T));
((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0];
}
}
}
__memcpy((T *)base_grad_mask, (T *)grad_mask_buff, k_up * k_up * sizeof(T),
NRAM2GDRAM);
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelCarafeBackward(
const void *input, const void *mask, const void *grad_output,
void *grad_input, void *grad_mask, const int n, const int hi, const int wi,
const int c, const int k_up, const int group, const int scale) {
CarafeCompute((T *)input, (T *)mask, (T *)grad_output, (T *)grad_input,
(T *)grad_mask, n, hi, wi, c, k_up, group, scale);
}
} // namespace backward
void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *mask,
const CarafeForwardParam &param,
const CarafeForwardBlockDim &block_dim,
const CarafeForwardGridDim &grid_dim, void *output) {
if (d_type == CNRT_FLOAT16) {
forward::MLUBLOCKKernelCarafeForward<half><<<k_dim, k_type, queue>>>(
input, mask, param, block_dim, grid_dim, output);
} else {
forward::MLUBLOCKKernelCarafeForward<float><<<k_dim, k_type, queue>>>(
input, mask, param, block_dim, grid_dim, output);
}
}
void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t dtype,
const void *input, const void *mask,
const void *grad_output, void *grad_input,
void *grad_mask, const int n, const int hi,
const int wi, const int c, const int k_up,
const int group, const int scale) {
if (dtype == CNRT_FLOAT16) {
backward::MLUUnion1KernelCarafeBackward<half>
<<<k_dim, k_type, queue>>>(input, mask, grad_output, grad_input,
grad_mask, n, hi, wi, c, k_up, group, scale);
} else {
backward::MLUUnion1KernelCarafeBackward<float>
<<<k_dim, k_type, queue>>>(input, mask, grad_output, grad_input,
grad_mask, n, hi, wi, c, k_up, group, scale);
}
}
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CARAFE_UTILS_HPP_
#define CARAFE_UTILS_HPP_
#define NRAM_ALIGN_SIZE 64
struct CarafeForwardParam {
int N; // batch size
int Hi; // input height
int Wi; // input width
int Ci; // input channels
int Ho; // output height
int Wo; // output width
int Cg; // channels per group
int kernel_size; // kernel_size
int group_size; // group_size
int scale_factor; // scale_factor
int kernel_size_half; // kernel half size (K-1)/2
int kernel_size_sq; // square of kernel size
int dtype_size; // size of tensor data type
// Host arrays' geometry
int input_stride_g;
int input_stride_w;
int input_stride_h;
int input_stride_n;
int input_size;
int mask_stride_kh;
int mask_stride_g;
int mask_stride_w;
int mask_stride_h;
int mask_stride_n;
int mask_size;
int output_stride_g;
int output_stride_w;
int output_stride_h;
int output_stride_n;
int output_size;
// NRAM arrays' geometry
int input_nram_stride_g;
int input_nram_stride_w;
int input_nram_stride_h;
int input_nram_size;
int mask_nram_stride_kh;
int mask_nram_stride_g;
int mask_nram_stride_w;
int mask_nram_stride_h;
int mask_nram_size;
int output_nram_stride_g;
int output_nram_stride_w;
int output_nram_stride_h;
int output_nram_size;
// for address/compute alignment
int align_size_NRAM; // for addressing on NRAM
int align_size_NFU; // for NFU operation length
int block_Cg_NFU; // for bang_mul_const
int job_num; // total job number
};
struct CarafeForwardBlockDim {
int Ho; // block size of output height
int Wo; // block size of output width
int Kh; // block size of kernel height
int Kw; // block size of kernel width
int G; // block size of groups
int Cg; // block size of channels within a group
int Hi; // block size of input height
int Wi; // block size of input width
};
struct CarafeForwardGridDim {
int Ho; // number of blocks of output height
int Wo;
int Kh;
int Kw;
int G;
int Cg;
};
#endif // CARAFE_UTILS_HPP_
...@@ -45,6 +45,148 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) { ...@@ -45,6 +45,148 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) {
return a > b ? a : b; return a > b ? a : b;
} }
/*!
* @brief loads data from global DRAM to NRAM with 2D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void loadStr2D(T *dst, T *src, const int size, const int dst_str,
const int src_str, const int seg_num) {
if (dst_str == src_str && size == src_str) {
__memcpy(dst, src, src_str * seg_num * sizeof(T), GDRAM2NRAM);
} else if ((size == src_str || src_str <= dst_str) &&
src_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
T *tmp = (T *)dst + (dst_str - src_str) * seg_num;
__memcpy(tmp, src, (src_str * (seg_num - 1) + size) * sizeof(T),
GDRAM2NRAM);
if (dst_str != src_str) {
__memcpy(dst, tmp, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
} else {
__memcpy(dst, src, size * sizeof(T), GDRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief loads data from global DRAM to NRAM with 3D pattern.
*
* @param[out] dst
* Pointer to NRAM that stores dst data.
* @param[in] src
* Pointer to global DRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void loadStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
loadStr2D(tmp_dst, tmp_src, size, dst_str_in, src_str_in, seg_num_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*!
* @brief stores data from NRAM to global DRAM with 2D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lower dimension.
* @param[in] dst_str
* The data stride in bytes between segments in the lower dimension of dst.
* @param[in] src_str
* The data stride in bytes between segments in the lower dimension of src.
* @param[in] seg_num
* The total count of data segments in the lower dimension.
*/
template <typename T>
__mlu_func__ void storeStr2D(T *dst, T *src, const int size, const int seg_num,
const int dst_str, const int src_str) {
if ((size == dst_str && dst_str <= src_str) && dst_str * sizeof(T) <= 512) {
// gather data less than 512Bytes to improve IO efficiency
if (dst_str != src_str) {
__memcpy(src, src, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
__memcpy(dst, src, size * seg_num * sizeof(T), NRAM2GDRAM);
} else {
__memcpy(dst, src, size * sizeof(T), NRAM2GDRAM, dst_str * sizeof(T),
src_str * sizeof(T), seg_num - 1);
}
}
/*!
* @brief stores data from NRAM to global DRAM with 3D pattern.
*
* @param[out] dst
* Pointer to global DRAM that stores dst data.
* @param[in] src
* Pointer to NRAM that stores src data.
* @param[in] size
* The byte size of segment in the lowest dimension.
* @param[in] seg_num_in
* The total count of data segments in the lowest dimension.
* @param[in] seg_num_out
* The total count of data segments in the middle dimension.
* @param[in] dst_str_in
* The data stride in bytes between segments in the lowest dimension of dst.
* @param[in] dst_str_out
* The data stride in bytes between segments in the middle dimension of dst.
* @param[in] src_str_in
* The data stride in bytes between segments in the lowest dimension of src.
* @param[in] src_str_out
* The data stride in bytes between segments in the middle dimension of src.
*/
template <typename T>
__mlu_func__ void storeStr3D(T *dst, T *src, const int size,
const int seg_num_in, const int seg_num_out,
const int dst_str_in, const int dst_str_out,
const int src_str_in, const int src_str_out) {
T *tmp_dst = dst;
T *tmp_src = src;
for (int i = 0; i < seg_num_out; ++i) {
storeStr2D(tmp_dst, tmp_src, size, seg_num_in, dst_str_in, src_str_in);
tmp_src += src_str_out;
tmp_dst += dst_str_out;
}
}
/*! /*!
* @brief Converts int32 to float32 data type. * @brief Converts int32 to float32 data type.
* *
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
#define PAD_DOWN(x, y) (((x) / (y)) * (y)) #define PAD_DOWN(x, y) (((x) / (y)) * (y))
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
#define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y)) #define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y))
#endif #endif // MMCV_WITH_MLU
#endif // PYTORCH_MLU_HELPER_HPP_ #endif // PYTORCH_MLU_HELPER_HPP_
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "carafe_utils.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, const cnrtDataType_t d_type,
const void *input, const void *mask,
const CarafeForwardParam &param,
const CarafeForwardBlockDim &block_dim,
const CarafeForwardGridDim &grid_dim, void *output);
void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t dtype,
const void *input, const void *mask,
const void *grad_output, void *grad_input,
void *grad_mask, const int n, const int hi,
const int wi, const int c, const int k_up,
const int group, const int scale);
// Get total NRAM usage and set strides of NRAM arrays.
static void getNramUsage(CarafeForwardParam *param,
CarafeForwardBlockDim *block_dim, int *nram_usage) {
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_G, blkDim_Cg]
block_dim->Hi = CEIL_DIV(block_dim->Ho, param->scale_factor) + 1;
block_dim->Wi = CEIL_DIV(block_dim->Wo, param->scale_factor) + 1;
param->input_nram_stride_g = PAD_UP(block_dim->Cg, param->align_size_NRAM);
param->input_nram_stride_w = param->input_nram_stride_g * block_dim->G;
param->input_nram_stride_h =
(block_dim->Wi + block_dim->Kw - 1) * param->input_nram_stride_w;
param->input_nram_size =
(block_dim->Hi + block_dim->Kh - 1) * param->input_nram_stride_h;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_G, blkDim_Kh, blkDim_Kw]
param->mask_nram_stride_kh = block_dim->Kw;
param->mask_nram_stride_g = block_dim->Kh * param->mask_nram_stride_kh;
param->mask_nram_stride_w = block_dim->G * param->mask_nram_stride_g;
param->mask_nram_stride_h = block_dim->Wo * param->mask_nram_stride_w;
param->mask_nram_size =
PAD_UP(block_dim->Ho * param->mask_nram_stride_h, param->align_size_NRAM);
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
param->output_nram_stride_g = param->input_nram_stride_g;
param->output_nram_stride_w =
PAD_UP(param->input_nram_stride_w, param->align_size_NFU);
param->output_nram_stride_h = block_dim->Wo * param->output_nram_stride_w;
param->output_nram_size = block_dim->Ho * param->output_nram_stride_h;
// sum_array[blkDim_(G*Cg)]
// ensure the last mul_const on Cg does not exceed memory boundary
int sum_array_size_bang_mul_const =
(block_dim->G - 1) * param->input_nram_stride_g +
PAD_UP(param->input_nram_stride_g, param->align_size_NFU);
int sum_array_size =
std::max(param->output_nram_stride_w, sum_array_size_bang_mul_const);
*nram_usage = param->input_nram_size + param->mask_nram_size +
param->output_nram_size + sum_array_size;
}
// Policy Function for Forward
static void genPolicyForward(CarafeForwardParam *param,
CarafeForwardBlockDim *block_dim,
CarafeForwardGridDim *grid_dim, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
// device info
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
auto core_num = core_dim * cluster_num;
// maximum NRAM size as the number of <dtype>
auto max_nram_size =
torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore) / param->dtype_size;
// determine grid and block dimensions
// set initial values for block_dim and grid_dim
block_dim->Ho = param->Ho;
block_dim->Wo = param->Wo;
block_dim->Kh = param->kernel_size;
block_dim->Kw = param->kernel_size;
block_dim->G = param->group_size;
block_dim->Cg = param->Cg;
grid_dim->Ho = 1;
grid_dim->Wo = 1;
grid_dim->Kh = 1;
grid_dim->Kw = 1;
grid_dim->G = 1;
grid_dim->Cg = 1;
// decrease the block size to fit in the NRAM.
int nram_usage = 0;
while (true) {
getNramUsage(param, block_dim, &nram_usage);
if (nram_usage > max_nram_size) {
// decrease Ho
// decrease block_Ho and block_Wo evenly
// so that the block is close to a square.
if (block_dim->Ho > 1 && block_dim->Ho >= block_dim->Wo) {
grid_dim->Ho += 1;
block_dim->Ho = CEIL_DIV(param->Ho, grid_dim->Ho);
} else if (block_dim->Wo > 1 && block_dim->Wo > block_dim->Ho) {
// decrease Wo
grid_dim->Wo += 1;
block_dim->Wo = CEIL_DIV(param->Wo, grid_dim->Wo);
} else if (block_dim->Kh > 1) {
// decrease Kh
grid_dim->Kh += 1;
block_dim->Kh = CEIL_DIV(param->kernel_size, grid_dim->Kh);
// reset Hi, Wi to maximize NRAM usage
grid_dim->Ho = 1;
block_dim->Ho = param->Ho;
grid_dim->Wo = 1;
block_dim->Wo = param->Wo;
} else if (block_dim->Kw > 1) {
// decrease Kw
grid_dim->Kw += 1;
block_dim->Kw = CEIL_DIV(param->kernel_size, grid_dim->Kw);
// reset Kh
grid_dim->Kh = 1;
block_dim->Kh = param->kernel_size;
} else if (block_dim->G > 1) {
// decrease G
grid_dim->G += 1;
block_dim->G = CEIL_DIV(param->group_size, grid_dim->G);
// reset Kw
grid_dim->Kw = 1;
block_dim->Kw = param->kernel_size;
} else if (block_dim->Cg > 1) {
// decrease block_Cg
// This is done in the last since c is the continuous dim
// (input layout is NHWC) and large c can improve
// IO & compute efficiency.
grid_dim->Cg += 1;
block_dim->Cg = CEIL_DIV(param->Cg, grid_dim->Cg);
// reset G
grid_dim->G = 1;
block_dim->G = param->group_size;
} else {
// the block volume is one now, cannot decrease the block size anymore!
// this situation should not occur.
break;
}
} else {
break;
}
}
// define parameters depending on block_dim, grid_dim
param->block_Cg_NFU = PAD_UP(block_dim->Cg, param->align_size_NFU);
// define host arrays' strides
// input[N,H,W,G,Cg]
param->input_stride_g = param->Cg;
param->input_stride_w = param->Ci;
param->input_stride_h = param->Wi * param->input_stride_w;
param->input_stride_n = param->Hi * param->input_stride_h;
// mask[N,Ho,Wo,G,Kh,Kw]
param->mask_stride_kh = param->kernel_size;
param->mask_stride_g = param->kernel_size * param->mask_stride_kh;
param->mask_stride_w = param->group_size * param->mask_stride_g;
param->mask_stride_h = param->Wo * param->mask_stride_w;
param->mask_stride_n = param->Ho * param->mask_stride_h;
// output[N,Ho,Wo,G,Cg]
param->output_stride_g = param->Cg;
param->output_stride_w = param->Ci;
param->output_stride_h = param->Wo * param->output_stride_w;
param->output_stride_n = param->Ho * param->output_stride_h;
param->job_num =
param->N * grid_dim->Ho * grid_dim->Wo * grid_dim->G * grid_dim->Cg;
// determine task type and dims
*k_type = CNRT_FUNC_TYPE_BLOCK;
k_dim->x = std::min(param->job_num, static_cast<int>(core_num));
k_dim->y = 1;
k_dim->z = 1;
}
void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
Tensor rinput, Tensor routput, Tensor rmask,
Tensor output, const int kernel_size,
const int group_size,
const int scale_factor) {
const int batch_size = output.size(0);
const int channels = output.size(1);
const int ho = output.size(2);
const int wo = output.size(3);
// check tensor data type
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"Data type of input should be Float or Half. But now input type is ",
input.scalar_type(), ".");
TORCH_CHECK(mask.scalar_type() == input.scalar_type(),
"Data types of input and mask should be the same, but got ",
input.scalar_type(), " and ", mask.scalar_type());
// check number of dimensions
TORCH_CHECK(input.dim() == 4, "input should be a 4-D tensor, but has ",
input.dim(), "D.");
TORCH_CHECK(mask.dim() == 4, "mask should be a 4-D tensor, but has ",
input.dim(), "D.");
// return fast on zero-element tensor
if (output.numel() == 0) {
output = at::zeros({batch_size, channels, ho, wo}, output.options());
return;
}
// set param
CarafeForwardParam param;
param.N = input.size(0);
param.Ci = input.size(1);
param.Hi = input.size(2);
param.Wi = input.size(3);
param.kernel_size = kernel_size;
param.group_size = group_size;
param.scale_factor = scale_factor;
param.Cg = param.Ci / group_size;
param.dtype_size = input.itemsize();
param.align_size_NRAM = NRAM_ALIGN_SIZE / param.dtype_size;
param.align_size_NFU = NFU_ALIGN_SIZE / param.dtype_size;
param.kernel_size_sq = param.kernel_size * param.kernel_size;
param.kernel_size_half = (param.kernel_size - 1) / 2;
param.Ho = param.Hi * param.scale_factor;
param.Wo = param.Wi * param.scale_factor;
// generate policy
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
CarafeForwardBlockDim block_dim;
CarafeForwardGridDim grid_dim;
genPolicyForward(&param, &block_dim, &grid_dim, &k_dim, &k_type);
// convert NCHW to NHWC
auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto rinput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format_input_nhwc);
auto memory_format_mask_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(mask.dim());
auto rmask_ =
torch_mlu::cnnl::ops::cnnl_contiguous(mask, memory_format_mask_nhwc);
auto memory_format_output_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(output.dim());
auto routput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format_output_nhwc);
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
auto input_ptr = input_impl->cnnlMalloc();
auto mask_impl = torch_mlu::getMluTensorImpl(rmask_);
auto mask_ptr = mask_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(routput_);
auto output_ptr = output_impl->cnnlMalloc();
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get dtype of input
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype());
// launch kernel
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
CNLOG(INFO) << "Launch Kernel KernelCarafeForward<<<Union"
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", "
<< k_dim.z << ">>>";
KernelCarafeForward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, param,
block_dim, grid_dim, output_ptr);
// copy output from NHWC back into NCHW
rinput.copy_(rinput_);
output.copy_(routput_);
}
// Policy Function for Backward
static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
// set Union1 Job
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
}
void CARAFEBackwardMLUKernelLauncher(
const Tensor grad_output, const Tensor rinput, const Tensor mask,
Tensor rgrad_output, Tensor rgrad_input_hs, Tensor rgrad_input,
Tensor rgrad_mask, Tensor grad_input, Tensor grad_mask,
const int kernel_size, const int group_size, const int scale_factor) {
const int batch_size = rinput.size(0);
const int channels = rinput.size(1);
const int hi = rinput.size(2);
const int wi = rinput.size(3);
// data type check
TORCH_CHECK(grad_output.scalar_type() == at::kFloat ||
grad_output.scalar_type() == at::kHalf,
"grad_output type should be Float or Half, got ",
grad_output.scalar_type());
TORCH_CHECK(grad_output.scalar_type() == mask.scalar_type(),
"mask should have the same type as grad_output");
// dim check
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be a 4d tensor, got ",
grad_output.dim(), "D");
// param check
TORCH_CHECK(kernel_size < 137, "kernel_size should be less than 137, got ",
kernel_size);
// set task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncBackward(&k_dim, &k_type);
// convert NCHW to NHWC
auto memory_format_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(rinput.dim());
auto rinput_ =
torch_mlu::cnnl::ops::cnnl_contiguous(rinput, memory_format_input_nhwc);
auto memory_format_mask_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(mask.dim());
auto rmask_ =
torch_mlu::cnnl::ops::cnnl_contiguous(mask, memory_format_mask_nhwc);
auto memory_format_grad_output_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
auto rgrad_output_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_output, memory_format_grad_output_nhwc);
auto memory_format_grad_input_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_input.dim());
auto rgrad_input_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_input, memory_format_grad_input_nhwc)
.zero_();
auto memory_format_grad_mask_nhwc =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_mask.dim());
auto rgrad_mask_ = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_mask, memory_format_grad_mask_nhwc);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(rinput_);
auto input_ptr = input_impl->cnnlMalloc();
auto mask_impl = torch_mlu::getMluTensorImpl(rmask_);
auto mask_ptr = mask_impl->cnnlMalloc();
auto grad_output_impl = torch_mlu::getMluTensorImpl(rgrad_output_);
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto grad_input_impl = torch_mlu::getMluTensorImpl(rgrad_input_);
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
auto grad_mask_impl = torch_mlu::getMluTensorImpl(rgrad_mask_);
auto grad_mask_ptr = grad_mask_impl->cnnlMalloc();
// get dtype of grad_output
cnrtDataType_t d_type = torch_mlu::toCnrtDtype(grad_output.dtype());
auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
CNLOG(INFO) << "Launch Kernel KernelCarafeBackward<<<Union"
<< k_type / core_dim << ", " << k_dim.x << ", " << k_dim.y << ", "
<< k_dim.z << ">>>";
// launch kernel
KernelCarafeBackward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr,
grad_output_ptr, grad_input_ptr, grad_mask_ptr,
batch_size, hi, wi, channels, kernel_size, group_size,
scale_factor);
// copy output from NHWC back into NCHW
grad_input.copy_(rgrad_input_);
grad_mask.copy_(rgrad_mask_);
}
void carafe_forward_mlu(Tensor features, Tensor masks, Tensor rfeatures,
Tensor routput, Tensor rmasks, Tensor output,
int kernel_size, int group_size, int scale_factor) {
CARAFEForwardMLUKernelLauncher(features, masks, rfeatures, routput, rmasks,
output, kernel_size, group_size, scale_factor);
}
void carafe_backward_mlu(Tensor top_grad, Tensor rfeatures, Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs,
Tensor rbottom_grad, Tensor rmask_grad,
Tensor bottom_grad, Tensor mask_grad, int kernel_size,
int group_size, int scale_factor) {
CARAFEBackwardMLUKernelLauncher(top_grad, rfeatures, masks, rtop_grad,
rbottom_grad_hs, rbottom_grad, rmask_grad,
bottom_grad, mask_grad, kernel_size,
group_size, scale_factor);
}
void carafe_forward_impl(Tensor features, Tensor masks, Tensor rfeatures,
Tensor routput, Tensor rmasks, Tensor output,
int kernel_size, int group_size, int scale_factor);
void carafe_backward_impl(Tensor top_grad, Tensor rfeatures, Tensor masks,
Tensor rtop_grad, Tensor rbottom_grad_hs,
Tensor rbottom_grad, Tensor rmask_grad,
Tensor bottom_grad, Tensor mask_grad, int kernel_size,
int group_size, int scale_factor);
REGISTER_DEVICE_IMPL(carafe_forward_impl, MLU, carafe_forward_mlu);
REGISTER_DEVICE_IMPL(carafe_backward_impl, MLU, carafe_backward_mlu);
¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A¾š”A †™AÆ>ÀAÚd¦A–"~A܈A^a‰AŒA–J™A
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA
 ‰AÍ
A0‡A6˜žAE|‡AyëAò®Ab•ANÅhA¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹A¨A)à“A’ö±A%ÕˆAZwƒA;—AŽ›¡Ae²¥A,÷‹AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£AWl›AÝæƒA“A8ƒŠAgÖŠA‡ºƒA ”ŽA^øxAõ£A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A8D¥A‚Ö„AÒ{ƒAFÜ—AÕïA‡×—A–€ATi˜A`‰A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹A•Ç‹A‰ç ALlAlžzAÙ¦lAÂ.©Aû ‘A*ûœAYO‹Aÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽAÿ£‘AŽ„›A§ Auq›AÒ^Aöp€AΚAàžŽA*ŽA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA˼™A¸»‰AhZ“AgAÄ’’AB\›AÖy€Aô‰A¤ÐŒA
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class TestCarafe: class TestCarafe:
...@@ -26,3 +30,56 @@ class TestCarafe: ...@@ -26,3 +30,56 @@ class TestCarafe:
2, 100, 6, 6, requires_grad=True, 2, 100, 6, 6, requires_grad=True,
device='cuda').sigmoid().double() device='cuda').sigmoid().double()
gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_carafe_allclose(self, device):
try:
from mmcv.ops import CARAFE
except ModuleNotFoundError:
pytest.skip('test requires compilation')
np_feat = np.fromfile(
'tests/data/for_carafe/carafe_feat.bin', dtype=np.float32)
np_mask = np.fromfile(
'tests/data/for_carafe/carafe_mask.bin', dtype=np.float32)
np_output = np.fromfile(
'tests/data/for_carafe/carafe_output.bin', dtype=np.float32)
np_feat_grad = np.fromfile(
'tests/data/for_carafe/carafe_feat_grad.bin', dtype=np.float32)
np_mask_grad = np.fromfile(
'tests/data/for_carafe/carafe_mask_grad.bin', dtype=np.float32)
np_feat = np_feat.reshape((2, 64, 3, 3))
np_mask = np_mask.reshape((2, 100, 6, 6))
np_output = np_output.reshape((2, 64, 6, 6))
np_feat_grad = np_feat_grad.reshape((2, 64, 3, 3))
np_mask_grad = np_mask_grad.reshape((2, 100, 6, 6))
feat = torch.tensor(
np_feat, dtype=torch.float, device=device, requires_grad=True)
mask = torch.tensor(
np_mask, dtype=torch.float, device=device, requires_grad=True)
carafe = CARAFE(5, 4, 2)
output = carafe(feat, mask)
output.backward(torch.ones_like(output))
assert np.allclose(
output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3)
assert np.allclose(
feat.grad.data.type(torch.float).cpu().numpy(),
np_feat_grad,
atol=1e-3)
assert np.allclose(
mask.grad.data.type(torch.float).cpu().numpy(),
np_mask_grad,
atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment