Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
513a8502
Unverified
Commit
513a8502
authored
Feb 11, 2026
by
thatPepe
Committed by
GitHub
Feb 11, 2026
Browse files
Merge pull request #1010 from InfiniTensor/issue/899
issue/899 - fix: fix causal_softmax and rearrange bug
parents
c312f175
e4bce369
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
114 additions
and
244 deletions
+114
-244
src/infiniop/ops/causal_softmax/moore/causal_softmax_kernel.h
...infiniop/ops/causal_softmax/moore/causal_softmax_kernel.h
+1
-1
src/infiniop/ops/rearrange/moore/rearrange_kernel.h
src/infiniop/ops/rearrange/moore/rearrange_kernel.h
+16
-50
src/infiniop/ops/rearrange/moore/rearrange_moore.mu
src/infiniop/ops/rearrange/moore/rearrange_moore.mu
+97
-193
No files found.
src/infiniop/ops/causal_softmax/moore/causal_softmax_kernel.h
View file @
513a8502
...
@@ -28,7 +28,7 @@ __device__ void causalSoftmaxKernel(
...
@@ -28,7 +28,7 @@ __device__ void causalSoftmaxKernel(
// 1 | * * * ... * * |
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
// height: 3 col_id->
if
(
width
+
blockIdx
.
x
>=
threadIdx
.
x
+
height
)
{
if
(
width
+
blockIdx
.
x
>=
col
+
height
)
{
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
||
std
::
is_same_v
<
Tdata
,
cuda_bfloat16
>
)
{
if
constexpr
(
std
::
is_same_v
<
Tdata
,
half
>
||
std
::
is_same_v
<
Tdata
,
cuda_bfloat16
>
)
{
/*
/*
* MUSA does not support CUDA's native `hexp` function.
* MUSA does not support CUDA's native `hexp` function.
...
...
src/infiniop/ops/rearrange/moore/rearrange_kernel.h
View file @
513a8502
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
#define ARRAY_TYPE_STRIDE ptrdiff_t
#define ARRAY_TYPE_STRIDE ptrdiff_t
#define ARRAY_TYPE_SIZE size_t
#define ARRAY_TYPE_SIZE size_t
// 与 DEFINE_KERNELS_BY_CONSTRAINT 耦合,需要同时修改
#define MAX_BLOCK_ARRAY_SIZE 5
#define MAX_BLOCK_ARRAY_SIZE 5
#define MAX_GRID_ARRAY_SIZE 5
#define MAX_GRID_ARRAY_SIZE 5
...
@@ -16,7 +15,6 @@ struct ArrayStruct {
...
@@ -16,7 +15,6 @@ struct ArrayStruct {
ArrayType
a
[
ArrSize
];
ArrayType
a
[
ArrSize
];
};
};
// 各个元素分别代表:[grid_idx, block_idx, grid的stride相对于block的倍数,总的len限制]
template
<
typename
ElementType
>
template
<
typename
ElementType
>
struct
Constraint
{
struct
Constraint
{
ElementType
grid_idx
;
ElementType
grid_idx
;
...
@@ -29,9 +27,8 @@ struct Constraint {
...
@@ -29,9 +27,8 @@ struct Constraint {
#define IF_CONSTRAINT_1 , const ArrayStruct<1, Constraint<ARRAY_TYPE_SIZE>> constraints
#define IF_CONSTRAINT_1 , const ArrayStruct<1, Constraint<ARRAY_TYPE_SIZE>> constraints
#define IF_CONSTRAINT_2 , const ArrayStruct<2, Constraint<ARRAY_TYPE_SIZE>> constraints
#define IF_CONSTRAINT_2 , const ArrayStruct<2, Constraint<ARRAY_TYPE_SIZE>> constraints
// 定义宏生成内核函数
#define DEFINE_REARRANGE_KERNEL(Tmem_type, constraint_num, block_array_size, grid_array_size) \
#define DEFINE_REARRANGE_KERNEL(Tmem_type, constraint_num, block_array_size, grid_array_size) \
extern "C"
__global__ void
rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num(
\
extern "C"
INFINIOP_MOORE_KERNEL
rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num( \
void *__restrict__ dst, \
void *__restrict__ dst, \
const void *__restrict__ src, \
const void *__restrict__ src, \
const size_t block_dim, \
const size_t block_dim, \
...
@@ -48,15 +45,14 @@ struct Constraint {
...
@@ -48,15 +45,14 @@ struct Constraint {
return; \
return; \
} \
} \
\
\
/* 声明共享内存 */
\
__shared__ ptrdiff_t shared_src_offset; \
__shared__ ptrdiff_t shared_src_offset; \
__shared__ ptrdiff_t shared_dst_offset; \
__shared__ ptrdiff_t shared_dst_offset; \
\
\
if (constraint_num > 0) { \
if (constraint_num > 0) { \
__shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
__shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
\
\
if (threadIdx.x == 0) {
/* 只让0号线程计算 */
\
if (threadIdx.x == 0) {
\
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */
\
\
ptrdiff_t src_offset = 0; \
ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \
ptrdiff_t dst_offset = 0; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
...
@@ -64,13 +60,13 @@ struct Constraint {
...
@@ -64,13 +60,13 @@ struct Constraint {
size_t remaining \
size_t remaining \
= blockIdx.x; \
= blockIdx.x; \
\
\
for (
ssize
_t i = grid_array_size - 1; i >= 0; i--) {
\
for (
ptrdiff
_t i = grid_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % grid_len.a[i]; \
size_t idx = remaining % grid_len.a[i]; \
remaining /= grid_len.a[i]; \
remaining /= grid_len.a[i]; \
src_offset += idx * src_grid_stride.a[i]; \
src_offset += idx * src_grid_stride.a[i]; \
dst_offset += idx * dst_grid_stride.a[i]; \
dst_offset += idx * dst_grid_stride.a[i]; \
if (constraint_num > 0) { \
if (constraint_num > 0) { \
for (
ssize
_t j = 0; j < constraint_num; j++) {
\
for (
ptrdiff
_t j = 0; j < constraint_num; j++) { \
if (i == constraints.a[j].grid_idx) { \
if (i == constraints.a[j].grid_idx) { \
constraints_grid_idx_multiple[j] = idx * constraints.a[j].grid_div_block; \
constraints_grid_idx_multiple[j] = idx * constraints.a[j].grid_div_block; \
} \
} \
...
@@ -78,33 +74,30 @@ struct Constraint {
...
@@ -78,33 +74,30 @@ struct Constraint {
} \
} \
} \
} \
\
\
/* 将结果存入共享内存 */
\
shared_src_offset = src_offset; \
shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \
shared_dst_offset = dst_offset; \
for (
ssize
_t j = 0; j < constraint_num; j++) {
\
for (
ptrdiff
_t j = 0; j < constraint_num; j++) { \
shared_constraints_grid_idx_multiple[j] = constraints_grid_idx_multiple[j]; \
shared_constraints_grid_idx_multiple[j] = constraints_grid_idx_multiple[j]; \
} \
} \
} \
} \
\
\
/* 确保所有线程都能看到共享内存中的值 */
\
__syncthreads(); \
__syncthreads(); \
\
\
/* 所有线程直接使用计算好的偏移值 */
\
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
for (
ssize
_t j = 0; j < constraint_num; j++) {
\
for (
ptrdiff
_t j = 0; j < constraint_num; j++) { \
constraints_grid_idx_multiple[j] = shared_constraints_grid_idx_multiple[j]; \
constraints_grid_idx_multiple[j] = shared_constraints_grid_idx_multiple[j]; \
} \
} \
\
\
for (
ssize
_t i = block_array_size - 1; i >= 0; i--) {
\
for (
ptrdiff
_t i = block_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % block_len.a[i]; \
size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \
remaining /= block_len.a[i]; \
/* 计算偏移量 */
\
\
src_offset += idx * src_block_stride.a[i]; \
src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
if (constraint_num > 0) { \
if (constraint_num > 0) { \
for (
ssize
_t j = 0; j < constraint_num; j++) {
\
for (
ptrdiff
_t j = 0; j < constraint_num; j++) { \
if (i == constraints.a[j].block_idx) { \
if (i == constraints.a[j].block_idx) { \
if (constraints_grid_idx_multiple[j] + idx >= constraints.a[j].total_len) { \
if (constraints_grid_idx_multiple[j] + idx >= constraints.a[j].total_len) { \
return; \
return; \
...
@@ -116,7 +109,7 @@ struct Constraint {
...
@@ -116,7 +109,7 @@ struct Constraint {
\
\
src_offset += remaining * src_block_stride.a[0]; \
src_offset += remaining * src_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \
for (
ssize
_t j = 0; j < constraint_num; j++) {
\
for (
ptrdiff
_t j = 0; j < constraint_num; j++) { \
if (0 == constraints.a[j].block_idx) { \
if (0 == constraints.a[j].block_idx) { \
if (constraints_grid_idx_multiple[j] + remaining >= constraints.a[j].total_len) { \
if (constraints_grid_idx_multiple[j] + remaining >= constraints.a[j].total_len) { \
return; \
return; \
...
@@ -124,39 +117,35 @@ struct Constraint {
...
@@ -124,39 +117,35 @@ struct Constraint {
} \
} \
} \
} \
\
\
/* 执行数据拷贝,注意offset已经是字节偏移 */
\
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
\
\
} else { \
} else { \
if (threadIdx.x == 0) {
/* 只让0号线程计算 */
\
if (threadIdx.x == 0) {
\
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */
\
\
ptrdiff_t src_offset = 0; \
ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \
ptrdiff_t dst_offset = 0; \
size_t remaining = blockIdx.x; \
size_t remaining = blockIdx.x; \
\
\
for (
ssize
_t i = grid_array_size - 1; i >= 0; i--) {
\
for (
ptrdiff
_t i = grid_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % grid_len.a[i]; \
size_t idx = remaining % grid_len.a[i]; \
remaining /= grid_len.a[i]; \
remaining /= grid_len.a[i]; \
src_offset += idx * src_grid_stride.a[i]; \
src_offset += idx * src_grid_stride.a[i]; \
dst_offset += idx * dst_grid_stride.a[i]; \
dst_offset += idx * dst_grid_stride.a[i]; \
} \
} \
\
\
/* 将结果存入共享内存 */
\
shared_src_offset = src_offset; \
shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \
shared_dst_offset = dst_offset; \
} \
} \
\
\
/* 确保所有线程都能看到共享内存中的值 */
\
__syncthreads(); \
__syncthreads(); \
\
\
/* 所有线程直接使用计算好的偏移值 */
\
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
\
\
for (
ssize
_t i = block_array_size - 1; i > 0; i--) {
\
for (
ptrdiff
_t i = block_array_size - 1; i > 0; i--) { \
size_t idx = remaining % block_len.a[i]; \
size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \
remaining /= block_len.a[i]; \
/* 计算偏移量 */
\
\
src_offset += idx * src_block_stride.a[i]; \
src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
} \
} \
...
@@ -164,18 +153,15 @@ struct Constraint {
...
@@ -164,18 +153,15 @@ struct Constraint {
src_offset += remaining * src_block_stride.a[0]; \
src_offset += remaining * src_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \
\
\
/* 执行数据拷贝,注意offset已经是字节偏移 */
\
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
} \
} \
}
}
// 定义支持的约束条件数量组合
#define DEFINE_KERNELS_BY_CONSTRAINT(block_array_size, grid_array_size) \
#define DEFINE_KERNELS_BY_CONSTRAINT(block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(0, block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(0, block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(1, block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(1, block_array_size, grid_array_size) \
DEFINE_KERNELS_BY_TYPE(2, block_array_size, grid_array_size)
DEFINE_KERNELS_BY_TYPE(2, block_array_size, grid_array_size)
// 定义支持的类型
#define DEFINE_KERNELS_BY_TYPE(constraint_num, block_array_size, grid_array_size) \
#define DEFINE_KERNELS_BY_TYPE(constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(uchar1, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(uchar1, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(uchar2, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(uchar2, constraint_num, block_array_size, grid_array_size) \
...
@@ -184,8 +170,6 @@ struct Constraint {
...
@@ -184,8 +170,6 @@ struct Constraint {
DEFINE_REARRANGE_KERNEL(float4, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(float4, constraint_num, block_array_size, grid_array_size) \
DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size)
DEFINE_REARRANGE_KERNEL(double4, constraint_num, block_array_size, grid_array_size)
// 与 MAX_BLOCK_ARRAY_SIZE 和 MAX_GRID_ARRAY_SIZE 耦合,需要同时修改
// 为1-5和1-5的所有组合生成内核
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
1
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
2
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
3
)
DEFINE_KERNELS_BY_CONSTRAINT
(
1
,
3
)
...
@@ -212,7 +196,6 @@ DEFINE_KERNELS_BY_CONSTRAINT(5, 3)
...
@@ -212,7 +196,6 @@ DEFINE_KERNELS_BY_CONSTRAINT(5, 3)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
4
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
5
)
DEFINE_KERNELS_BY_CONSTRAINT
(
5
,
5
)
// 准备参数结构体
struct
RearrangeParams
{
struct
RearrangeParams
{
std
::
vector
<
ARRAY_TYPE_SIZE
>
block_len
;
std
::
vector
<
ARRAY_TYPE_SIZE
>
block_len
;
std
::
vector
<
ARRAY_TYPE_STRIDE
>
src_block_stride
;
std
::
vector
<
ARRAY_TYPE_STRIDE
>
src_block_stride
;
...
@@ -234,25 +217,8 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams ¶ms) {
...
@@ -234,25 +217,8 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams ¶ms) {
CHECK_OR_RETURN
(
grid_num
<=
MAX_GRID_ARRAY_SIZE
&&
grid_num
!=
0
,
INFINI_STATUS_BAD_PARAM
);
CHECK_OR_RETURN
(
grid_num
<=
MAX_GRID_ARRAY_SIZE
&&
grid_num
!=
0
,
INFINI_STATUS_BAD_PARAM
);
CHECK_OR_RETURN
(
block_num
<=
MAX_BLOCK_ARRAY_SIZE
&&
block_num
!=
0
,
INFINI_STATUS_BAD_PARAM
);
CHECK_OR_RETURN
(
block_num
<=
MAX_BLOCK_ARRAY_SIZE
&&
block_num
!=
0
,
INFINI_STATUS_BAD_PARAM
);
CHECK_OR_RETURN
(
constraint_num
<=
2
,
INFINI_STATUS_BAD_PARAM
);
CHECK_OR_RETURN
(
constraint_num
<=
2
,
INFINI_STATUS_BAD_PARAM
);
/*
* These variables were originally part of the CUDA implementation for this kernel.
* They have been commented out because they are not currently used in the MUSA kernel logic.
*
* This change resolves "unused variable" warnings during compilation, ensuring a clean build.
* The original declarations are preserved here for for MUSA/CUDA platform alignment.
*/
// auto block_len = params.block_len.data();
// auto src_block_stride = params.src_block_stride.data();
// auto dst_block_stride = params.dst_block_stride.data();
// auto grid_len = params.grid_len.data();
// auto src_grid_stride = params.src_grid_stride.data();
// auto dst_grid_stride = params.dst_grid_stride.data();
// auto constrain = params.constraints.data();
void
*
kernel_func
=
nullptr
;
void
*
kernel_func
=
nullptr
;
#define GET_REARRANGE_KERNEL(Tmem_type, block_array_size, grid_array_size, constraint_num) \
#define GET_REARRANGE_KERNEL(Tmem_type, block_array_size, grid_array_size, constraint_num) \
kernel_func = (void *)rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num;
kernel_func = (void *)rearrange_unit_##Tmem_type##_block_##block_array_size##_grid_##grid_array_size##_constrain_##constraint_num;
...
...
src/infiniop/ops/rearrange/moore/rearrange_moore.mu
View file @
513a8502
...
@@ -28,7 +28,7 @@ infiniStatus_t Descriptor::create(
...
@@ -28,7 +28,7 @@ infiniStatus_t Descriptor::create(
CHECK_OR_RETURN(x_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(x_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(x_desc->ndim() == ndim, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(x_desc->ndim() == ndim, INFINI_STATUS_BAD_TENSOR_SHAPE);
// 保存临时vector对象
auto x_shape = x_desc->shape();
auto x_shape = x_desc->shape();
auto y_shape = y_desc->shape();
auto y_shape = y_desc->shape();
auto y_strides = y_desc->strides();
auto y_strides = y_desc->strides();
...
@@ -52,14 +52,12 @@ infiniStatus_t Descriptor::create(
...
@@ -52,14 +52,12 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
}
}
// 维度信息结构
struct Dim {
struct Dim {
size_t len;
size_t len;
ARRAY_TYPE_STRIDE src_stride;
ARRAY_TYPE_STRIDE src_stride;
ARRAY_TYPE_STRIDE dst_stride;
ARRAY_TYPE_STRIDE dst_stride;
};
};
// 分割维度结构
struct SplitDim {
struct SplitDim {
size_t choose_idx;
size_t choose_idx;
size_t num_per_block;
size_t num_per_block;
...
@@ -69,28 +67,17 @@ struct SplitDim {
...
@@ -69,28 +67,17 @@ struct SplitDim {
size_t dim_len;
size_t dim_len;
};
};
/**
* 根据给定的元数据准备张量重排参数,该函数主要完成以下工作:
* 1. 根据原始元数据调整单元大小,获取更适合GPU处理的单元大小
* 2. 将维度分配为块(block)维度和网格(grid)维度:
* 该步骤是核心,目标是为每个block分配尽可能多的相对连续的数据进行处理,
* 对无法完整放入块的维度进行分割,并记录分割维度信息,用于防止kernel访问越界,最大化内存访问局部性和计算效率
*/
utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta &original_meta, int max_threads) {
utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta &original_meta, int max_threads) {
RearrangeParams params;
RearrangeParams params;
// 获取更适合GPU处理的单元大小,这里使用2的幂次方
auto meta_result = original_meta.distributeUnit({32, 16, 8, 4, 2, 1});
auto meta_result = original_meta.distributeUnit({32, 16, 8, 4, 2, 1});
CHECK_RESULT(meta_result);
CHECK_RESULT(meta_result);
const utils::RearrangeMeta &meta = meta_result.take();
const utils::RearrangeMeta &meta = meta_result.take();
// 获取维度信息
const size_t ndim = meta.ndim();
const size_t ndim = meta.ndim();
const size_t unit = meta.unit();
const size_t unit = meta.unit();
// 特殊情况:无维度,只需要简单复制
if (ndim == 0) {
if (ndim == 0) {
params.block_dim = 0;
params.block_dim = 0;
params.block_len_total = 1;
params.block_len_total = 1;
...
@@ -104,12 +91,10 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
...
@@ -104,12 +91,10 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
return utils::Result<RearrangeParams>(params);
return utils::Result<RearrangeParams>(params);
}
}
// 从元数据中提取必要的信息
const ptrdiff_t *idx_strides = meta.idx_strides();
const ptrdiff_t *idx_strides = meta.idx_strides();
const ptrdiff_t *dst_strides = meta.dst_strides();
const ptrdiff_t *dst_strides = meta.dst_strides();
const ptrdiff_t *src_strides = meta.src_strides();
const ptrdiff_t *src_strides = meta.src_strides();
// 准备维度信息
std::vector<Dim> dims;
std::vector<Dim> dims;
std::vector<size_t> shape;
std::vector<size_t> shape;
dims.reserve(ndim);
dims.reserve(ndim);
...
@@ -123,153 +108,93 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
...
@@ -123,153 +108,93 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
prev_idx_stride = idx_strides[i];
prev_idx_stride = idx_strides[i];
}
}
// 计算src_strides的降序排序索引,类似于Rust版本中的src_strides_desc_idx
std::vector<bool> block_dim_choose(ndim, false);
std::vector<size_t> src_strides_desc_idx(ndim);
std::vector<SplitDim> split_dims;
std::vector<size_t> dim_order(ndim);
for (size_t i = 0; i < ndim; ++i) {
for (size_t i = 0; i < ndim; ++i) {
src_strides_desc_idx
[i] = i;
dim_order
[i] = i;
}
}
std::sort(src_strides_desc_idx.begin(), src_strides_desc_idx.end(),
std::sort(dim_order.begin(), dim_order.end(),
[&dims](size_t a, size_t b) {
[&dims](size_t a, size_t b) {
return std::abs(dims[a].src_stride)
>
std::abs(dims[b].src_stride);
return std::abs(dims[a].src_stride)
<
std::abs(dims[b].src_stride);
});
});
// 根据最大线程数选择block和grid维度
constexpr size_t MAX_BLOCK_DIM = MAX_BLOCK_ARRAY_SIZE;
const size_t block_size = max_threads;
std::vector<bool> block_dim_choose(ndim, false);
// 初始化计数器
size_t block_elements = 1;
size_t block_elements = 1;
size_t block_src_elements = 1;
size_t chosen_block_dims = 0;
size_t block_dst_elements = 1;
size_t src_choose_idx = ndim;
size_t dst_choose_idx = ndim;
// 用于存储分割维度信息
std::vector<SplitDim> split_dims;
// 维度选择循环
for (size_t i = 0; i < ndim; ++i) {
while (src_choose_idx > 0 && dst_choose_idx > 0) {
size_t dim_idx = dim_order[i];
// 获取当前需要处理的维度索引
size_t dim_len = shape[dim_idx];
size_t src_idx = src_strides_desc_idx[src_choose_idx - 1];
size_t dst_idx = dst_choose_idx - 1;
if (chosen_block_dims < MAX_BLOCK_DIM &&
block_elements * dim_len <= (size_t)max_threads) {
if (src_idx == dst_idx) {
// 源和目标维度相同,可以一起处理
block_dim_choose[dim_idx] = true;
size_t idx = src_idx;
block_elements *= dim_len;
size_t len = shape[idx];
chosen_block_dims++;
continue;
// 检查是否可以将此维度完全添加到block中
}
if (block_elements * len <= block_size) {
// 选择此维度
if (block_elements > 1 && dim_len > 1) {
block_dim_choose[idx] = true;
block_elements *= len;
if (chosen_block_dims + 1 > MAX_BLOCK_DIM) {
block_src_elements *= len;
break;
block_dst_elements *= len;
src_choose_idx--;
dst_choose_idx--;
} else {
// 需要分割此维度
size_t num_per_block = block_size / block_elements;
// 确保num_per_block > 0且len >= num_per_block
if (num_per_block > 0 && len >= num_per_block && num_per_block > 1) {
size_t num_per_grid = (len + num_per_block - 1) / num_per_block; // 向上取整
SplitDim split_dim = {
idx, // choose_idx
num_per_block, // num_per_block
num_per_grid, // num_per_grid
0, // array_struct_idx_block (待更新)
0, // array_struct_idx_grid (待更新)
len // 原始维度长度
};
split_dims.push_back(split_dim);
}
break;
}
}
} else {
// 源和目标维度不同,需要分别处理
size_t num_per_block =
// 计算块比例
std::min(dim_len, (size_t)max_threads / block_elements);
double src_div_dst = static_cast<double>(block_src_elements) / block_dst_elements;
double src_num_per_block = std::sqrt(block_size / (double)block_elements / src_div_dst);
if (num_per_block > 0) {
double dst_num_per_block = src_num_per_block * src_div_dst;
size_t num_per_grid =
(dim_len + num_per_block - 1) / num_per_block;
size_t src_current_dim_len = shape[src_idx];
size_t dst_current_dim_len = shape[dst_idx];
split_dims.push_back({
dim_idx,
if (static_cast<double>(src_current_dim_len) < src_num_per_block) {
num_per_block,
// 源维度可以完全添加到block
num_per_grid,
block_dim_choose[src_idx] = true;
0,
block_elements *= src_current_dim_len;
0,
block_src_elements *= src_current_dim_len;
dim_len
src_choose_idx--;
});
} else if (static_cast<double>(dst_current_dim_len) < dst_num_per_block) {
// 目标维度可以完全添加到block
block_elements *= num_per_block;
block_dim_choose[dst_idx] = true;
chosen_block_dims++;
block_elements *= dst_current_dim_len;
}
block_dst_elements *= dst_current_dim_len;
break;
dst_choose_idx--;
}
} else {
}
// 需要分割源和目标维度
size_t src_num_per_block_int = static_cast<size_t>(std::floor(src_num_per_block));
size_t dst_num_per_block_int = static_cast<size_t>(std::floor(dst_num_per_block));
// 计算网格尺寸
size_t src_num_per_grid = (src_current_dim_len + src_num_per_block_int - 1) / src_num_per_block_int; // 向上取整
size_t dst_num_per_grid = (dst_current_dim_len + dst_num_per_block_int - 1) / dst_num_per_block_int; // 向上取整
// 处理源维度
if (src_num_per_block_int > 1) {
if (src_num_per_grid == 1) {
// 可以完全放入块
block_dim_choose[src_idx] = true;
block_elements *= src_current_dim_len;
block_src_elements *= src_current_dim_len;
src_choose_idx--;
} else {
// 需要分割
SplitDim split_dim = {
src_idx, // choose_idx
src_num_per_block_int, // num_per_block
src_num_per_grid, // num_per_grid
0, // array_struct_idx_block (待更新)
0, // array_struct_idx_grid (待更新)
src_current_dim_len // 原始维度长度
};
split_dims.push_back(split_dim);
}
}
// 处理目标维度
if (dst_num_per_block_int > 1) {
if (dst_num_per_grid == 1) {
// 可以完全放入块
block_dim_choose[dst_idx] = true;
block_elements *= dst_current_dim_len;
block_dst_elements *= dst_current_dim_len;
dst_choose_idx--;
} else {
// 需要分割
SplitDim split_dim = {
dst_idx, // choose_idx
dst_num_per_block_int, // num_per_block
dst_num_per_grid, // num_per_grid
0, // array_struct_idx_block (待更新)
0, // array_struct_idx_grid (待更新)
dst_current_dim_len // 原始维度长度
};
split_dims.push_back(split_dim);
}
}
break;
if (block_elements == 1 && ndim > 0) {
}
size_t dim_idx = dim_order[0];
size_t dim_len = shape[dim_idx];
if (dim_len <= (size_t)max_threads) {
block_dim_choose[dim_idx] = true;
block_elements = dim_len;
} else {
size_t num_per_block = std::min(dim_len, (size_t)max_threads);
size_t num_per_grid = (dim_len + num_per_block - 1) / num_per_block;
SplitDim split_dim = {
dim_idx,
num_per_block,
num_per_grid,
0,
0,
dim_len};
split_dims.push_back(split_dim);
block_elements = num_per_block;
}
}
}
}
// 准备block维度相关参数
size_t block_dim = 0;
size_t block_dim = 0;
size_t block_len_total =
1
;
size_t block_len_total =
block_elements
;
std::vector<ARRAY_TYPE_SIZE> block_len;
std::vector<ARRAY_TYPE_SIZE> block_len;
std::vector<ARRAY_TYPE_STRIDE> src_block_stride;
std::vector<ARRAY_TYPE_STRIDE> src_block_stride;
...
@@ -279,46 +204,40 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
...
@@ -279,46 +204,40 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
std::vector<ARRAY_TYPE_STRIDE> src_grid_stride;
std::vector<ARRAY_TYPE_STRIDE> src_grid_stride;
std::vector<ARRAY_TYPE_STRIDE> dst_grid_stride;
std::vector<ARRAY_TYPE_STRIDE> dst_grid_stride;
// 处理block维度,填充block_len和block_stride
for (size_t i = 0; i < ndim; ++i) {
for (size_t i = 0; i < ndim; ++i) {
if (block_dim_choose[i]) {
if (block_dim_choose[i]) {
block_len.push_back(shape[i]);
block_len.push_back(shape[i]);
src_block_stride.push_back(dims[i].src_stride);
src_block_stride.push_back(dims[i].src_stride);
dst_block_stride.push_back(dims[i].dst_stride);
dst_block_stride.push_back(dims[i].dst_stride);
block_dim += 1;
block_dim += 1;
block_len_total *= shape[i];
}
}
// 处理分割维度的block部分
for (size_t j = 0; j < split_dims.size(); ++j) {
for (size_t j = 0; j < split_dims.size(); ++j) {
if (i == split_dims[j].choose_idx) {
if (i == split_dims[j].choose_idx) {
block_len.push_back(split_dims[j].num_per_block);
block_len.push_back(split_dims[j].num_per_block);
src_block_stride.push_back(dims[i].src_stride);
src_block_stride.push_back(dims[i].src_stride);
dst_block_stride.push_back(dims[i].dst_stride);
dst_block_stride.push_back(dims[i].dst_stride);
split_dims[j].array_struct_idx_block = block_dim;
split_dims[j].array_struct_idx_block =
static_cast<int>(
block_dim
)
;
block_dim += 1;
block_dim += 1;
block_len_total *= split_dims[j].num_per_block;
}
}
}
}
}
}
// 处理grid维度,填充grid_len和grid_stride
for (size_t i = 0; i < ndim; ++i) {
for (size_t i = 0; i < ndim; ++i) {
if (!block_dim_choose[i]) {
if (!block_dim_choose[i]) {
bool is_split = false;
bool is_split = false;
// 检查是否是分割维度
for (size_t j = 0; j < split_dims.size(); ++j) {
for (size_t j = 0; j < split_dims.size(); ++j) {
if (i == split_dims[j].choose_idx) {
if (i == split_dims[j].choose_idx) {
is_split = true;
is_split = true;
grid_len.push_back(split_dims[j].num_per_grid);
grid_len.push_back(split_dims[j].num_per_grid);
src_grid_stride.push_back(dims[i].src_stride * split_dims[j].num_per_block);
src_grid_stride.push_back(dims[i].src_stride * split_dims[j].num_per_block);
dst_grid_stride.push_back(dims[i].dst_stride * split_dims[j].num_per_block);
dst_grid_stride.push_back(dims[i].dst_stride * split_dims[j].num_per_block);
split_dims[j].array_struct_idx_grid = grid_len.size() - 1;
split_dims[j].array_struct_idx_grid = static_cast<int>(grid_len.size() - 1);
break;
}
}
}
}
// 如果不是分割维度,则作为完整的grid维度
if (!is_split) {
if (!is_split) {
grid_len.push_back(shape[i]);
grid_len.push_back(shape[i]);
src_grid_stride.push_back(dims[i].src_stride);
src_grid_stride.push_back(dims[i].src_stride);
...
@@ -327,17 +246,14 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
...
@@ -327,17 +246,14 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
}
}
}
}
// 如果grid_len为空,添加一个默认值
if (grid_len.empty()) {
if (grid_len.empty()) {
grid_len.push_back(1);
grid_len.push_back(1);
src_grid_stride.push_back(0);
src_grid_stride.push_back(0);
dst_grid_stride.push_back(0);
dst_grid_stride.push_back(0);
}
}
// 处理约束条件 - 使用与Rust版本相似的逻辑
std::vector<Constraint<ARRAY_TYPE_SIZE>> constraints;
std::vector<Constraint<ARRAY_TYPE_SIZE>> constraints;
// 限制最多处理2个约束条件
for (size_t i = 0; i < split_dims.size(); ++i) {
for (size_t i = 0; i < split_dims.size(); ++i) {
if (split_dims[i].dim_len % split_dims[i].num_per_block == 0) {
if (split_dims[i].dim_len % split_dims[i].num_per_block == 0) {
continue;
continue;
...
@@ -348,9 +264,12 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
...
@@ -348,9 +264,12 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
constraint.grid_div_block = split_dims[i].num_per_block;
constraint.grid_div_block = split_dims[i].num_per_block;
constraint.total_len = split_dims[i].dim_len;
constraint.total_len = split_dims[i].dim_len;
constraints.push_back(constraint);
constraints.push_back(constraint);
if (constraints.size() >= 2) {
break;
}
}
}
// 设置参数
params.block_dim = block_dim;
params.block_dim = block_dim;
params.block_len_total = block_len_total;
params.block_len_total = block_len_total;
params.block_len = block_len;
params.block_len = block_len;
...
@@ -365,7 +284,6 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
...
@@ -365,7 +284,6 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
return utils::Result<RearrangeParams>(params);
return utils::Result<RearrangeParams>(params);
}
}
// 带约束的内核启动模板函数
template <unsigned int BLOCK_SIZE>
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
infiniStatus_t launchKernel(
void *y,
void *y,
...
@@ -375,30 +293,28 @@ infiniStatus_t launchKernel(
...
@@ -375,30 +293,28 @@ infiniStatus_t launchKernel(
size_t unit_size,
size_t unit_size,
musaStream_t stream) {
musaStream_t stream) {
// 获取内核函数
RearrangeParams params_copy = params;
RearrangeParams params_copy = params; // 创建一个非const副本
auto kernel_func_result = getRearrangeKernel(params_copy);
auto kernel_func_result = getRearrangeKernel(params_copy);
CHECK_RESULT(kernel_func_result);
CHECK_RESULT(kernel_func_result);
auto kernel_func = kernel_func_result.take();
auto kernel_func = kernel_func_result.take();
// 创建非const的临时变量
size_t block_dim = params.block_dim;
size_t block_dim = params.block_dim;
size_t block_len_total = params.block_len_total;
size_t block_len_total = params.block_len_total;
//
计算对齐后的线程块大小(B
lock
S
ize
)以适配 MUSA 架构的Warp特性
//
Calculate aligned thread b
lock
s
ize
to match MUSA architecture's Warp characteristics:
// - MUSA
架构以 32 线程为基本调度单位(1个
Warp
)
// - MUSA
architecture uses 32 threads as the fundamental scheduling unit (1
Warp
).
// -
通过向上取整到最近的 32 的倍数,确保线程块包含完整的
Warp
// -
Round up to the nearest multiple of 32 to ensure the block consists of full
Warp
s.
// - MUSA
似乎不支持非 32 整数倍的计算
// - MUSA
hardware/scheduler typically requires thread counts to be multiples of 32.
size_t aligned_block_size = ((block_len_total + 31) / 32) * 32;
size_t aligned_block_size = ((block_len_total + 31) / 32) * 32;
block_len_total = aligned_block_size;
//
block_len_total = aligned_block_size;
//
确保对齐后的线程块大小不超过硬件/模板限制
//
Ensure the aligned block size does not exceed hardware or template-defined limits.
if (aligned_block_size > BLOCK_SIZE) {
if (aligned_block_size > BLOCK_SIZE) {
aligned_block_size = BLOCK_SIZE;
// 降级到安全值
aligned_block_size = BLOCK_SIZE;
}
}
//
检查向量尺寸是否合理
//
Validate that vector dimensions are sufficient for the specified block dimension.
if (params.block_len.size() < block_dim || params.src_block_stride.size() < block_dim || params.dst_block_stride.size() < block_dim) {
if (params.block_len.size() < block_dim || params.src_block_stride.size() < block_dim || params.dst_block_stride.size() < block_dim) {
return INFINI_STATUS_BAD_PARAM;
return INFINI_STATUS_BAD_PARAM;
}
}
...
@@ -428,18 +344,18 @@ infiniStatus_t launchKernel(
...
@@ -428,18 +344,18 @@ infiniStatus_t launchKernel(
const_cast<void *>(static_cast<const void *>(params.dst_grid_stride.data())),
const_cast<void *>(static_cast<const void *>(params.dst_grid_stride.data())),
const_cast<void *>(static_cast<const void *>(constraints_data))};
const_cast<void *>(static_cast<const void *>(constraints_data))};
//
musaLaunchKernel 的 blockDim 似乎必须满足:
//
The blockDim for musaLaunchKernel must satisfy the following constraints:
// -
是32的整数倍(适配
MUSA
的
Warp
调度机制)
// -
Must be a multiple of 32 (aligned with
MUSA
's
Warp
scheduling mechanism).
// -
不小于实际需要处理的元素数(
block_len_total
)
// -
Must be greater than or equal to the number of elements to process (
block_len_total
).
// -
向上取整,数学等效:ceil(n / 32) * 32
// -
Math equivalent: ceil(n / 32) * 32 (rounding up to the nearest warp).
CHECK_OR_RETURN(musaLaunchKernel(
CHECK_OR_RETURN(musaLaunchKernel(
kernel_func,
kernel_func,
grid_size, aligned_block_size,
static_cast<unsigned int>(
grid_size
)
,
static_cast<unsigned int>(
aligned_block_size
)
,
args, 0, stream)
args, 0, stream)
== musaSuccess,
== musaSuccess,
INFINI_STATUS_INTERNAL_ERROR);
INFINI_STATUS_INTERNAL_ERROR);
//
设备同步,检查内核执行是否出错
//
Synchronize the device to detect potential asynchronous kernel execution errors.
musaError_t err = musaDeviceSynchronize();
musaError_t err = musaDeviceSynchronize();
if (err != musaSuccess) {
if (err != musaSuccess) {
std::cerr << "[ERROR] musaDeviceSynchronize failed: " << err << std::endl;
std::cerr << "[ERROR] musaDeviceSynchronize failed: " << err << std::endl;
...
@@ -456,38 +372,27 @@ infiniStatus_t Descriptor::calculate(
...
@@ -456,38 +372,27 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// 如果没有维度,直接进行内存拷贝
if (_meta.ndim() == 0) {
if (_meta.ndim() == 0) {
auto err = musaMemcpyAsync(y, x, _meta.unit(), musaMemcpyDeviceToDevice, musa_stream);
if (err != musaSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
CHECK_OR_RETURN(musaMemcpyAsync(y, x, _meta.unit(), musaMemcpyDeviceToDevice, musa_stream) == musaSuccess,
CHECK_OR_RETURN(musaMemcpyAsync(y, x, _meta.unit(), musaMemcpyDeviceToDevice, musa_stream) == musaSuccess,
INFINI_STATUS_INTERNAL_ERROR);
INFINI_STATUS_INTERNAL_ERROR);
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
}
}
// 获取设备属性
int max_threads = _opaque->internal->maxThreadsPerBlock();
int max_threads = _opaque->internal->maxThreadsPerBlock();
// 准备参数
auto params_result = prepareRearrangeParams(_meta, std::min(MOORE_BLOCK_SIZE_1024, max_threads));
auto params_result = prepareRearrangeParams(_meta, std::min(MOORE_BLOCK_SIZE_1024, max_threads));
CHECK_RESULT(params_result);
CHECK_RESULT(params_result);
auto params = params_result.take();
auto params = params_result.take();
// 计算grid大小
size_t grid_size = 1;
size_t grid_size = 1;
for (size_t i = 0; i < params.grid_len.size(); ++i) {
for (size_t i = 0; i < params.grid_len.size(); ++i) {
grid_size *= params.grid_len[i];
grid_size *= params.grid_len[i];
}
}
// 检查grid大小是否为0
if (grid_size == 0) {
if (grid_size == 0) {
return INFINI_STATUS_BAD_PARAM;
return INFINI_STATUS_BAD_PARAM;
}
}
// 根据设备属性选择合适的内核
infiniStatus_t status = INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
infiniStatus_t status = INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
size_t block_size = params.block_len_total;
size_t block_size = params.block_len_total;
...
@@ -497,7 +402,6 @@ infiniStatus_t Descriptor::calculate(
...
@@ -497,7 +402,6 @@ infiniStatus_t Descriptor::calculate(
} else if (block_size <= MOORE_BLOCK_SIZE_1024) {
} else if (block_size <= MOORE_BLOCK_SIZE_1024) {
status = launchKernel<MOORE_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), musa_stream);
status = launchKernel<MOORE_BLOCK_SIZE_1024>(y, x, grid_size, params, _meta.unit(), musa_stream);
} else {
} else {
std::cerr << "[ERROR] block_size=" << block_size << " exceeds max supported" << std::endl;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment