Unverified Commit 92b3e861 authored by liuduanhui's avatar liuduanhui Committed by GitHub
Browse files

[Refactor] Replace the implementation of psa_mask with mlu-ops. (#2810)

parent 2611b990
This diff is collapsed.
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef PSAMASK_UTILS_HPP_
#define PSAMASK_UTILS_HPP_
typedef enum {
COLLECT = 0,
DISTRIBUTE = 1,
} PsamaskType;
typedef enum {
PARTITION_N = 0,
PARTITION_H = 1,
} DimPartitionType;
struct PartitionSeg {
int h_per_cluster;
int n_per_cluster;
int h_per_core;
int n_per_core;
DimPartitionType cluster_partition;
DimPartitionType core_partition;
};
struct Shape {
int n;
int h;
int w;
int c;
};
struct LimitParam {
int n;
int h;
int w;
};
struct PositionInCore {
int n_start;
int n_end;
int h_start;
int h_end;
int w_start;
int w_end;
};
#endif // PSAMASK_UTILS_HPP_
......@@ -9,136 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include "psamask_utils.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define COMPUTE_COUNT_ALIGN 64
void KernelPsamaskForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *x, void *y, const PsamaskType psa_type,
const DimPartitionType core_partition,
const DimPartitionType cluster_partition, const int batch,
const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int x_c, const int y_c, const int half_h_mask,
const int half_w_mask, const int n_per_core, const int h_per_core,
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
const int limit_h_seg, const int limit_w_seg);
void KernelPsamaskBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *dy, void *dx, const PsamaskType psa_type,
const DimPartitionType core_partition,
const DimPartitionType cluster_partition, const int batch,
const int h_feature, const int w_feature, const int h_mask,
const int w_mask, const int dx_c, const int dy_c, const int half_h_mask,
const int half_w_mask, const int n_per_core, const int h_per_core,
const int n_per_cluster, const int h_per_cluster, const int limit_n_seg,
const int limit_h_seg, const int limit_w_seg);
namespace {
void policyFunc(cnrtDim3_t *k_dim_ptr, cnrtFunctionType_t *f_type_ptr,
PartitionSeg *partition_ptr, const int n, const int h_feature) {
unsigned int core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
unsigned int use_cluster_num = cluster_num;
unsigned int use_core_num = core_dim;
if (n >= cluster_num || n >= h_feature) {
partition_ptr->cluster_partition = PARTITION_N;
partition_ptr->n_per_cluster = (n + cluster_num - 1) / cluster_num;
partition_ptr->h_per_cluster = h_feature;
use_cluster_num =
(n + partition_ptr->n_per_cluster - 1) / partition_ptr->n_per_cluster;
} else {
partition_ptr->cluster_partition = PARTITION_H;
partition_ptr->h_per_cluster = (h_feature + cluster_num - 1) / cluster_num;
partition_ptr->n_per_cluster = n;
use_cluster_num = (h_feature + partition_ptr->h_per_cluster - 1) /
partition_ptr->h_per_cluster;
}
if (partition_ptr->n_per_cluster >= core_dim ||
partition_ptr->n_per_cluster >= partition_ptr->h_per_cluster) {
partition_ptr->core_partition = PARTITION_N;
partition_ptr->n_per_core =
(partition_ptr->n_per_cluster + core_dim - 1) / core_dim;
partition_ptr->h_per_core = partition_ptr->h_per_cluster;
use_core_num =
(partition_ptr->n_per_cluster + partition_ptr->n_per_core - 1) /
partition_ptr->n_per_core;
} else {
partition_ptr->core_partition = PARTITION_H;
partition_ptr->h_per_core =
(partition_ptr->h_per_cluster + core_dim - 1) / core_dim;
partition_ptr->n_per_core = partition_ptr->n_per_cluster;
use_core_num =
(partition_ptr->h_per_cluster + partition_ptr->h_per_core - 1) /
partition_ptr->h_per_core;
}
*k_dim_ptr = {core_dim, use_cluster_num, 1};
}
} // namespace
bool findLimit(const int shape_core_n, const int shape_core_h,
const int shape_core_w, const int shape_core_ci,
const int shape_core_co, int *limit_n_seg_ptr,
int *limit_h_seg_ptr, int *limit_w_seg_ptr, const int psa_type) {
const bool need_temp = psa_type == 1;
const int input_bytes = sizeof(float);
int limit_n_seg = shape_core_n;
int limit_h_seg = shape_core_h;
int limit_w_seg = shape_core_w;
const int max_nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore);
const int align_base_128 = NFU_ALIGN_SIZE / input_bytes;
const int align_base_64 = COMPUTE_COUNT_ALIGN / input_bytes;
const int align_co = CEIL_ALIGN(shape_core_co, align_base_64);
const int align_w = CEIL_ALIGN(shape_core_w, align_base_64);
const int align_hw = CEIL_ALIGN(shape_core_h * shape_core_w, align_base_64);
const int max_num = max_nram_size / input_bytes;
int n_limit =
max_num /
(CEIL_ALIGN(shape_core_h * shape_core_w * shape_core_ci, align_base_128) +
align_hw * align_co * (1 + need_temp));
if (n_limit > 0) {
n_limit = std::min(n_limit, shape_core_n);
limit_n_seg = n_limit;
} else {
int h_limit =
max_num / (CEIL_ALIGN(shape_core_w * shape_core_ci, align_base_128) +
align_w * align_co * (1 + need_temp));
if (h_limit > 0) {
h_limit = std::min(h_limit, shape_core_h);
limit_h_seg = h_limit;
limit_n_seg = 1;
} else {
int w_limit =
max_num / (CEIL_ALIGN(shape_core_ci, align_base_128) +
CEIL_ALIGN(align_co, align_base_128) * (1 + need_temp));
if (w_limit > 0 && w_limit >= (COMPUTE_COUNT_ALIGN / input_bytes)) {
w_limit = std::min(w_limit, shape_core_w);
w_limit = w_limit / (COMPUTE_COUNT_ALIGN / input_bytes) *
(COMPUTE_COUNT_ALIGN / input_bytes);
limit_w_seg = w_limit;
limit_h_seg = 1;
limit_n_seg = 1;
} else {
CNLOG(INFO) << "The size of input channel is too large.";
return false;
}
}
}
*limit_n_seg_ptr = limit_n_seg;
*limit_h_seg_ptr = limit_h_seg;
*limit_w_seg_ptr = limit_w_seg;
return true;
}
#include "mlu_common_helper.h"
void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
Tensor y, const int num_,
......@@ -146,39 +17,7 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
const int h_mask, const int w_mask,
const int half_h_mask,
const int half_w_mask) {
// params check
TORCH_CHECK(x.scalar_type() == at::kFloat, "x type should be Float, got ",
x.scalar_type());
TORCH_CHECK(y.scalar_type() == x.scalar_type(),
"y should have the same type as x");
TORCH_CHECK(x.dim() == 4, "x should be a 4d tensor, got ", x.dim(), "D");
TORCH_CHECK(y.dim() == 4, "y should be a 4d tensor, got ", y.dim(), "D");
int x_c = x.size(1);
int y_c = y.size(1);
TORCH_CHECK(h_mask * w_mask == x_c,
"channel of x should be the same as h_mask * w_mask");
TORCH_CHECK(h_feature * w_feature == y_c,
"channel of y should be the same as h_feature * w_feature");
TORCH_CHECK(psa_type == 0 || psa_type == 1,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently");
if (x.numel() == 0) {
CNLOG(INFO) << "skip zero-element tensor";
return;
}
cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
cnrtDim3_t k_dim;
PartitionSeg partition_info;
policyFunc(&k_dim, &k_type, &partition_info, num_, h_feature);
int n_limit_seg, h_limit_seg, w_limit_seg;
bool ret =
findLimit(partition_info.n_per_core, partition_info.h_per_core, w_feature,
x_c, y_c, &n_limit_seg, &h_limit_seg, &w_limit_seg, psa_type);
if (ret != true) {
return;
}
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(x.dim());
......@@ -186,22 +25,18 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
at::Tensor y_tmp =
at::empty({num_, y_c, h_feature, w_feature}, x.options(), memory_format);
// get compute queue
auto queue = torch_mlu::getCurQueue();
MluOpTensorDescriptor x_desc, y_desc;
x_desc.set_with_layout(x_tensor, MLUOP_LAYOUT_NHWC);
y_desc.set_with_layout(y_tmp, MLUOP_LAYOUT_NHWC);
// get ptr of tensors
auto handle = mluOpGetCurrentHandle();
auto x_impl = torch_mlu::getMluTensorImpl(x_tensor);
auto x_ptr = x_impl->cnnlMalloc();
auto y_impl = torch_mlu::getMluTensorImpl(y_tmp);
auto y_ptr = y_impl->cnnlMalloc();
KernelPsamaskForward(
k_dim, k_type, queue, x_ptr, y_ptr, (PsamaskType)psa_type,
partition_info.core_partition, partition_info.cluster_partition, num_,
h_feature, w_feature, h_mask, w_mask, x_c, y_c, half_h_mask, half_w_mask,
partition_info.n_per_core, partition_info.h_per_core,
partition_info.n_per_cluster, partition_info.h_per_cluster, n_limit_seg,
h_limit_seg, w_limit_seg);
mluOpPsamaskForward(handle, psa_type, x_desc.desc(), x_ptr, h_mask, w_mask,
y_desc.desc(), y_ptr);
y.copy_(y_tmp);
}
......@@ -212,39 +47,7 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
const int h_mask, const int w_mask,
const int half_h_mask,
const int half_w_mask) {
// params check
TORCH_CHECK(dy.scalar_type() == at::kFloat, "dy type should be Float, got ",
dy.scalar_type());
TORCH_CHECK(dx.scalar_type() == dy.scalar_type(),
"dx should have the same type as dy");
TORCH_CHECK(dy.dim() == 4, "dy should be a 4d tensor, got ", dy.dim(), "D");
TORCH_CHECK(dx.dim() == 4, "dx should be a 4d tensor, got ", dx.dim(), "D");
int dy_c = dy.size(1);
int dx_c = dx.size(1);
TORCH_CHECK(h_feature * w_feature == dy_c,
"channel of dy should be the same as h_feature * w_feature");
TORCH_CHECK(h_mask * w_mask == dx_c,
"channel of dx should be the same as h_mask * w_mask");
TORCH_CHECK(psa_type == 0 || psa_type == 1,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently");
if (dx.numel() == 0) {
CNLOG(INFO) << "skip zero-element tensor";
return;
}
cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
cnrtDim3_t k_dim;
PartitionSeg partition_info;
policyFunc(&k_dim, &k_type, &partition_info, num_, h_feature);
int n_limit_seg, h_limit_seg, w_limit_seg;
bool ret =
findLimit(partition_info.n_per_core, partition_info.h_per_core, w_feature,
dx_c, dy_c, &n_limit_seg, &h_limit_seg, &w_limit_seg, psa_type);
if (ret != true) {
return;
}
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(dy.dim());
......@@ -252,8 +55,11 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
at::Tensor dx_tmp = at::empty({num_, dx_c, h_feature, w_feature},
dy.options(), memory_format);
// get compute queue
auto queue = torch_mlu::getCurQueue();
MluOpTensorDescriptor dy_desc, dx_tmp_desc;
dy_desc.set_with_layout(dy_tensor, MLUOP_LAYOUT_NHWC);
dx_tmp_desc.set_with_layout(dx_tmp, MLUOP_LAYOUT_NHWC);
auto handle = mluOpGetCurrentHandle();
// get ptr of tensors
auto dx_impl = torch_mlu::getMluTensorImpl(dx_tmp);
......@@ -261,13 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
auto dy_impl = torch_mlu::getMluTensorImpl(dy_tensor);
auto dy_ptr = dy_impl->cnnlMalloc();
KernelPsamaskBackward(
k_dim, k_type, queue, dy_ptr, dx_ptr, (PsamaskType)psa_type,
partition_info.core_partition, partition_info.cluster_partition, num_,
h_feature, w_feature, h_mask, w_mask, dx_c, dy_c, half_h_mask,
half_w_mask, partition_info.n_per_core, partition_info.h_per_core,
partition_info.n_per_cluster, partition_info.h_per_cluster, n_limit_seg,
h_limit_seg, w_limit_seg);
mluOpPsamaskBackward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask,
dx_tmp_desc.desc(), dx_ptr);
dx.copy_(dx_tmp);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment