Commit abcac0da authored by crapromer's avatar crapromer
Browse files

issue/238 - format rearrange_kernel.h

parent 8b60242d
......@@ -37,26 +37,26 @@ struct Constraint {
const size_t block_dim, \
const size_t block_len_total, \
const ArrayStruct<block_array_size, ARRAY_TYPE_SIZE> block_len, \
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> src_block_stride, /* 字节单位的步长 */ \
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> dst_block_stride, /* 字节单位的步长 */ \
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> src_block_stride, /* 字节单位的步长 */ \
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> dst_block_stride, /* 字节单位的步长 */ \
const ArrayStruct<grid_array_size, ARRAY_TYPE_SIZE> grid_len, \
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> src_grid_stride, /* 字节单位的步长 */ \
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> dst_grid_stride /* 字节单位的步长 */ \
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> src_grid_stride, /* 字节单位的步长 */ \
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> dst_grid_stride /* 字节单位的步长 */ \
IF_CONSTRAINT_##constraint_num) { \
size_t remaining = threadIdx.x; \
if (remaining >= block_len_total) { \
return; \
} \
\
/* 声明共享内存 */ \
/* 声明共享内存 */ \
__shared__ ptrdiff_t shared_src_offset; \
__shared__ ptrdiff_t shared_dst_offset; \
\
if (constraint_num > 0) { \
__shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
\
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \
ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
......@@ -78,7 +78,7 @@ struct Constraint {
} \
} \
\
/* 将结果存入共享内存 */ \
/* 将结果存入共享内存 */ \
shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \
for (ssize_t j = 0; j < constraint_num; j++) { \
......@@ -86,10 +86,10 @@ struct Constraint {
} \
} \
\
/* 确保所有线程都能看到共享内存中的值 */ \
/* 确保所有线程都能看到共享内存中的值 */ \
__syncthreads(); \
\
/* 所有线程直接使用计算好的偏移值 */ \
/* 所有线程直接使用计算好的偏移值 */ \
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
......@@ -100,7 +100,7 @@ struct Constraint {
for (ssize_t i = block_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \
/* 计算偏移量 */ \
/* 计算偏移量 */ \
src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
if (constraint_num > 0) { \
......@@ -124,12 +124,12 @@ struct Constraint {
} \
} \
\
/* 执行数据拷贝,注意offset已经是字节偏移 */ \
/* 执行数据拷贝,注意offset已经是字节偏移 */ \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
\
} else { \
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \
ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \
size_t remaining = blockIdx.x; \
......@@ -141,22 +141,22 @@ struct Constraint {
dst_offset += idx * dst_grid_stride.a[i]; \
} \
\
/* 将结果存入共享内存 */ \
/* 将结果存入共享内存 */ \
shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \
} \
\
/* 确保所有线程都能看到共享内存中的值 */ \
/* 确保所有线程都能看到共享内存中的值 */ \
__syncthreads(); \
\
/* 所有线程直接使用计算好的偏移值 */ \
/* 所有线程直接使用计算好的偏移值 */ \
ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \
\
for (ssize_t i = block_array_size - 1; i > 0; i--) { \
size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \
/* 计算偏移量 */ \
/* 计算偏移量 */ \
src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \
} \
......@@ -164,7 +164,7 @@ struct Constraint {
src_offset += remaining * src_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \
\
/* 执行数据拷贝,注意offset已经是字节偏移 */ \
/* 执行数据拷贝,注意offset已经是字节偏移 */ \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
} \
}
......@@ -328,4 +328,4 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
return utils::Result<void *>(kernel_func);
}
#endif // __REARRANGE_CUDA_KERNEL_H__
#endif // __REARRANGE_MACA_KERNEL_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment