Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
abcac0da
Commit
abcac0da
authored
Jun 04, 2025
by
crapromer
Browse files
issue/238 - format rearrange_kernel.h
parent
8b60242d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
20 deletions
+20
-20
src/infiniop/ops/rearrange/maca/rearrange_kernel.h
src/infiniop/ops/rearrange/maca/rearrange_kernel.h
+20
-20
No files found.
src/infiniop/ops/rearrange/maca/rearrange_kernel.h
View file @
abcac0da
...
...
@@ -37,26 +37,26 @@ struct Constraint {
const size_t block_dim, \
const size_t block_len_total, \
const ArrayStruct<block_array_size, ARRAY_TYPE_SIZE> block_len, \
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> src_block_stride,
/* 字节单位的步长 */
\
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> dst_block_stride,
/* 字节单位的步长 */
\
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> src_block_stride,
/* 字节单位的步长 */
\
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> dst_block_stride,
/* 字节单位的步长 */
\
const ArrayStruct<grid_array_size, ARRAY_TYPE_SIZE> grid_len, \
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> src_grid_stride,
/* 字节单位的步长 */
\
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> dst_grid_stride
/* 字节单位的步长 */
\
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> src_grid_stride,
/* 字节单位的步长 */
\
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> dst_grid_stride
/* 字节单位的步长 */
\
IF_CONSTRAINT_##constraint_num) { \
size_t remaining = threadIdx.x; \
if (remaining >= block_len_total) { \
return; \
} \
\
/* 声明共享内存 */
\
/* 声明共享内存 */
\
__shared__ ptrdiff_t shared_src_offset; \
__shared__ ptrdiff_t shared_dst_offset; \
\
if (constraint_num > 0) { \
__shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
\
if (threadIdx.x == 0) {
/* 只让0号线程计算 */
\
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */
\
if (threadIdx.x == 0) {
/* 只让0号线程计算 */
\
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */
\
ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
...
...
@@ -78,7 +78,7 @@ struct Constraint {
} \
} \
\
/* 将结果存入共享内存 */
\
/* 将结果存入共享内存 */
\
shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \
for (ssize_t j = 0; j < constraint_num; j++) { \
...
...
@@ -86,10 +86,10 @@ struct Constraint {
} \
} \
\
/* 确保所有线程都能看到共享内存中的值 */
\
/* 确保所有线程都能看到共享内存中的值 */
\
__syncthreads(); \
\
/* 所有线程直接使用计算好的偏移值 */
\
/* 所有线程直接使用计算好的偏移值 */
\
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
...
...
@@ -100,7 +100,7 @@ struct Constraint {
for (ssize_t i = block_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \
/* 计算偏移量 */
\
/* 计算偏移量 */
\
src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
if (constraint_num > 0) { \
...
...
@@ -124,12 +124,12 @@ struct Constraint {
} \
} \
\
/* 执行数据拷贝,注意offset已经是字节偏移 */
\
/* 执行数据拷贝,注意offset已经是字节偏移 */
\
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
\
} else { \
if (threadIdx.x == 0) {
/* 只让0号线程计算 */
\
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */
\
if (threadIdx.x == 0) {
/* 只让0号线程计算 */
\
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */
\
ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \
size_t remaining = blockIdx.x; \
...
...
@@ -141,22 +141,22 @@ struct Constraint {
dst_offset += idx * dst_grid_stride.a[i]; \
} \
\
/* 将结果存入共享内存 */
\
/* 将结果存入共享内存 */
\
shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \
} \
\
/* 确保所有线程都能看到共享内存中的值 */
\
/* 确保所有线程都能看到共享内存中的值 */
\
__syncthreads(); \
\
/* 所有线程直接使用计算好的偏移值 */
\
/* 所有线程直接使用计算好的偏移值 */
\
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
\
for (ssize_t i = block_array_size - 1; i > 0; i--) { \
size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \
/* 计算偏移量 */
\
/* 计算偏移量 */
\
src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
} \
...
...
@@ -164,7 +164,7 @@ struct Constraint {
src_offset += remaining * src_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \
\
/* 执行数据拷贝,注意offset已经是字节偏移 */
\
/* 执行数据拷贝,注意offset已经是字节偏移 */
\
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
} \
}
...
...
@@ -328,4 +328,4 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams ¶ms) {
return
utils
::
Result
<
void
*>
(
kernel_func
);
}
#endif // __REARRANGE_
CUD
A_KERNEL_H__
#endif // __REARRANGE_
MAC
A_KERNEL_H__
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment