Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
99cb8535
Unverified
Commit
99cb8535
authored
Aug 28, 2023
by
qirun-uiuc
Committed by
GitHub
Aug 28, 2023
Browse files
[Refactor] Replace focal_loss_sigmoid op of MLU backend with mlu-ops (#2855)
parent
ee93530a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
1110 deletions
+67
-1110
mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu
+0
-888
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
+67
-222
No files found.
mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu
deleted
100644 → 0
View file @
ee93530a
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <float.h>
#include "common_mlu_helper.hpp"
#define PING 0
#define PONG 1
__nram__ char nram_buffer[MAX_NRAM_SIZE];
namespace forward {
template <typename T>
__mlu_func__ void loadInput(char *nram_input, T *dram_input, const int32_t size,
const int32_t dst_stride = 0,
const int32_t src_stride = 0,
const int32_t count = 1) {
if (dst_stride == src_stride) {
__memcpy_async(nram_input, dram_input, size * count, GDRAM2NRAM);
} else {
__memcpy_async(nram_input, dram_input, size, GDRAM2NRAM, dst_stride,
src_stride, count - 1);
}
}
template <typename T>
__mlu_func__ void loadWeight(char *nram_input, T *dram_input, const int32_t t,
const int32_t c, const int32_t has_weight,
const int32_t partition_nc) {
if (has_weight && partition_nc && t >= 0 && t < c) {
__memcpy_async(nram_input, (T *)dram_input + t, sizeof(T), GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void storeOutput(T *dram_output, char *nram_output,
const int32_t size, const int32_t dst_stride = 0,
const int32_t src_stride = 0,
const int32_t count = 1) {
if (dst_stride == src_stride) {
__memcpy_async(dram_output, nram_output, size * count, NRAM2GDRAM);
} else {
__memcpy_async(dram_output, nram_output, size, NRAM2GDRAM, dst_stride,
src_stride, count - 1);
}
}
template <typename T>
__mlu_func__ void compute(T *input, const int32_t *target, const T *weight,
const int32_t has_weight, const int32_t partition_nc,
const int32_t deal_num, const int32_t n_seg,
const int32_t c, const int32_t c_seg,
const int32_t c_start_index, const float alpha,
const float gamma, T *compute_a, T *compute_b,
T *output) {
// set params
const int32_t c_num =
has_weight ? PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T)) : c_seg;
const int32_t c_end_index = c_start_index + c_seg;
const int32_t half_epsilon = 0x0400;
const T epsilon_f =
sizeof(T) == sizeof(float) ? FLT_MIN : *((half *)&half_epsilon);
// 0. alpha_t * p_t^r = alpha * (1 - p) ^ gamma if t == c_i
// = (1 - alpha) * p ^ gamma if t != c_i
__nramset((T *)output, deal_num, (T)(1 - alpha));
__bang_active_sigmoid((T *)compute_b, (T *)input, deal_num);
for (int32_t i = 0; i < n_seg; ++i) {
const int32_t t = *((uint32_t *)target + i);
if (t >= c_start_index && t < c_end_index) {
const uint32_t index = i * c_num + t - c_start_index;
*((T *)input + index) = -1.0 * (*((T *)input + index));
*((T *)compute_b + index) = 1.0 - (*((T *)compute_b + index)) + epsilon_f;
*((T *)output + index) = alpha;
}
}
if (sizeof(T) == sizeof(half)) {
__bang_half2float((float *)compute_a, (half *)compute_b, deal_num);
__bang_active_loghp((float *)compute_a, (float *)compute_a, deal_num);
__bang_mul_const((float *)compute_a, (float *)compute_a, (float)gamma,
deal_num);
__bang_active_exphp((float *)compute_a, (float *)compute_a, deal_num);
__bang_float2half_rd((half *)compute_a, (float *)compute_a, deal_num);
} else {
__bang_active_loghp((T *)compute_a, (T *)compute_b, deal_num);
__bang_mul_const((T *)compute_a, (T *)compute_a, (T)gamma, deal_num);
__bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num);
}
__bang_mul((T *)output, (T *)compute_a, (T *)output, deal_num);
// 1. max = max(0, -x) if t == c_i
// = max(0, x) if t != c_i
__nramset((T *)compute_b, deal_num, (T)0);
__bang_maxequal((T *)compute_b, (T *)compute_b, (T *)input, deal_num);
// 2. -log(p_t) = ln(e^(-max)+ e^(-max-x) + max if t == c_i
// = ln(e^(-max)+ e^(-max+x) + max if t != c_i
__bang_mul_const((T *)compute_a, (T *)compute_b, (T)-1.0, deal_num);
__bang_add((T *)input, (T *)compute_a, (T *)input, deal_num);
__bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num);
__bang_active_exphp((T *)input, (T *)input, deal_num);
__bang_add((T *)compute_a, (T *)compute_a, (T *)input, deal_num);
__bang_active_loghp((T *)compute_a, (T *)compute_a, deal_num);
__bang_add((T *)input, (T *)compute_a, (T *)compute_b, deal_num);
// 3. output = alpha_t * p_t^r * [-log(p_t)]
__bang_mul((T *)output, (T *)output, (T *)input, deal_num);
// 4. with weight
if (has_weight) {
for (int32_t i = 0; i < n_seg; ++i) {
int32_t t = *((int32_t *)target + i);
if (t >= 0 && t < c) {
t = partition_nc ? 0 : t;
__bang_mul_const((T *)output + i * c_num, (T *)output + i * c_num,
*((T *)weight + t), c_num);
}
}
}
}
template <typename T>
__mlu_func__ void startPipeline(
const T *input, const int32_t *target, const T *weight,
char *nram_compute_a, char *nram_compute_b, char *nram_input,
char *nram_target, char *nram_weight, char *nram_output,
const int32_t has_weight, const int32_t partition_nc,
const int32_t pingpong_offset, const int32_t pingpong_weight_offset,
const int32_t c_offset_num, const int32_t n, const int32_t n_seg,
const int32_t c, const int32_t c_seg, const float alpha, const float gamma,
T *output) {
// with offset
input = (T *)((char *)input + c_offset_num * sizeof(T));
output = (T *)((char *)output + c_offset_num * sizeof(T));
const int32_t c_seg_align_num = PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T));
const int32_t c_num = has_weight ? c_seg_align_num : c_seg;
const int32_t deal_num = PAD_UP(n_seg * c_num, NFU_ALIGN_SIZE / sizeof(T));
const int32_t load_size = c_seg * sizeof(T);
const int32_t dram_stride = c * sizeof(T);
const int32_t nram_stride = c_num * sizeof(T);
if (has_weight && !partition_nc) {
loadInput<T>(nram_weight, (T *)weight, load_size, nram_stride, dram_stride,
1);
__asm__ volatile("sync;\n\t");
}
const int32_t repeat = n / n_seg;
const int32_t remain = n % n_seg;
/*
* Pipeline: The pipeline is processed in three stages: Load, Compute, Store.
* The allocated memory space of NRAM is divided into two parts:
* PING and Pong. In a single time slice, PING is used to process
* IO stream and PONG is used for computation. Both of them are
* processed synchronously until finished.
*
* diagram of PINGPONG:
* |------|-----------------------------------------------------------------|
* | | space |
* |------|-----------------------------------------------------------------|
* | time | Ping | Pong | Ping | Pong | Ping | Pong |
* |------|-----------------------------------------------------------------|
* | 0 | L0 | | | | | |
* | 1 | C0 | L1 | | | | |
* | 2 | S0 | C1 | L2 | | | |
* | 3 | | S1 | C2 | L3 | | |
* | 4 | | | S2 | C3 | L4 | |
* | 5 | | | | S3 | C4 | L5 |
* | 6 | | | | | S4 | C5 |
* | 7 | | | | | | S5 |
* |------|-----------------------------------------------------------------|
*/
// diagram of PINGPONG: L0
if (repeat > 0) {
loadInput<T>(nram_input, (T *)input, load_size, nram_stride, dram_stride,
n_seg);
loadInput<int32_t>(nram_target, (int32_t *)target, n_seg * sizeof(int32_t));
loadWeight<T>(nram_weight, (T *)weight, *((int32_t *)target), c, has_weight,
partition_nc);
__asm__ volatile("sync;\n\t");
}
// diagram of PINGPONG: C0 and L1
if (repeat > 1) {
compute((T *)nram_input, (int32_t *)nram_target, (T *)nram_weight,
has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)nram_output);
loadInput<T>((char *)nram_input + pingpong_offset, (T *)input + c * n_seg,
load_size, nram_stride, dram_stride, n_seg);
loadInput<int32_t>((char *)nram_target + pingpong_offset,
(int32_t *)target + n_seg, n_seg * sizeof(int32_t));
loadWeight<T>((char *)nram_weight + pingpong_weight_offset, (T *)weight,
*((int32_t *)target + n_seg), c, has_weight, partition_nc);
__asm__ volatile("sync;\n\t");
}
for (int32_t i = 0; i < repeat - 2; ++i) {
storeOutput<T>((T *)output + i * c * n_seg,
nram_output + (i % 2) * pingpong_offset, load_size,
dram_stride, nram_stride, n_seg);
loadInput<T>((char *)nram_input + (i % 2) * pingpong_offset,
(T *)(input) + (i + 2) * c * n_seg, load_size, nram_stride,
dram_stride, n_seg);
loadInput<int32_t>((char *)nram_target + (i % 2) * pingpong_offset,
(int32_t *)target + (i + 2) * n_seg,
n_seg * sizeof(int32_t));
loadWeight<T>((char *)nram_weight + (i % 2) * pingpong_weight_offset,
(T *)weight, *((int32_t *)target + (i + 2) * n_seg), c,
has_weight, partition_nc);
compute((T *)(nram_input + ((i + 1) % 2) * pingpong_offset),
(int32_t *)(nram_target + ((i + 1) % 2) * pingpong_offset),
(T *)(nram_weight +
partition_nc * ((i + 1) % 2) * pingpong_weight_offset),
has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)(nram_output + ((i + 1) % 2) * pingpong_offset));
__asm__ volatile("sync;\n\t");
}
if (repeat > 1) {
storeOutput<T>((T *)output + (repeat - 2) * c * n_seg,
(char *)nram_output + (repeat % 2) * pingpong_offset,
load_size, dram_stride, nram_stride, n_seg);
}
if (remain > 0) {
loadInput<T>((char *)nram_input + (repeat % 2) * pingpong_offset,
(T *)input + repeat * c * n_seg, load_size, nram_stride,
dram_stride, remain);
loadInput<int32_t>((char *)nram_target + (repeat % 2) * pingpong_offset,
(int32_t *)target + repeat * n_seg,
remain * sizeof(int32_t));
loadWeight<T>((char *)nram_weight + (repeat % 2) * pingpong_weight_offset,
(T *)weight, *((int32_t *)target + repeat * n_seg), c,
has_weight, partition_nc);
}
if (repeat > 0) {
compute((T *)(nram_input + ((repeat - 1) % 2) * pingpong_offset),
(int32_t *)(nram_target + ((repeat - 1) % 2) * pingpong_offset),
(T *)(nram_weight +
partition_nc * ((repeat - 1) % 2) * pingpong_weight_offset),
has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)(nram_output + ((repeat - 1) % 2) * pingpong_offset));
}
__asm__ volatile("sync;\n\t");
if (repeat > 0) {
storeOutput<T>((T *)output + (repeat - 1) * c * n_seg,
(char *)nram_output + ((repeat - 1) % 2) * pingpong_offset,
load_size, dram_stride, nram_stride, n_seg);
}
if (remain > 0) {
int32_t rem_num = PAD_UP(remain * c_num, NFU_ALIGN_SIZE / sizeof(T));
compute((T *)(nram_input + (repeat % 2) * pingpong_offset),
(int32_t *)(nram_target + (repeat % 2) * pingpong_offset),
(T *)(nram_weight +
partition_nc * (repeat % 2) * pingpong_weight_offset),
has_weight, partition_nc, rem_num, remain, c, c_seg, c_offset_num,
alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b,
(T *)(nram_output + (repeat % 2) * pingpong_offset));
__asm__ volatile("sync;\n\t");
storeOutput<T>((T *)output + repeat * c * n_seg,
(char *)nram_output + (repeat % 2) * pingpong_offset,
load_size, dram_stride, nram_stride, remain);
}
__asm__ volatile("sync;\n\t");
}
template <typename T>
__mlu_func__ void focalLossSigmoidForwardBlock(
const T *input, const int32_t *target, const T *weight, const int32_t n,
const int32_t c, const float alpha, const float gamma, T *output) {
/*
* NRAM partition
* |-----------------------------------------------------------------------|
* | weight |
* |------------------------------- COMPUTE -------------------------------|
* | | |
* | computeA | computeB |
* | | |
* |------------- PING ------------------------------- PONG ---------------|
* | | |
* | input | input |
* | | |
* |-----------------------------------|-----------------------------------|
* | | |
* | output | output |
* | | |
* |-----------------------------------|-----------------------------------|
* | target | target |
* |-----------------------------------|-----------------------------------|
*
* split_pipeline_num is 6: COMPUTE(computeA,computeB), PING(input,output),
* PONG(input,output).
* split_target_num is 2: PING(target), PONG(target).
* weight is not NULL:
* The nram-size of weight is equal to c_align_size when partition input-N.
* The nram-size of weight is equal to NFU_ALIGN_SIZE when partition
* input-NC.
*/
// calculate threshold of c
const int32_t split_pipeline_num = 6;
const int32_t split_target_num = 2;
const int32_t has_weight = weight != NULL;
const int32_t threshold_c =
PAD_DOWN((MAX_NRAM_SIZE - split_target_num * sizeof(int32_t)) /
(split_pipeline_num + has_weight),
NFU_ALIGN_SIZE) /
sizeof(T);
const int32_t c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T));
const int32_t c_align_size = c_align * sizeof(T);
if (c <= threshold_c) {
// partition inputN
int32_t c_num = c;
int32_t reservered_align_size =
(split_target_num + split_pipeline_num) * NFU_ALIGN_SIZE;
int32_t weight_size = 0;
if (has_weight) {
c_num = c_align;
reservered_align_size = split_target_num * NFU_ALIGN_SIZE;
weight_size = c_align_size;
}
const int32_t remain_size =
MAX_NRAM_SIZE - weight_size - reservered_align_size;
const int32_t n_seg =
remain_size / (split_pipeline_num * c_num * sizeof(T) +
split_target_num * sizeof(int32_t));
const int32_t split_pipeline_size =
PAD_UP(c_num * n_seg * sizeof(T), NFU_ALIGN_SIZE);
const int32_t compute_size = 2 * split_pipeline_size;
const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2;
char *nram_weight = (char *)nram_buffer;
char *nram_compute_a = nram_weight + has_weight * c_align_size;
char *nram_compute_b = nram_compute_a + split_pipeline_size;
char *nram_input = nram_compute_b + split_pipeline_size;
char *nram_output = nram_input + split_pipeline_size;
char *nram_target = nram_output + split_pipeline_size;
startPipeline<T>(input, target, weight, nram_compute_a, nram_compute_b,
nram_input, nram_target, nram_weight, nram_output,
has_weight, 0, pingpong_offset, 0, 0, n, n_seg, c, c,
alpha, gamma, output);
} else {
// partition inputNC
const int32_t weight_size = has_weight * NFU_ALIGN_SIZE;
const int32_t remain_size = MAX_NRAM_SIZE - weight_size;
const int32_t split_pipeline_size = PAD_DOWN(
(remain_size - split_target_num * NFU_ALIGN_SIZE) / split_pipeline_num,
NFU_ALIGN_SIZE);
const int32_t c_seg = split_pipeline_size / sizeof(T);
const int32_t n_seg = 1;
const int32_t compute_size = 2 * split_pipeline_size;
const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2;
const int32_t pingpong_weight_offset = weight_size / 2;
char *nram_weight = (char *)nram_buffer;
char *nram_compute_a = nram_weight + weight_size;
char *nram_compute_b = nram_compute_a + split_pipeline_size;
char *nram_input = nram_compute_b + split_pipeline_size;
char *nram_output = nram_input + split_pipeline_size;
char *nram_target = nram_output + split_pipeline_size;
const int32_t loop_num = (c + c_seg - 1) / c_seg;
const int32_t partition_nc = 1;
for (int32_t i = 0; i < loop_num; ++i) {
const int32_t c_index = i * c_seg;
const int32_t c_seg_curr = i == (loop_num - 1) ? c - c_index : c_seg;
startPipeline<T>(input, target, weight, nram_compute_a, nram_compute_b,
nram_input, nram_target, nram_weight, nram_output,
has_weight, partition_nc, pingpong_offset,
pingpong_weight_offset, c_index, n, n_seg, c, c_seg_curr,
alpha, gamma, output);
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelFocalLossSigmoidForward(
const void *input, const void *target, const void *weight, const int32_t N,
const int32_t C, const float alpha, const float gamma, void *output) {
const int32_t n_seg = N / taskDim + (taskId == taskDim - 1) * (N % taskDim);
const T *input_offset = (T *)input + N / taskDim * taskId * C;
const int32_t *target_offset = (int32_t *)target + N / taskDim * taskId;
T *output_offset = (T *)output + N / taskDim * taskId * C;
focalLossSigmoidForwardBlock((T *)input_offset, (int32_t *)target_offset,
(T *)weight, n_seg, C, alpha, gamma,
(T *)output_offset);
}
} // namespace forward
namespace backward {
template <typename T>
__mlu_func__ void loadInput(char *nram_input, char *nram_target,
const T *gdram_input, const int32_t *gdram_target,
const int32_t deal_n, const int32_t total_c,
const bool pingping_flag, const bool has_weight,
const int32_t nram_offset,
const int32_t gdram_offset) {
if (pingping_flag == PONG) {
nram_input += nram_offset;
nram_target += nram_offset;
}
__memcpy_async(nram_target, gdram_target + gdram_offset / total_c,
deal_n * sizeof(int32_t), GDRAM2NRAM);
char *nram_input_load = nram_input;
int32_t compute_align_size = 2 * NFU_ALIGN_SIZE;
if (has_weight) {
if (sizeof(T) == sizeof(half)) {
int32_t compute_align_num = compute_align_size / sizeof(float);
int32_t align_c = PAD_UP(total_c, compute_align_num);
int32_t compute_size = deal_n * align_c * sizeof(float);
nram_input_load += compute_size / 2;
}
int32_t align_c = PAD_UP(total_c, NFU_ALIGN_SIZE / sizeof(T));
int32_t total_c_size = total_c * sizeof(T);
int32_t align_c_size = align_c * sizeof(T);
__memcpy_async(nram_input_load, gdram_input + gdram_offset, total_c_size,
GDRAM2NRAM, align_c_size, total_c_size, deal_n - 1);
} else {
if (sizeof(T) == sizeof(half)) {
int32_t compute_size =
PAD_UP(deal_n * total_c * sizeof(float), compute_align_size);
nram_input_load += compute_size / 2;
}
int32_t load_size = deal_n * total_c * sizeof(T);
__memcpy_async(nram_input_load, gdram_input + gdram_offset, load_size,
GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void sigmoid(T *dst_data, const T *src_data,
const int32_t elem_count) {
__bang_mul_const(dst_data, (T *)src_data, T(-1), elem_count);
__bang_active_exphp(dst_data, dst_data, elem_count);
__bang_add_const(dst_data, dst_data, T(1), elem_count);
__bang_active_reciphp(dst_data, dst_data, elem_count);
}
template <typename T>
__mlu_func__ void coreCompute(char *nram_input, const T *nram_weight,
const float *nram_flt_min, char *nram_pt,
char *nram_alpha_t, char *nram_temp,
char *nram_target, const float *nram_gamma,
char *nram_output, const float alpha,
const int32_t compute_num, const int32_t deal_n,
const int32_t total_c, const bool pingpong_flag,
const int32_t nram_offset,
const bool has_weight) {
if (pingpong_flag == PONG) {
nram_input += nram_offset;
nram_pt += nram_offset;
nram_alpha_t += nram_offset;
nram_temp += nram_offset;
nram_output += nram_offset;
nram_target += nram_offset;
}
if (sizeof(T) == sizeof(half)) {
const int32_t compute_size = compute_num * sizeof(float);
char *nram_input_load = nram_input + compute_size / 2;
__bang_half2float((float *)nram_input, (half *)nram_input_load,
compute_num);
}
// 0. alpha_t = alpha - 1
__nramset((float *)nram_alpha_t, compute_num, (float)(alpha - 1.0));
// 1. pt = 1 - sigmoid(x)
sigmoid((float *)nram_pt, (float *)nram_input, compute_num);
__bang_mul_const((float *)nram_pt, (float *)nram_pt, (float)(-1),
compute_num);
__bang_add_const((float *)nram_pt, (float *)nram_pt, (float)1, compute_num);
// 2. pt = target[n] == c ? sigmoid(x) : 1 - sigmoid(x)
// alpha_t = target[n] == c ? alpha : alpha - 1
const int32_t nfu_align_num = NFU_ALIGN_SIZE / sizeof(float);
for (int n = 0; n < deal_n; n++) {
const int32_t target_value = ((int32_t *)nram_target)[n];
if (target_value >= total_c || target_value < 0) continue;
int32_t c_offset = 0;
if (has_weight) {
int32_t c_align_num = nfu_align_num;
if (sizeof(T) == sizeof(half)) {
c_align_num += nfu_align_num;
}
c_offset = PAD_UP(total_c, c_align_num);
} else {
c_offset = total_c;
}
int32_t idx = n * c_offset + target_value;
*((float *)nram_pt + idx) = 1.0 - *((float *)nram_pt + idx);
*((float *)nram_alpha_t + idx) = alpha;
}
// 3. temp = -alpha_t * e^(gamma * log(max(1 - pt, FLT_MIN))
__bang_mul_const((float *)nram_temp, (float *)nram_pt, (float)(-1),
compute_num);
__bang_add_const((float *)nram_temp, (float *)nram_temp, (float)(1),
compute_num);
__bang_cycle_maxequal((float *)nram_temp, (float *)nram_temp,
(float *)nram_flt_min, compute_num, nfu_align_num);
__bang_active_loghp((float *)nram_temp, (float *)nram_temp, compute_num);
__bang_cycle_mul((float *)nram_temp, (float *)nram_temp, (float *)nram_gamma,
compute_num, nfu_align_num);
__bang_active_exphp((float *)nram_temp, (float *)nram_temp, compute_num);
__bang_mul((float *)nram_temp, (float *)nram_temp, (float *)nram_alpha_t,
compute_num);
__bang_mul_const((float *)nram_temp, (float *)nram_temp, (float)(-1),
compute_num);
// 4. output = 1 - pt - gamma * pt * log(max(pt, FLT_MIN))
__bang_cycle_maxequal((float *)nram_output, (float *)nram_pt,
(float *)nram_flt_min, compute_num, nfu_align_num);
__bang_active_loghp((float *)nram_output, (float *)nram_output, compute_num);
__bang_mul((float *)nram_output, (float *)nram_output, (float *)nram_pt,
compute_num);
__bang_cycle_mul((float *)nram_output, (float *)nram_output,
(float *)nram_gamma, compute_num, nfu_align_num);
__bang_add((float *)nram_output, (float *)nram_output, (float *)nram_pt,
compute_num);
__bang_mul_const((float *)nram_output, (float *)nram_output, (float)(-1),
compute_num);
__bang_add_const((float *)nram_output, (float *)nram_output, (float)(1),
compute_num);
// 5. output = output * temp
__bang_mul((float *)nram_output, (float *)nram_output, (float *)nram_temp,
compute_num);
if (sizeof(T) == sizeof(half)) {
__bang_float2half_rd((half *)nram_output, (float *)nram_output,
compute_num);
}
if (has_weight) {
// with weight
for (int n = 0; n < deal_n; n++) {
int32_t c_align_num = nfu_align_num;
if (sizeof(T) == sizeof(half)) {
c_align_num += nfu_align_num;
}
int32_t align_c = PAD_UP(total_c, c_align_num);
int32_t target_value = ((int32_t *)nram_target)[n];
T weight_value = nram_weight[target_value];
__bang_mul_const((T *)nram_output + n * align_c,
(T *)nram_output + n * align_c, weight_value, align_c);
}
}
}
template <typename T>
__mlu_func__ void storeOutput(T *gdram_output, const char *nram_output,
const int32_t deal_n, const int32_t total_c,
const bool pingpong_flag, const bool has_weight,
const int32_t nram_offset,
const int32_t gdram_offset) {
if (pingpong_flag == PONG) {
nram_output += nram_offset;
}
const int32_t store_size = deal_n * total_c * sizeof(T);
if (has_weight) {
int32_t align_c = PAD_UP(total_c, NFU_ALIGN_SIZE / sizeof(T));
int32_t total_c_size = total_c * sizeof(T);
int32_t align_c_size = align_c * sizeof(T);
__memcpy_async(gdram_output + gdram_offset, nram_output, total_c_size,
NRAM2GDRAM, total_c_size, align_c_size, deal_n - 1);
} else {
__memcpy_async(gdram_output + gdram_offset, nram_output, store_size,
NRAM2GDRAM);
}
}
template <typename T>
__mlu_func__ void focalLossSigmoidBackwardBlock(
const T *input, const int32_t *target, const T *weight, const float gamma,
const float alpha, const int32_t total_n, const int32_t deal_n,
const int32_t total_c, T *output) {
// params per time slice
int32_t deal_num = deal_n * total_c;
int32_t deal_size = deal_num * sizeof(float);
int32_t compute_num = 0;
int32_t compute_size = 0;
int32_t compute_align_size = NFU_ALIGN_SIZE;
const int32_t nfu_align_num = NFU_ALIGN_SIZE / sizeof(T);
if (sizeof(T) == sizeof(half)) {
compute_align_size += NFU_ALIGN_SIZE;
}
const int32_t compute_align_num = compute_align_size / sizeof(float);
bool has_weight = false;
if (weight != NULL) {
has_weight = true;
int32_t align_c = PAD_UP(total_c, compute_align_num);
compute_num = deal_n * align_c;
compute_size = compute_num * sizeof(float);
} else {
compute_size = PAD_UP(deal_size, compute_align_size);
compute_num = compute_size / sizeof(float);
}
// params per core
int32_t total_num = total_n * total_c;
int32_t num_per_core = PAD_DOWN(total_num / taskDim, deal_num);
int32_t loop_per_core = num_per_core / deal_num;
/* NRAM partition:
*
* |-----------------ping pong--------------------|
* |input | pt | alpha_t | temp | output | target | flt_min | gamma | weight|
*
* split_pipeline_num is 5: input, pt, alpha_t, temp, output.
* nram_reserved_line_num is 2: flt_min, gamma.
*/
const int32_t split_pipeline_num = 5;
const int32_t nram_reserved_line_num = 2;
int32_t target_deal_size = deal_n * sizeof(int32_t);
int32_t target_deal_size_align = PAD_UP(target_deal_size, NFU_ALIGN_SIZE);
// nram PING/PONG offset
int32_t ping_pong_offset =
compute_size * split_pipeline_num + target_deal_size_align;
// gdram addr
int32_t *base_addr_target =
(int32_t *)target + taskId * loop_per_core * deal_n;
T *base_addr_input = (T *)input + taskId * num_per_core;
T *base_addr_output = output + taskId * num_per_core;
// nram addr
char *nram_input = (char *)nram_buffer;
char *nram_pt = nram_input + compute_size;
char *nram_alpha_t = nram_pt + compute_size;
char *nram_temp = nram_alpha_t + compute_size;
char *nram_output = nram_temp + compute_size;
char *nram_target = nram_output + compute_size;
float *nram_flt_min = NULL;
float *nram_gamma = NULL;
T *nram_weight = NULL;
if (!has_weight) {
nram_flt_min = (float *)(nram_buffer + MAX_NRAM_SIZE -
nram_reserved_line_num * NFU_ALIGN_SIZE);
nram_gamma = nram_flt_min + nfu_align_num;
} else {
int32_t weight_space = PAD_UP(total_c * sizeof(T), NFU_ALIGN_SIZE);
nram_flt_min =
(float *)(nram_buffer + MAX_NRAM_SIZE -
nram_reserved_line_num * NFU_ALIGN_SIZE - weight_space);
nram_gamma = nram_flt_min + nfu_align_num;
nram_weight = (T *)(nram_gamma + nfu_align_num);
__memcpy_async(nram_weight, weight, total_c * sizeof(T), GDRAM2NRAM);
}
// nram set gamma and FLT_MIN
__nramset(nram_gamma, nfu_align_num, gamma);
__nramset(nram_flt_min, nfu_align_num, FLT_MIN);
/*
* Pipeline: The pipeline is processed in three stages: Load, Compute, Store.
* The allocated memory space of NRAM is divided into two parts:
* PING and Pong. In a single time slice, PING is used to process
* IO stream and PONG is used for computation. Both of them are
* processed synchronously until finished.
*
* diagram of PINGPONG:
* |------|-----------------------------------------------------------------|
* | | space |
* |------|-----------------------------------------------------------------|
* | time | Ping | Pong | Ping | Pong | Ping | Pong |
* |------|-----------------------------------------------------------------|
* | 0 | L0 | | | | | |
* | 1 | C0 | L1 | | | | |
* | 2 | S0 | C1 | L2 | | | |
* | 3 | | S1 | C2 | L3 | | |
* | 4 | | | S2 | C3 | L4 | |
* | 5 | | | | S3 | C4 | L5 |
* | 6 | | | | | S4 | C5 |
* | 7 | | | | | | S5 |
* |------|-----------------------------------------------------------------|
*/
// diagram of PINGPONG: L0
if (loop_per_core > 0) {
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PING, has_weight, ping_pong_offset, 0);
__asm__ volatile("sync;");
}
// diagram of PINGPONG: C0 and L1
if (loop_per_core > 1) {
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PONG, has_weight, ping_pong_offset, deal_num);
__asm__ volatile("sync;");
}
for (int i = 0; i < loop_per_core - 2; ++i) {
if (i % 2 == PING) {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PING,
has_weight, ping_pong_offset, i * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PONG, ping_pong_offset,
has_weight);
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PING, has_weight, ping_pong_offset,
(i + 2) * deal_num);
} else {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG,
has_weight, ping_pong_offset, i * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
deal_n, total_c, PONG, has_weight, ping_pong_offset,
(i + 2) * deal_num);
}
__asm__ volatile("sync;");
}
if (loop_per_core > 1) {
if ((loop_per_core - 2) % 2 == PING) {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PING,
has_weight, ping_pong_offset, (loop_per_core - 2) * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PONG, ping_pong_offset,
has_weight);
} else {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG,
has_weight, ping_pong_offset, (loop_per_core - 2) * deal_num);
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
}
__asm__ volatile("sync;");
}
if (loop_per_core > 0) {
if (loop_per_core == 1) {
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
compute_num, deal_n, total_c, PING, ping_pong_offset,
has_weight);
__asm__ volatile("sync;");
}
if ((loop_per_core - 1) % 2 == PING) {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PING,
has_weight, ping_pong_offset, (loop_per_core - 1) * deal_num);
} else {
storeOutput(base_addr_output, nram_output, deal_n, total_c, PONG,
has_weight, ping_pong_offset, (loop_per_core - 1) * deal_num);
}
}
// process the remaining data which N remainder per core is less than deal_n
int32_t rem_for_all = total_num - num_per_core * taskDim;
if (rem_for_all == 0) return;
int32_t rem_n_for_all = rem_for_all / total_c;
int32_t rem_n_per_core = (rem_n_for_all + taskDim - 1) / taskDim;
int32_t rem_num_per_core = rem_n_per_core * total_c;
int32_t rem_num_per_core_align = 0;
int32_t rem_core_num = rem_for_all / rem_num_per_core;
int32_t rem_n_for_last = rem_n_for_all % rem_n_per_core;
int32_t rem_num_for_last = rem_n_for_last * total_c;
int32_t rem_num_for_last_align = 0;
if (has_weight) {
int32_t align_c = PAD_UP(total_c, compute_align_num);
rem_num_per_core_align = rem_n_per_core * align_c;
rem_num_for_last_align = rem_n_for_last * align_c;
} else {
rem_num_per_core_align = PAD_UP(rem_num_per_core, compute_align_num);
rem_num_for_last_align = PAD_UP(rem_num_for_last, compute_align_num);
}
int32_t rem_addr_base = num_per_core * taskDim;
int32_t rem_target_addr_base = loop_per_core * deal_n * taskDim;
base_addr_target = (int32_t *)target + rem_target_addr_base;
base_addr_input = (T *)input + rem_addr_base;
base_addr_output = output + rem_addr_base;
if (taskId < rem_core_num) {
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
rem_n_per_core, total_c, PING, has_weight, ping_pong_offset,
taskId * rem_num_per_core);
__asm__ volatile("sync;");
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
rem_num_per_core_align, rem_n_per_core, total_c, PING,
ping_pong_offset, has_weight);
__asm__ volatile("sync;");
storeOutput(base_addr_output, nram_output, rem_n_per_core, total_c, PING,
has_weight, ping_pong_offset, taskId * rem_num_per_core);
} else if (taskId == rem_core_num) {
if (rem_num_for_last == 0) return;
loadInput(nram_input, nram_target, base_addr_input, base_addr_target,
rem_n_for_last, total_c, PING, has_weight, ping_pong_offset,
taskId * rem_num_per_core);
__asm__ volatile("sync;");
coreCompute(nram_input, nram_weight, nram_flt_min, nram_pt, nram_alpha_t,
nram_temp, nram_target, nram_gamma, nram_output, alpha,
rem_num_for_last_align, rem_n_for_last, total_c, PING,
ping_pong_offset, has_weight);
__asm__ volatile("sync;");
storeOutput(base_addr_output, nram_output, rem_n_for_last, total_c, PING,
has_weight, ping_pong_offset, taskId * rem_num_per_core);
} else {
return;
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelFocalLossSigmoidBackward(
const void *input, const void *target, const void *weight,
const float gamma, const float alpha, const int32_t total_n,
const int32_t deal_n, const int32_t total_c, void *output) {
focalLossSigmoidBackwardBlock((T *)input, (int32_t *)target, (T *)weight,
gamma, alpha, total_n, deal_n, total_c,
(T *)output);
}
} // namespace backward
void KernelFocalLossSigmoidForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue,
const cnrtDataType_t d_type,
const void *input, const void *target,
const void *weight, const int32_t N,
const int32_t C, const float alpha,
const float gamma, void *output) {
if (d_type == CNRT_FLOAT16) {
forward::MLUUnion1KernelFocalLossSigmoidForward<
half><<<k_dim, k_type, queue>>>(input, target, weight, N, C, alpha,
gamma, output);
} else {
forward::MLUUnion1KernelFocalLossSigmoidForward<
float><<<k_dim, k_type, queue>>>(input, target, weight, N, C, alpha,
gamma, output);
}
}
void KernelFocalLossSigmoidBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue,
const cnrtDataType_t d_type,
const void *input, const void *target,
const void *weight, const float gamma,
const float alpha, const int32_t dim_n,
const int32_t deal_n, const int32_t dim_c,
void *output) {
if (d_type == CNRT_FLOAT16) {
backward::MLUUnion1KernelFocalLossSigmoidBackward<
half><<<k_dim, k_type, queue>>>(input, target, weight, gamma, alpha,
dim_n, deal_n, dim_c, output);
} else {
backward::MLUUnion1KernelFocalLossSigmoidBackward<
float><<<k_dim, k_type, queue>>>(input, target, weight, gamma, alpha,
dim_n, deal_n, dim_c, output);
}
}
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
View file @
99cb8535
...
...
@@ -12,87 +12,11 @@
#include <string>
#include <vector>
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#include "mlu_common_helper.h"
void
KernelFocalLossSigmoidForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
target
,
const
void
*
weight
,
const
int32_t
N
,
const
int32_t
C
,
const
float
alpha
,
const
float
gamma
,
void
*
output
);
void
KernelFocalLossSigmoidBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
target
,
const
void
*
weight
,
const
float
gamma
,
const
float
alpha
,
const
int32_t
dim_n
,
const
int32_t
deal_n
,
const
int32_t
dim_c
,
void
*
output
);
// Policy Function for Forward
static
void
policyFuncForward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
Tensor
&
input
,
const
Tensor
&
target
,
const
Tensor
&
weight
)
{
auto
N
=
input
.
size
(
0
);
auto
C
=
input
.
size
(
1
);
const
size_t
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
size_t
c_align_size
=
PAD_UP
((
C
*
input
.
itemsize
()),
NFU_ALIGN_SIZE
);
const
int
split_target_num
=
2
;
const
int
split_pipeline_num
=
6
;
const
int
has_weight
=
weight
.
data_ptr
()
!=
nullptr
;
const
int
target_data_width
=
target
.
scalar_type
()
==
at
::
kLong
?
target
.
itemsize
()
/
2
:
target
.
itemsize
();
const
int
threshold_c
=
PAD_DOWN
((
nram_size
-
split_target_num
*
sizeof
(
int
))
/
(
split_pipeline_num
+
has_weight
),
NFU_ALIGN_SIZE
)
/
input
.
itemsize
();
int
n_seg
=
1
;
if
(
C
<=
threshold_c
)
{
int
c_size
=
C
*
input
.
itemsize
();
int
reservered_align_size
=
(
split_target_num
+
split_pipeline_num
)
*
NFU_ALIGN_SIZE
;
int
wegiht_size
=
0
;
if
(
has_weight
)
{
c_size
=
c_align_size
;
reservered_align_size
=
split_target_num
*
NFU_ALIGN_SIZE
;
wegiht_size
=
c_align_size
;
}
// n_seg * c_size * split_pipeline_num + n_seg * target.itemsize() *
// split_target_num
// + weight_size + reservered_align_size <= nram_size
n_seg
=
(
nram_size
-
wegiht_size
-
reservered_align_size
)
/
(
split_pipeline_num
*
c_size
+
split_target_num
*
sizeof
(
int32_t
));
}
auto
seg_num
=
n_seg
==
0
?
N
:
(
N
+
n_seg
-
1
)
/
n_seg
;
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
auto
core_num
=
core_dim
*
cluster_num
;
k_dim
->
x
=
*
k_type
;
k_dim
->
y
=
seg_num
>
core_num
?
cluster_num
:
(
seg_num
+
core_dim
-
1
)
/
core_dim
;
k_dim
->
z
=
1
;
}
// Policy Function for Backward
static
void
policyFuncBackward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// set Union1 Job
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
->
z
=
1
;
}
void
SigmoidFocalLossForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
target
,
void
sigmoid_focal_loss_forward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
const
float
gamma
,
const
float
alpha
)
{
const
float
gamma
,
const
float
alpha
)
{
// params check
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
"But now gamma is "
,
gamma
,
"."
);
...
...
@@ -123,103 +47,50 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target,
return
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
policyFuncForward
(
&
k_dim
,
&
k_type
,
input
,
target
,
weight
);
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// contiguous
auto
input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
input
.
suggest_memory_format
());
// target only support in32
auto
target_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
target
.
toType
(
at
::
kInt
),
target
.
suggest_memory_format
());
auto
weight_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
weight
,
weight
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
output
.
suggest_memory_format
());
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
target_desc
,
weight_desc
,
output_desc
;
input_desc
.
set
(
input_contiguous
);
target_desc
.
set
(
target_contiguous
);
weight_desc
.
set
(
weight_contiguous
);
output_desc
.
set
(
output_contiguous
);
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
_contiguous
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
);
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
_contiguous
);
auto
target_ptr
=
target_impl
->
cnnlMalloc
();
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
);
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
_contiguous
);
auto
weight_ptr
=
weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
_contiguous
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get dtype of input
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
CNLOG
(
INFO
)
<<
"Launch Kernel KernelFocalLossSigmoidForward<<<Union"
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
// launch kernel
KernelFocalLossSigmoidForward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
target_ptr
,
weight_ptr
,
input
.
size
(
0
),
input
.
size
(
1
),
alpha
,
gamma
,
output_ptr
);
}
void
getDealNAndThresholdC
(
const
int
compute_data_bytes
,
const
int
target_data_bytes
,
const
int
total_c
,
int
*
deal_n_ptr
,
int
*
threshold_c_ptr
,
const
bool
has_weight
,
const
bool
is_half
)
{
/* NRAM partition:
*
* |-----------------ping pong--------------------|
* |input | pt | alpha_t | temp | output | target | flt_min | gamma | weight|
*
* split_pipeline_num is 5: including input, pt, alpha_t, temp, output.
*/
const
int
nram_split_num
=
5
;
const
int
nram_split_pingpong
=
2
;
const
int
max_nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
int32_t
compute_align_size
=
NFU_ALIGN_SIZE
;
if
(
is_half
)
{
compute_align_size
+=
NFU_ALIGN_SIZE
;
}
const
int32_t
compute_align_num
=
compute_align_size
/
compute_data_bytes
;
// reservered_align_size: including input(ping pong), pt(ping pong),
// alpha_t(ping pong), temp(ping pong),
// output(ping pong), target(ping pong),
// flt_min and gamma.
const
int
reservered_align_size
=
((
nram_split_num
+
1
)
*
nram_split_pingpong
+
2
)
*
compute_align_size
;
int
nram_pingpong_size
=
max_nram_size
-
reservered_align_size
;
int
compute_c
=
total_c
;
int
threshold_c
=
0
;
if
(
has_weight
)
{
// reserved space for weight to align
nram_pingpong_size
-=
NFU_ALIGN_SIZE
;
// set prefer computation performance and redcuntion approach
mluOpComputationPreference_t
prefer
=
MLUOP_COMPUTATION_FAST
;
mluOpLossReduction_t
reduction
=
MLUOP_LOSS_REDUCTION_NONE
;
// threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num +
// nram_split_pingpong * target_data_bytes +
// threshold_c * compute_data_bytes <= nram_pingpong_size
threshold_c
=
(
nram_pingpong_size
-
nram_split_pingpong
*
target_data_bytes
)
/
(
compute_data_bytes
*
(
nram_split_num
*
nram_split_pingpong
+
1
));
threshold_c
=
PAD_DOWN
(
threshold_c
,
compute_align_num
);
int
weight_space
=
PAD_UP
(
total_c
*
compute_data_bytes
,
NFU_ALIGN_SIZE
);
auto
handle
=
mluOpGetCurrentHandle
();
// reserved space for weight
nram_pingpong_size
-=
weight_space
;
compute_c
=
PAD_UP
(
total_c
,
compute_align_num
);
}
else
{
// threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num +
// nram_split_pingpong * target_data_bytes <= nram_pingpong_size
threshold_c
=
(
nram_pingpong_size
/
nram_split_pingpong
-
target_data_bytes
)
/
(
nram_split_num
*
compute_data_bytes
);
}
// deal_n * compute_c * nram_split_pingpong * compute_data_bytes *
// nram_split_num + deal_n * nram_split_pingpong * target_data_bytes <=
// nram_pingpong_size
*
deal_n_ptr
=
nram_pingpong_size
/
((
nram_split_num
*
compute_c
*
compute_data_bytes
+
target_data_bytes
)
*
nram_split_pingpong
);
*
threshold_c_ptr
=
threshold_c
;
// launch kernel
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidForward
(
handle
,
prefer
,
reduction
,
input_desc
.
desc
(),
input_ptr
,
target_desc
.
desc
(),
target_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
gamma
,
output_desc
.
desc
(),
output_ptr
));
}
void
S
igmoid
F
ocal
L
oss
B
ackward
MLUKernelLauncher
(
Tensor
input
,
Tensor
target
,
void
s
igmoid
_f
ocal
_l
oss
_b
ackward
_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
const
float
gamma
,
const
float
alpha
)
{
const
float
gamma
,
const
float
alpha
)
{
// params check
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
"But now gamma is "
,
gamma
,
"."
);
...
...
@@ -246,77 +117,51 @@ void SigmoidFocalLossBackwardMLUKernelLauncher(Tensor input, Tensor target,
CNLOG
(
INFO
)
<<
"weight is a empty tensor."
;
}
auto
dim_c
=
input
.
size
(
1
);
const
int
compute_data_bytes
=
sizeof
(
float
);
// target supports only INT on MLU device while it keeps LONG on host side,
// so target.itemsize() / 2
const
int
target_data_bytes
=
target
.
scalar_type
()
==
at
::
kLong
?
(
target
.
itemsize
()
/
2
)
:
target
.
itemsize
();
int
deal_n
=
0
;
int
threshold_c
=
0
;
bool
is_half
=
false
;
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
is_half
=
true
;
}
// calculate deal_n and threshold_c
getDealNAndThresholdC
(
compute_data_bytes
,
target_data_bytes
,
dim_c
,
&
deal_n
,
&
threshold_c
,
has_weight
,
is_half
);
// check C
TORCH_CHECK
(
threshold_c
>=
dim_c
,
"input.size(1) should be in the range of [0, "
,
threshold_c
,
"]. "
,
"But now input.size(1) is "
,
dim_c
,
"."
);
if
(
input
.
numel
()
==
0
||
target
.
numel
()
==
0
||
output
.
numel
()
==
0
)
{
// return if zero-element
return
;
}
// set task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncBackward
(
&
k_dim
,
&
k_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// contiguous
auto
input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
input
.
suggest_memory_format
());
// only support in32
auto
target_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
target
.
toType
(
at
::
kInt
),
target
.
suggest_memory_format
());
auto
weight_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
weight
,
weight
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
output
.
suggest_memory_format
());
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
target_desc
,
weight_desc
,
output_desc
;
input_desc
.
set
(
input_contiguous
);
target_desc
.
set
(
target_contiguous
);
weight_desc
.
set
(
weight_contiguous
);
output_desc
.
set
(
output_contiguous
);
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
_contiguous
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
);
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
_contiguous
);
auto
target_ptr
=
target_impl
->
cnnlMalloc
();
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
);
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
_contiguous
);
auto
weight_ptr
=
weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
_contiguous
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
//
g
et
dtype of input
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
)
;
auto
dim_n
=
input
.
size
(
0
)
;
//
s
et
prefer computation performance and redcuntion approach
// backward only support MLUOP_COMPUTATION_HIGH_PRECISION
mluOpComputationPreference_t
prefer
=
MLUOP_COMPUTATION_HIGH_PRECISION
;
mluOpLossReduction_t
reduction
=
MLUOP_LOSS_REDUCTION_NONE
;
CNLOG
(
INFO
)
<<
"Launch Kernel KernelFocalLossSigmoidBackward<<<Union"
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
auto
handle
=
mluOpGetCurrentHandle
();
// launch kernel
KernelFocalLossSigmoidBackward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
target_ptr
,
weight_ptr
,
gamma
,
alpha
,
dim_n
,
deal_n
,
dim_c
,
output_ptr
);
}
void
sigmoid_focal_loss_forward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
float
gamma
,
float
alpha
)
{
SigmoidFocalLossForwardMLUKernelLauncher
(
input
,
target
,
weight
,
output
,
gamma
,
alpha
);
}
void
sigmoid_focal_loss_backward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
grad_input
,
float
gamma
,
float
alpha
)
{
SigmoidFocalLossBackwardMLUKernelLauncher
(
input
,
target
,
weight
,
grad_input
,
gamma
,
alpha
);
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidBackward
(
handle
,
prefer
,
reduction
,
input_desc
.
desc
(),
input_ptr
,
target_desc
.
desc
(),
target_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
gamma
,
output_desc
.
desc
(),
output_ptr
));
}
void
sigmoid_focal_loss_forward_impl
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment