Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
92b3e861
Unverified
Commit
92b3e861
authored
Jun 01, 2023
by
liuduanhui
Committed by
GitHub
Jun 01, 2023
Browse files
[Refactor] Replace the implementation of psa_mask with mlu-ops. (#2810)
parent
2611b990
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
883 deletions
+14
-883
mmcv/ops/csrc/common/mlu/psamask_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/psamask_mlu_kernel.mlu
+0
-615
mmcv/ops/csrc/common/mlu/psamask_utils.hpp
mmcv/ops/csrc/common/mlu/psamask_utils.hpp
+0
-55
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
+14
-213
No files found.
mmcv/ops/csrc/common/mlu/psamask_mlu_kernel.mlu
deleted
100644 → 0
View file @
2611b990
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include "psamask_utils.hpp"
#define COMPUTE_COUNT_ALIGN 64
__nram__ char buf[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void swap(T &a, T &b) {
T tmp = a;
a = b;
b = tmp;
}
template <typename T>
__mlu_func__ void storeDataFromNramToDram(T *dst, const T *src,
const PositionInCore &position,
const Shape &shape_full) {
int n_offset = shape_full.h * shape_full.w * shape_full.c;
int h_offset = shape_full.w * shape_full.c;
int w_offset = shape_full.c;
int n_seg = position.n_end - position.n_start;
int h_seg = position.h_end - position.h_start;
int w_seg = position.w_end - position.w_start;
int size = h_seg * w_seg * shape_full.c;
__memcpy(dst + position.n_start * n_offset + position.h_start * h_offset +
position.w_start * w_offset,
src, size * sizeof(T), NRAM2GDRAM, n_offset * sizeof(T),
size * sizeof(T), n_seg - 1);
}
template <typename T>
__mlu_func__ void loadDataFromDramToNram(T *dst, const T *src,
const PositionInCore &position,
const Shape &shape_full) {
int n_offset = shape_full.h * shape_full.w * shape_full.c;
int h_offset = shape_full.w * shape_full.c;
int w_offset = shape_full.c;
int n_seg = position.n_end - position.n_start;
int h_seg = position.h_end - position.h_start;
int w_seg = position.w_end - position.w_start;
int size = h_seg * w_seg * shape_full.c;
__memcpy(dst, src + position.n_start * n_offset +
position.h_start * h_offset + position.w_start * w_offset,
size * sizeof(T), GDRAM2NRAM, size * sizeof(T), n_offset * sizeof(T),
n_seg - 1);
}
// transpose the data from A*B*C*(D*E) to A*D*E*(B*C)
template <typename T>
__mlu_func__ void transposeData(T *dst, T *src, const Shape &shape_seg) {
int align_c = CEIL_ALIGN(shape_seg.c, COMPUTE_COUNT_ALIGN / sizeof(T));
int align_hw =
CEIL_ALIGN(shape_seg.h * shape_seg.w, COMPUTE_COUNT_ALIGN / sizeof(T));
for (int i = 0; i < shape_seg.n; ++i) {
__bang_transpose(dst, src, align_hw, align_c);
dst += align_hw * align_c;
src += align_hw * align_c;
}
}
template <typename T>
__mlu_func__ void psamaskCollectForward(
const T *x_dram, T *y_dram, const PositionInCore &position,
const Shape &x_full, const Shape &y_full, const Shape &shape_seg,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask) {
T *x_nram = (T *)buf;
T *y_nram =
x_nram + CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * x_full.c,
COMPUTE_COUNT_ALIGN / sizeof(T));
loadDataFromDramToNram(x_nram, x_dram, position, x_full);
// fill zeros to output
int elem_count =
CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * y_full.c,
NFU_ALIGN_SIZE / sizeof(T));
__bang_write_value(y_nram, elem_count, (T)0);
int y_n_offset = shape_seg.h * shape_seg.w * shape_seg.c;
int y_h_offset = shape_seg.w * shape_seg.c;
int y_w_offset = shape_seg.c;
int x_n_offset = shape_seg.h * shape_seg.w * x_full.c;
int y_c_offset = 1;
int x_h_offset = shape_seg.w * x_full.c;
int x_w_offset = x_full.c;
int x_c_offset = 1;
int x_start = 0;
int y_start = 0;
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
for (int widx = 0; widx < shape_seg.w; ++widx) {
int h_abs = hidx + position.h_start;
int w_abs = widx + position.w_start;
int y_offset = y_start;
int x_offset = x_start;
y_offset += hidx * y_h_offset + widx * y_w_offset;
x_offset += hidx * x_h_offset + widx * x_w_offset;
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
const int hend = x_full.h + half_h_mask - h_abs < h_mask
? x_full.h + half_h_mask - h_abs
: h_mask;
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
const int wend = x_full.w + half_w_mask - w_abs < w_mask
? x_full.w + half_w_mask - w_abs
: w_mask;
// (h, w ) with mask-indexed
// (h + hidx - half_h_mask, w + widx - half_w_mask) with feature-indexed
y_offset += ((hstart + h_abs - half_h_mask) * x_full.w + wstart +
w_abs - half_w_mask) *
y_c_offset;
x_offset += (hstart * w_mask + wstart) * x_c_offset;
int count = wend - wstart;
__memcpy(y_nram + y_offset, x_nram + x_offset, count * sizeof(T),
NRAM2NRAM, y_c_offset * x_full.w * sizeof(T),
x_c_offset * w_mask * sizeof(T), hend - hstart - 1);
}
}
y_start += y_n_offset;
x_start += x_n_offset;
}
storeDataFromNramToDram(y_dram, y_nram, position, y_full);
}
template <typename T>
__mlu_func__ void psamaskDistributeForward(
const T *x_dram, T *y_dram, const PositionInCore &position,
const Shape &x_full, const Shape &y_full, const Shape &shape_seg,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask) {
T *x_nram = (T *)buf;
T *y_nram_temp =
x_nram + CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * x_full.c,
COMPUTE_COUNT_ALIGN / sizeof(T));
loadDataFromDramToNram(x_nram, x_dram, position, x_full);
// fill zeros to output
int align_c = CEIL_ALIGN(y_full.c, COMPUTE_COUNT_ALIGN / sizeof(T));
int align_hw =
CEIL_ALIGN(shape_seg.h * shape_seg.w, COMPUTE_COUNT_ALIGN / sizeof(T));
int elem_count =
CEIL_ALIGN(shape_seg.n * align_c * align_hw, NFU_ALIGN_SIZE / sizeof(T));
__bang_write_value(y_nram_temp, elem_count, (T)0);
int y_n_offset = align_hw * align_c;
int y_h_offset = shape_seg.w * align_c;
int y_w_offset = align_c;
int y_c_offset = 1;
int x_n_offset = shape_seg.h * shape_seg.w * x_full.c;
int x_h_offset = shape_seg.w * x_full.c;
int x_w_offset = x_full.c;
int x_c_offset = 1;
int h_feature = y_full.h;
int w_feature = y_full.w;
int y_start = 0;
int x_start = 0;
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
for (int widx = 0; widx < shape_seg.w; ++widx) {
int h_abs = hidx + position.h_start;
int w_abs = widx + position.w_start;
int y_offset = y_start;
int x_offset = x_start;
y_offset += hidx * y_h_offset + widx * y_w_offset;
x_offset += hidx * x_h_offset + widx * x_w_offset;
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
const int hend = h_feature + half_h_mask - h_abs < h_mask
? h_feature + half_h_mask - h_abs
: h_mask;
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
const int wend = w_feature + half_w_mask - w_abs < w_mask
? w_feature + half_w_mask - w_abs
: w_mask;
// (h, w ) with mask-indexed
// (h + hidx - half_h_mask, w + widx - half_w_mask) with feature-indexed
y_offset += ((hstart + h_abs - half_h_mask) * x_full.w + wstart +
w_abs - half_w_mask) *
y_c_offset;
x_offset += (hstart * w_mask + wstart) * x_c_offset;
int count = wend - wstart;
__memcpy(y_nram_temp + y_offset, x_nram + x_offset, count * sizeof(T),
NRAM2NRAM, y_c_offset * w_feature * sizeof(T),
x_c_offset * w_mask * sizeof(T), hend - hstart - 1);
}
}
y_start += y_n_offset;
x_start += x_n_offset;
}
// transpose y
T *y_nram = y_nram_temp + shape_seg.n * align_hw * align_c;
Shape y_seg{shape_seg.n, shape_seg.h, shape_seg.w, y_full.c};
transposeData(y_nram, y_nram_temp, y_seg);
swap(align_c, align_hw);
// store y from nram to dram
int y_n_offset_full = y_full.h * y_full.w * y_full.c;
int y_w_offset_full = y_full.c;
int y_c_offset_full = 1;
int y_dram_start =
position.n_start * y_n_offset_full +
(position.h_start * y_full.w + position.w_start) * y_c_offset_full;
int y_nram_start = 0;
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
int y_dram_offset = y_dram_start + nidx * y_n_offset_full;
int y_nram_offset = y_nram_start + nidx * align_hw * align_c;
__memcpy(y_dram + y_dram_offset, y_nram + y_nram_offset,
shape_seg.h * shape_seg.w * sizeof(T), NRAM2GDRAM,
y_w_offset_full * sizeof(T), align_c * sizeof(T),
h_feature * w_feature - 1);
}
}
template <typename T>
__mlu_func__ void psamaskCollectBackward(
const T *dy_dram, T *dx_dram, const PositionInCore &position,
const Shape &dy_full, const Shape &dx_full, const Shape &shape_seg,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask) {
T *dy_nram = (T *)buf;
T *dx_nram =
dy_nram + CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * dy_full.c,
COMPUTE_COUNT_ALIGN / sizeof(T));
loadDataFromDramToNram(dy_nram, dy_dram, position, dy_full);
// fill zeros to output
int elem_count =
CEIL_ALIGN(shape_seg.n * shape_seg.h * shape_seg.w * shape_seg.c,
NFU_ALIGN_SIZE / sizeof(T));
__bang_write_value(dx_nram, elem_count, (T)0);
int dy_n_offset = shape_seg.h * shape_seg.w * dy_full.c;
int dy_h_offset = shape_seg.w * dy_full.c;
int dy_w_offset = dy_full.c;
int dy_c_offset = 1;
int dx_n_offset = shape_seg.h * shape_seg.w * dx_full.c;
int dx_h_offset = shape_seg.w * dx_full.c;
int dx_w_offset = dx_full.c;
int dx_c_offset = 1;
int h_feature = dy_full.h;
int w_feature = dy_full.w;
int dy_start = 0;
int dx_start = 0;
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
for (int widx = 0; widx < shape_seg.w; ++widx) {
int h_abs = hidx + position.h_start;
int w_abs = widx + position.w_start;
int dy_offset = dy_start;
int dx_offset = dx_start;
dy_offset += hidx * dy_h_offset + widx * dy_w_offset;
dx_offset += hidx * dx_h_offset + widx * dx_w_offset;
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
const int hend = h_feature + half_h_mask - h_abs < h_mask
? h_feature + half_h_mask - h_abs
: h_mask;
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
const int wend = w_feature + half_w_mask - w_abs < w_mask
? w_feature + half_w_mask - w_abs
: w_mask;
// (h, w ) with mask-indexed
// (h + h_abs - half_h_mask, w + w_abs - half_w_mask) with
// feature-indexed
dy_offset += ((hstart + h_abs - half_h_mask) * w_feature + wstart +
w_abs - half_w_mask) *
dy_c_offset;
dx_offset += (hstart * w_mask + wstart) * dx_c_offset;
int count = wend - wstart;
__memcpy(dx_nram + dx_offset, dy_nram + dy_offset, count * sizeof(T),
NRAM2NRAM, dx_c_offset * w_mask * sizeof(T),
dy_c_offset * w_feature * sizeof(T), hend - hstart - 1);
}
}
dy_start += dy_n_offset;
dx_start += dx_n_offset;
}
storeDataFromNramToDram(dx_dram, dx_nram, position, dx_full);
}
template <typename T>
__mlu_func__ void psamaskDistributeBackward(
const T *dy_dram, T *dx_dram, const PositionInCore &position,
const Shape &dy_full, const Shape &dx_full, const Shape &shape_seg,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask) {
// load dy from dram to nram
T *dy_nram_temp = (T *)buf;
int dy_n_offset_full = dy_full.h * dy_full.w * dy_full.c;
int dy_c_offset_full = 1;
int h_feature = dy_full.h;
int w_feature = dy_full.w;
int align_c =
CEIL_ALIGN(shape_seg.h * shape_seg.w, COMPUTE_COUNT_ALIGN / sizeof(T));
int align_hw =
CEIL_ALIGN(h_feature * w_feature, COMPUTE_COUNT_ALIGN / sizeof(T));
int dy_dram_start =
position.n_start * dy_n_offset_full +
(position.h_start * w_feature + position.w_start) * dy_c_offset_full;
int dy_nram_start = 0;
for (int i = 0; i < shape_seg.n; ++i) {
int dy_nram_offset = dy_nram_start + i * (align_hw * align_c);
int dy_dram_offset = dy_dram_start + i * dy_n_offset_full;
__memcpy(dy_nram_temp + dy_nram_offset, dy_dram + dy_dram_offset,
shape_seg.h * shape_seg.w * sizeof(T), GDRAM2NRAM,
align_c * sizeof(T), dy_full.c * sizeof(T),
h_feature * w_feature - 1);
}
T *dy_nram = dy_nram_temp + shape_seg.n * align_hw * align_c;
Shape dy_seg{shape_seg.n, h_feature, w_feature, shape_seg.h * shape_seg.w};
transposeData(dy_nram, dy_nram_temp, dy_seg);
swap(align_c, align_hw);
// fill zeros to dx
T *dx_nram = dy_nram + shape_seg.n * align_hw * align_c;
int dx_size = shape_seg.n * shape_seg.h * shape_seg.w * dx_full.c;
__bang_write_value(dx_nram, CEIL_ALIGN(dx_size, NFU_ALIGN_SIZE / sizeof(T)),
(T)0);
int dy_n_offset_seg = align_hw * align_c;
int dy_h_offset_seg = shape_seg.w * align_c;
int dy_w_offset_seg = align_c;
int dy_c_offset_seg = 1;
int dx_n_offset_seg = shape_seg.h * shape_seg.w * shape_seg.c;
int dx_h_offset_seg = shape_seg.w * shape_seg.c;
int dx_w_offset_seg = shape_seg.c;
int dx_c_offset_seg = 1;
int dy_start = 0;
int dx_start = 0;
for (int nidx = 0; nidx < shape_seg.n; ++nidx) {
for (int hidx = 0; hidx < shape_seg.h; ++hidx) {
for (int widx = 0; widx < shape_seg.w; ++widx) {
int h_abs = hidx + position.h_start;
int w_abs = widx + position.w_start;
int dy_offset = dy_start;
int dx_offset = dx_start;
dy_offset += hidx * dy_h_offset_seg + widx * dy_w_offset_seg;
dx_offset += hidx * dx_h_offset_seg + widx * dx_w_offset_seg;
const int hstart = half_h_mask - h_abs > 0 ? half_h_mask - h_abs : 0;
const int hend = h_feature + half_h_mask - h_abs < h_mask
? h_feature + half_h_mask - h_abs
: h_mask;
const int wstart = half_w_mask - w_abs > 0 ? half_w_mask - w_abs : 0;
const int wend = w_feature + half_w_mask - w_abs < w_mask
? w_feature + half_w_mask - w_abs
: w_mask;
// (h, w ) with mask-indexed
// (h + h_abs - half_h_mask, w + w_abs - half_w_mask) with
// feature-indexed
dy_offset += ((hstart + h_abs - half_h_mask) * w_feature + wstart +
w_abs - half_w_mask) *
dy_c_offset_seg;
dx_offset += (hstart * w_mask + wstart) * dx_c_offset_seg;
int count = wend - wstart;
__memcpy(dx_nram + dx_offset, dy_nram + dy_offset, count * sizeof(T),
NRAM2NRAM, w_mask * dx_c_offset_seg * sizeof(T),
w_feature * dy_c_offset_seg * sizeof(T), hend - hstart - 1);
}
}
dy_start += dy_n_offset_seg;
dx_start += dx_n_offset_seg;
}
storeDataFromNramToDram(dx_dram, dx_nram, position, dx_full);
}
template <typename T>
__mlu_func__ void psamaskBase(const T *input_dram, T *output_dram,
const Shape &input_full, const Shape &output_full,
LimitParam &limit, const PsamaskType psa_type,
const DimPartitionType core_partition,
const DimPartitionType cluster_partition,
const bool is_forward, const int h_mask,
const int w_mask, const int half_h_mask,
const int half_w_mask, const int n_per_core,
const int h_per_core, const int n_per_cluster,
const int h_per_cluster) {
PositionInCore position_full;
PositionInCore position_seg;
position_full.w_start = 0;
position_full.w_end = output_full.w;
int n_num_in_cluster = n_per_cluster;
int h_num_in_cluster = h_per_cluster;
switch (cluster_partition) {
case PARTITION_N: {
position_full.h_start = 0;
position_full.h_end = input_full.h;
position_full.n_start = taskIdY * n_per_cluster;
int cluster_need = (input_full.n + n_per_cluster - 1) / n_per_cluster;
if (taskIdY >= cluster_need) return;
int n_remainder = input_full.n - (cluster_need - 1) * n_per_cluster;
n_num_in_cluster =
(taskIdY == cluster_need - 1) ? n_remainder : n_per_cluster;
position_full.n_end = position_full.n_start + n_num_in_cluster;
}; break;
case PARTITION_H: {
position_full.n_start = 0;
position_full.n_end = input_full.n;
position_full.h_start = taskIdY * h_per_cluster;
int cluster_need = (input_full.h + h_per_cluster - 1) / h_per_cluster;
if (taskIdY >= cluster_need) return;
int h_remainder = input_full.h - (cluster_need - 1) * h_per_cluster;
h_num_in_cluster =
(taskIdY == cluster_need - 1) ? h_remainder : h_per_cluster;
position_full.h_end = position_full.h_start + h_num_in_cluster;
}; break;
}
switch (core_partition) {
case PARTITION_N: {
position_full.n_start += taskIdX * n_per_core;
int core_need = (n_num_in_cluster + n_per_core - 1) / n_per_core;
if (taskIdX >= core_need) return;
int n_remainder = n_num_in_cluster - (core_need - 1) * n_per_core;
position_full.n_end =
position_full.n_start +
((taskIdX == core_need - 1) ? n_remainder : n_per_core);
}; break;
case PARTITION_H: {
position_full.h_start += taskIdX * h_per_core;
int core_need = (h_num_in_cluster + h_per_core - 1) / h_per_core;
if (taskIdX >= core_need) return;
int h_remainder = h_num_in_cluster - (core_need - 1) * h_per_core;
position_full.h_end =
position_full.h_start +
((taskIdX == core_need - 1) ? h_remainder : h_per_core);
}; break;
}
// the count of n ,h and w need to be processed in the current core
int shape_core_n = position_full.n_end - position_full.n_start;
int shape_core_h = position_full.h_end - position_full.h_start;
int shape_core_w = input_full.w;
limit.n = limit.n < shape_core_n ? limit.n : shape_core_n;
limit.h = limit.h < shape_core_h ? limit.h : shape_core_h;
limit.w = limit.w < shape_core_w ? limit.w : shape_core_w;
// load the data to nram according to the limit
for (int nidx = position_full.n_start; nidx < position_full.n_end;
nidx += limit.n) {
position_seg.n_start = nidx;
position_seg.n_end =
position_seg.n_start + (position_full.n_end - nidx < limit.n
? position_full.n_end - nidx
: limit.n);
for (int hidx = position_full.h_start; hidx < position_full.h_end;
hidx += limit.h) {
position_seg.h_start = hidx;
position_seg.h_end =
position_seg.h_start + (position_full.h_end - hidx < limit.h
? position_full.h_end - hidx
: limit.h);
for (int widx = position_full.w_start; widx < position_full.w_end;
widx += limit.w) {
position_seg.w_start = widx;
position_seg.w_end =
position_seg.w_start + (position_full.w_end - widx < limit.w
? position_full.w_end - widx
: limit.w);
// record the segment of output except the size of channel
// channel segments of output and input are the same
Shape shape_seg;
shape_seg.n = position_seg.n_end - position_seg.n_start;
shape_seg.h = position_seg.h_end - position_seg.h_start;
shape_seg.w = position_seg.w_end - position_seg.w_start;
shape_seg.c = output_full.c;
switch (psa_type) {
case COLLECT: {
if (is_forward) {
psamaskCollectForward(input_dram, output_dram, position_seg,
input_full, output_full, shape_seg, h_mask,
w_mask, half_h_mask, half_w_mask);
} else {
psamaskCollectBackward(input_dram, output_dram, position_seg,
input_full, output_full, shape_seg, h_mask,
w_mask, half_h_mask, half_w_mask);
}
} break;
case DISTRIBUTE: {
if (is_forward) {
psamaskDistributeForward(input_dram, output_dram, position_seg,
input_full, output_full, shape_seg,
h_mask, w_mask, half_h_mask,
half_w_mask);
} else {
psamaskDistributeBackward(input_dram, output_dram, position_seg,
input_full, output_full, shape_seg,
h_mask, w_mask, half_h_mask,
half_w_mask);
}
} break;
}
}
}
}
}
template <typename T>
__mlu_global__ void MLUUnion1KernelPsamaskForward(
const T *x, T *y, const PsamaskType psa_type,
const DimPartitionType core_partition,
const DimPartitionType cluster_partition, const int batch,
const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int x_c, const int y_c, const int half_h_mask,
const int half_w_mask, const int n_per_core, const int h_per_core,
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
const int limit_h_seg, const int limit_w_seg) {
if (coreId == 0x80) {
return;
}
Shape x_full, y_full;
x_full.n = batch;
x_full.h = h_feature;
x_full.w = w_feature;
x_full.c = x_c;
y_full.n = batch;
y_full.h = h_feature;
y_full.w = w_feature;
y_full.c = y_c;
LimitParam limit;
limit.n = limit_n_seg;
limit.h = limit_h_seg;
limit.w = limit_w_seg;
psamaskBase(x, y, x_full, y_full, limit, psa_type, core_partition,
cluster_partition, true, h_mask, w_mask, half_h_mask, half_w_mask,
n_per_core, h_per_core, n_per_cluster, h_per_cluster);
}
template <typename T>
__mlu_global__ void MLUUnion1KernelPsamaskBackward(
const T *dy, T *dx, const PsamaskType psa_type,
const DimPartitionType core_partition,
const DimPartitionType cluster_partition, const int batch,
const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int dx_c, const int dy_c, const int half_h_mask,
const int half_w_mask, const int n_per_core, const int h_per_core,
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
const int limit_h_seg, const int limit_w_seg) {
if (coreId == 0x80) {
return;
}
Shape dy_full, dx_full;
dx_full.n = batch;
dx_full.h = h_feature;
dx_full.w = w_feature;
dx_full.c = dx_c;
dy_full.n = batch;
dy_full.h = h_feature;
dy_full.w = w_feature;
dy_full.c = dy_c;
LimitParam limit;
limit.n = limit_n_seg;
limit.h = limit_h_seg;
limit.w = limit_w_seg;
psamaskBase(dy, dx, dy_full, dx_full, limit, psa_type, core_partition,
cluster_partition, false, h_mask, w_mask, half_h_mask,
half_w_mask, n_per_core, h_per_core, n_per_cluster,
h_per_cluster);
}
void KernelPsamaskForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *x, void *y, const PsamaskType psa_type,
const DimPartitionType core_partition,
const DimPartitionType cluster_partition, const int batch,
const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int x_c, const int y_c, const int half_h_mask,
const int half_w_mask, const int n_per_core, const int h_per_core,
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
const int limit_h_seg, const int limit_w_seg) {
MLUUnion1KernelPsamaskForward<<<k_dim, k_type, queue>>>(
static_cast<const float *>(x), static_cast<float *>(y), psa_type,
core_partition, cluster_partition, batch, h_feature, w_feature, h_mask,
w_mask, x_c, y_c, half_h_mask, half_w_mask, n_per_core, h_per_core,
n_per_cluster, h_per_cluster, limit_n_seg, limit_h_seg, limit_w_seg);
}
void KernelPsamaskBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *dy, void *dx, const PsamaskType psa_type,
const DimPartitionType core_partition,
const DimPartitionType cluster_partition, const int batch,
const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int dx_c, const int dy_c, const int half_h_mask,
const int half_w_mask, const int n_per_core, const int h_per_core,
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
const int limit_h_seg, const int limit_w_seg) {
MLUUnion1KernelPsamaskBackward<<<k_dim, k_type, queue>>>(
static_cast<const float *>(dy), static_cast<float *>(dx), psa_type,
core_partition, cluster_partition, batch, h_feature, w_feature, h_mask,
w_mask, dx_c, dy_c, half_h_mask, half_w_mask, n_per_core, h_per_core,
n_per_cluster, h_per_cluster, limit_n_seg, limit_h_seg, limit_w_seg);
}
mmcv/ops/csrc/common/mlu/psamask_utils.hpp
deleted
100644 → 0
View file @
2611b990
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef PSAMASK_UTILS_HPP_
#define PSAMASK_UTILS_HPP_
typedef
enum
{
COLLECT
=
0
,
DISTRIBUTE
=
1
,
}
PsamaskType
;
typedef
enum
{
PARTITION_N
=
0
,
PARTITION_H
=
1
,
}
DimPartitionType
;
struct
PartitionSeg
{
int
h_per_cluster
;
int
n_per_cluster
;
int
h_per_core
;
int
n_per_core
;
DimPartitionType
cluster_partition
;
DimPartitionType
core_partition
;
};
struct
Shape
{
int
n
;
int
h
;
int
w
;
int
c
;
};
struct
LimitParam
{
int
n
;
int
h
;
int
w
;
};
struct
PositionInCore
{
int
n_start
;
int
n_end
;
int
h_start
;
int
h_end
;
int
w_start
;
int
w_end
;
};
#endif // PSAMASK_UTILS_HPP_
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
View file @
92b3e861
...
@@ -9,136 +9,7 @@
...
@@ -9,136 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include <algorithm>
#include "mlu_common_helper.h"
#include "psamask_utils.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define COMPUTE_COUNT_ALIGN 64
void
KernelPsamaskForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
x
,
void
*
y
,
const
PsamaskType
psa_type
,
const
DimPartitionType
core_partition
,
const
DimPartitionType
cluster_partition
,
const
int
batch
,
const
int
h_feature
,
const
int
w_feature
,
const
int
h_mask
,
const
int
w_mask
,
const
int
x_c
,
const
int
y_c
,
const
int
half_h_mask
,
const
int
half_w_mask
,
const
int
n_per_core
,
const
int
h_per_core
,
const
int
n_per_cluster
,
const
int
h_per_cluster
,
const
int
limit_n_seg
,
const
int
limit_h_seg
,
const
int
limit_w_seg
);
void
KernelPsamaskBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
dy
,
void
*
dx
,
const
PsamaskType
psa_type
,
const
DimPartitionType
core_partition
,
const
DimPartitionType
cluster_partition
,
const
int
batch
,
const
int
h_feature
,
const
int
w_feature
,
const
int
h_mask
,
const
int
w_mask
,
const
int
dx_c
,
const
int
dy_c
,
const
int
half_h_mask
,
const
int
half_w_mask
,
const
int
n_per_core
,
const
int
h_per_core
,
const
int
n_per_cluster
,
const
int
h_per_cluster
,
const
int
limit_n_seg
,
const
int
limit_h_seg
,
const
int
limit_w_seg
);
namespace
{
void
policyFunc
(
cnrtDim3_t
*
k_dim_ptr
,
cnrtFunctionType_t
*
f_type_ptr
,
PartitionSeg
*
partition_ptr
,
const
int
n
,
const
int
h_feature
)
{
unsigned
int
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
unsigned
int
use_cluster_num
=
cluster_num
;
unsigned
int
use_core_num
=
core_dim
;
if
(
n
>=
cluster_num
||
n
>=
h_feature
)
{
partition_ptr
->
cluster_partition
=
PARTITION_N
;
partition_ptr
->
n_per_cluster
=
(
n
+
cluster_num
-
1
)
/
cluster_num
;
partition_ptr
->
h_per_cluster
=
h_feature
;
use_cluster_num
=
(
n
+
partition_ptr
->
n_per_cluster
-
1
)
/
partition_ptr
->
n_per_cluster
;
}
else
{
partition_ptr
->
cluster_partition
=
PARTITION_H
;
partition_ptr
->
h_per_cluster
=
(
h_feature
+
cluster_num
-
1
)
/
cluster_num
;
partition_ptr
->
n_per_cluster
=
n
;
use_cluster_num
=
(
h_feature
+
partition_ptr
->
h_per_cluster
-
1
)
/
partition_ptr
->
h_per_cluster
;
}
if
(
partition_ptr
->
n_per_cluster
>=
core_dim
||
partition_ptr
->
n_per_cluster
>=
partition_ptr
->
h_per_cluster
)
{
partition_ptr
->
core_partition
=
PARTITION_N
;
partition_ptr
->
n_per_core
=
(
partition_ptr
->
n_per_cluster
+
core_dim
-
1
)
/
core_dim
;
partition_ptr
->
h_per_core
=
partition_ptr
->
h_per_cluster
;
use_core_num
=
(
partition_ptr
->
n_per_cluster
+
partition_ptr
->
n_per_core
-
1
)
/
partition_ptr
->
n_per_core
;
}
else
{
partition_ptr
->
core_partition
=
PARTITION_H
;
partition_ptr
->
h_per_core
=
(
partition_ptr
->
h_per_cluster
+
core_dim
-
1
)
/
core_dim
;
partition_ptr
->
n_per_core
=
partition_ptr
->
n_per_cluster
;
use_core_num
=
(
partition_ptr
->
h_per_cluster
+
partition_ptr
->
h_per_core
-
1
)
/
partition_ptr
->
h_per_core
;
}
*
k_dim_ptr
=
{
core_dim
,
use_cluster_num
,
1
};
}
}
// namespace
bool
findLimit
(
const
int
shape_core_n
,
const
int
shape_core_h
,
const
int
shape_core_w
,
const
int
shape_core_ci
,
const
int
shape_core_co
,
int
*
limit_n_seg_ptr
,
int
*
limit_h_seg_ptr
,
int
*
limit_w_seg_ptr
,
const
int
psa_type
)
{
const
bool
need_temp
=
psa_type
==
1
;
const
int
input_bytes
=
sizeof
(
float
);
int
limit_n_seg
=
shape_core_n
;
int
limit_h_seg
=
shape_core_h
;
int
limit_w_seg
=
shape_core_w
;
const
int
max_nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
int
align_base_128
=
NFU_ALIGN_SIZE
/
input_bytes
;
const
int
align_base_64
=
COMPUTE_COUNT_ALIGN
/
input_bytes
;
const
int
align_co
=
CEIL_ALIGN
(
shape_core_co
,
align_base_64
);
const
int
align_w
=
CEIL_ALIGN
(
shape_core_w
,
align_base_64
);
const
int
align_hw
=
CEIL_ALIGN
(
shape_core_h
*
shape_core_w
,
align_base_64
);
const
int
max_num
=
max_nram_size
/
input_bytes
;
int
n_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_h
*
shape_core_w
*
shape_core_ci
,
align_base_128
)
+
align_hw
*
align_co
*
(
1
+
need_temp
));
if
(
n_limit
>
0
)
{
n_limit
=
std
::
min
(
n_limit
,
shape_core_n
);
limit_n_seg
=
n_limit
;
}
else
{
int
h_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_w
*
shape_core_ci
,
align_base_128
)
+
align_w
*
align_co
*
(
1
+
need_temp
));
if
(
h_limit
>
0
)
{
h_limit
=
std
::
min
(
h_limit
,
shape_core_h
);
limit_h_seg
=
h_limit
;
limit_n_seg
=
1
;
}
else
{
int
w_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_ci
,
align_base_128
)
+
CEIL_ALIGN
(
align_co
,
align_base_128
)
*
(
1
+
need_temp
));
if
(
w_limit
>
0
&&
w_limit
>=
(
COMPUTE_COUNT_ALIGN
/
input_bytes
))
{
w_limit
=
std
::
min
(
w_limit
,
shape_core_w
);
w_limit
=
w_limit
/
(
COMPUTE_COUNT_ALIGN
/
input_bytes
)
*
(
COMPUTE_COUNT_ALIGN
/
input_bytes
);
limit_w_seg
=
w_limit
;
limit_h_seg
=
1
;
limit_n_seg
=
1
;
}
else
{
CNLOG
(
INFO
)
<<
"The size of input channel is too large."
;
return
false
;
}
}
}
*
limit_n_seg_ptr
=
limit_n_seg
;
*
limit_h_seg_ptr
=
limit_h_seg
;
*
limit_w_seg_ptr
=
limit_w_seg
;
return
true
;
}
void
PSAMaskForwardMLUKernelLauncher
(
const
int
psa_type
,
const
Tensor
x
,
void
PSAMaskForwardMLUKernelLauncher
(
const
int
psa_type
,
const
Tensor
x
,
Tensor
y
,
const
int
num_
,
Tensor
y
,
const
int
num_
,
...
@@ -146,39 +17,7 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
...
@@ -146,39 +17,7 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
const
int
h_mask
,
const
int
w_mask
,
const
int
h_mask
,
const
int
w_mask
,
const
int
half_h_mask
,
const
int
half_h_mask
,
const
int
half_w_mask
)
{
const
int
half_w_mask
)
{
// params check
TORCH_CHECK
(
x
.
scalar_type
()
==
at
::
kFloat
,
"x type should be Float, got "
,
x
.
scalar_type
());
TORCH_CHECK
(
y
.
scalar_type
()
==
x
.
scalar_type
(),
"y should have the same type as x"
);
TORCH_CHECK
(
x
.
dim
()
==
4
,
"x should be a 4d tensor, got "
,
x
.
dim
(),
"D"
);
TORCH_CHECK
(
y
.
dim
()
==
4
,
"y should be a 4d tensor, got "
,
y
.
dim
(),
"D"
);
int
x_c
=
x
.
size
(
1
);
int
y_c
=
y
.
size
(
1
);
int
y_c
=
y
.
size
(
1
);
TORCH_CHECK
(
h_mask
*
w_mask
==
x_c
,
"channel of x should be the same as h_mask * w_mask"
);
TORCH_CHECK
(
h_feature
*
w_feature
==
y_c
,
"channel of y should be the same as h_feature * w_feature"
);
TORCH_CHECK
(
psa_type
==
0
||
psa_type
==
1
,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently"
);
if
(
x
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"skip zero-element tensor"
;
return
;
}
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
PartitionSeg
partition_info
;
policyFunc
(
&
k_dim
,
&
k_type
,
&
partition_info
,
num_
,
h_feature
);
int
n_limit_seg
,
h_limit_seg
,
w_limit_seg
;
bool
ret
=
findLimit
(
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
w_feature
,
x_c
,
y_c
,
&
n_limit_seg
,
&
h_limit_seg
,
&
w_limit_seg
,
psa_type
);
if
(
ret
!=
true
)
{
return
;
}
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
x
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
x
.
dim
());
...
@@ -186,22 +25,18 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
...
@@ -186,22 +25,18 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
at
::
Tensor
y_tmp
=
at
::
Tensor
y_tmp
=
at
::
empty
({
num_
,
y_c
,
h_feature
,
w_feature
},
x
.
options
(),
memory_format
);
at
::
empty
({
num_
,
y_c
,
h_feature
,
w_feature
},
x
.
options
(),
memory_format
);
// get compute queue
MluOpTensorDescriptor
x_desc
,
y_desc
;
auto
queue
=
torch_mlu
::
getCurQueue
();
x_desc
.
set_with_layout
(
x_tensor
,
MLUOP_LAYOUT_NHWC
);
y_desc
.
set_with_layout
(
y_tmp
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
auto
handle
=
mluOpGetCurrentHandle
();
auto
x_impl
=
torch_mlu
::
getMluTensorImpl
(
x_tensor
);
auto
x_impl
=
torch_mlu
::
getMluTensorImpl
(
x_tensor
);
auto
x_ptr
=
x_impl
->
cnnlMalloc
();
auto
x_ptr
=
x_impl
->
cnnlMalloc
();
auto
y_impl
=
torch_mlu
::
getMluTensorImpl
(
y_tmp
);
auto
y_impl
=
torch_mlu
::
getMluTensorImpl
(
y_tmp
);
auto
y_ptr
=
y_impl
->
cnnlMalloc
();
auto
y_ptr
=
y_impl
->
cnnlMalloc
();
KernelPsamaskForward
(
mluOpPsamaskForward
(
handle
,
psa_type
,
x_desc
.
desc
(),
x_ptr
,
h_mask
,
w_mask
,
k_dim
,
k_type
,
queue
,
x_ptr
,
y_ptr
,
(
PsamaskType
)
psa_type
,
y_desc
.
desc
(),
y_ptr
);
partition_info
.
core_partition
,
partition_info
.
cluster_partition
,
num_
,
h_feature
,
w_feature
,
h_mask
,
w_mask
,
x_c
,
y_c
,
half_h_mask
,
half_w_mask
,
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
partition_info
.
n_per_cluster
,
partition_info
.
h_per_cluster
,
n_limit_seg
,
h_limit_seg
,
w_limit_seg
);
y
.
copy_
(
y_tmp
);
y
.
copy_
(
y_tmp
);
}
}
...
@@ -212,39 +47,7 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
...
@@ -212,39 +47,7 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
const
int
h_mask
,
const
int
w_mask
,
const
int
h_mask
,
const
int
w_mask
,
const
int
half_h_mask
,
const
int
half_h_mask
,
const
int
half_w_mask
)
{
const
int
half_w_mask
)
{
// params check
TORCH_CHECK
(
dy
.
scalar_type
()
==
at
::
kFloat
,
"dy type should be Float, got "
,
dy
.
scalar_type
());
TORCH_CHECK
(
dx
.
scalar_type
()
==
dy
.
scalar_type
(),
"dx should have the same type as dy"
);
TORCH_CHECK
(
dy
.
dim
()
==
4
,
"dy should be a 4d tensor, got "
,
dy
.
dim
(),
"D"
);
TORCH_CHECK
(
dx
.
dim
()
==
4
,
"dx should be a 4d tensor, got "
,
dx
.
dim
(),
"D"
);
int
dy_c
=
dy
.
size
(
1
);
int
dx_c
=
dx
.
size
(
1
);
int
dx_c
=
dx
.
size
(
1
);
TORCH_CHECK
(
h_feature
*
w_feature
==
dy_c
,
"channel of dy should be the same as h_feature * w_feature"
);
TORCH_CHECK
(
h_mask
*
w_mask
==
dx_c
,
"channel of dx should be the same as h_mask * w_mask"
);
TORCH_CHECK
(
psa_type
==
0
||
psa_type
==
1
,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently"
);
if
(
dx
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"skip zero-element tensor"
;
return
;
}
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
PartitionSeg
partition_info
;
policyFunc
(
&
k_dim
,
&
k_type
,
&
partition_info
,
num_
,
h_feature
);
int
n_limit_seg
,
h_limit_seg
,
w_limit_seg
;
bool
ret
=
findLimit
(
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
w_feature
,
dx_c
,
dy_c
,
&
n_limit_seg
,
&
h_limit_seg
,
&
w_limit_seg
,
psa_type
);
if
(
ret
!=
true
)
{
return
;
}
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
dy
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
dy
.
dim
());
...
@@ -252,8 +55,11 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
...
@@ -252,8 +55,11 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
at
::
Tensor
dx_tmp
=
at
::
empty
({
num_
,
dx_c
,
h_feature
,
w_feature
},
at
::
Tensor
dx_tmp
=
at
::
empty
({
num_
,
dx_c
,
h_feature
,
w_feature
},
dy
.
options
(),
memory_format
);
dy
.
options
(),
memory_format
);
// get compute queue
MluOpTensorDescriptor
dy_desc
,
dx_tmp_desc
;
auto
queue
=
torch_mlu
::
getCurQueue
();
dy_desc
.
set_with_layout
(
dy_tensor
,
MLUOP_LAYOUT_NHWC
);
dx_tmp_desc
.
set_with_layout
(
dx_tmp
,
MLUOP_LAYOUT_NHWC
);
auto
handle
=
mluOpGetCurrentHandle
();
// get ptr of tensors
// get ptr of tensors
auto
dx_impl
=
torch_mlu
::
getMluTensorImpl
(
dx_tmp
);
auto
dx_impl
=
torch_mlu
::
getMluTensorImpl
(
dx_tmp
);
...
@@ -261,13 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
...
@@ -261,13 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
auto
dy_impl
=
torch_mlu
::
getMluTensorImpl
(
dy_tensor
);
auto
dy_impl
=
torch_mlu
::
getMluTensorImpl
(
dy_tensor
);
auto
dy_ptr
=
dy_impl
->
cnnlMalloc
();
auto
dy_ptr
=
dy_impl
->
cnnlMalloc
();
KernelPsamaskBackward
(
mluOpPsamaskBackward
(
handle
,
psa_type
,
dy_desc
.
desc
(),
dy_ptr
,
h_mask
,
w_mask
,
k_dim
,
k_type
,
queue
,
dy_ptr
,
dx_ptr
,
(
PsamaskType
)
psa_type
,
dx_tmp_desc
.
desc
(),
dx_ptr
);
partition_info
.
core_partition
,
partition_info
.
cluster_partition
,
num_
,
h_feature
,
w_feature
,
h_mask
,
w_mask
,
dx_c
,
dy_c
,
half_h_mask
,
half_w_mask
,
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
partition_info
.
n_per_cluster
,
partition_info
.
h_per_cluster
,
n_limit_seg
,
h_limit_seg
,
w_limit_seg
);
dx
.
copy_
(
dx_tmp
);
dx
.
copy_
(
dx_tmp
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment