Commit abcac0da authored by crapromer's avatar crapromer
Browse files

issue/238 - format rearrange_kernel.h

parent 8b60242d
...@@ -37,26 +37,26 @@ struct Constraint { ...@@ -37,26 +37,26 @@ struct Constraint {
const size_t block_dim, \ const size_t block_dim, \
const size_t block_len_total, \ const size_t block_len_total, \
const ArrayStruct<block_array_size, ARRAY_TYPE_SIZE> block_len, \ const ArrayStruct<block_array_size, ARRAY_TYPE_SIZE> block_len, \
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> src_block_stride, /* 字节单位的步长 */ \ const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> src_block_stride, /* 字节单位的步长 */ \
const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> dst_block_stride, /* 字节单位的步长 */ \ const ArrayStruct<block_array_size, ARRAY_TYPE_STRIDE> dst_block_stride, /* 字节单位的步长 */ \
const ArrayStruct<grid_array_size, ARRAY_TYPE_SIZE> grid_len, \ const ArrayStruct<grid_array_size, ARRAY_TYPE_SIZE> grid_len, \
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> src_grid_stride, /* 字节单位的步长 */ \ const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> src_grid_stride, /* 字节单位的步长 */ \
const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> dst_grid_stride /* 字节单位的步长 */ \ const ArrayStruct<grid_array_size, ARRAY_TYPE_STRIDE> dst_grid_stride /* 字节单位的步长 */ \
IF_CONSTRAINT_##constraint_num) { \ IF_CONSTRAINT_##constraint_num) { \
size_t remaining = threadIdx.x; \ size_t remaining = threadIdx.x; \
if (remaining >= block_len_total) { \ if (remaining >= block_len_total) { \
return; \ return; \
} \ } \
\ \
/* 声明共享内存 */ \ /* 声明共享内存 */ \
__shared__ ptrdiff_t shared_src_offset; \ __shared__ ptrdiff_t shared_src_offset; \
__shared__ ptrdiff_t shared_dst_offset; \ __shared__ ptrdiff_t shared_dst_offset; \
\ \
if (constraint_num > 0) { \ if (constraint_num > 0) { \
__shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \ __shared__ ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
\ \
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \ if (threadIdx.x == 0) { /* 只让0号线程计算 */ \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \ /* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \
ptrdiff_t src_offset = 0; \ ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \ ptrdiff_t dst_offset = 0; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \ ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
...@@ -78,7 +78,7 @@ struct Constraint { ...@@ -78,7 +78,7 @@ struct Constraint {
} \ } \
} \ } \
\ \
/* 将结果存入共享内存 */ \ /* 将结果存入共享内存 */ \
shared_src_offset = src_offset; \ shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \ shared_dst_offset = dst_offset; \
for (ssize_t j = 0; j < constraint_num; j++) { \ for (ssize_t j = 0; j < constraint_num; j++) { \
...@@ -86,10 +86,10 @@ struct Constraint { ...@@ -86,10 +86,10 @@ struct Constraint {
} \ } \
} \ } \
\ \
/* 确保所有线程都能看到共享内存中的值 */ \ /* 确保所有线程都能看到共享内存中的值 */ \
__syncthreads(); \ __syncthreads(); \
\ \
/* 所有线程直接使用计算好的偏移值 */ \ /* 所有线程直接使用计算好的偏移值 */ \
ptrdiff_t src_offset = shared_src_offset; \ ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \ ptrdiff_t dst_offset = shared_dst_offset; \
ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \ ARRAY_TYPE_SIZE constraints_grid_idx_multiple[constraint_num > 0 ? constraint_num : 1]; \
...@@ -100,7 +100,7 @@ struct Constraint { ...@@ -100,7 +100,7 @@ struct Constraint {
for (ssize_t i = block_array_size - 1; i >= 0; i--) { \ for (ssize_t i = block_array_size - 1; i >= 0; i--) { \
size_t idx = remaining % block_len.a[i]; \ size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \ remaining /= block_len.a[i]; \
/* 计算偏移量 */ \ /* 计算偏移量 */ \
src_offset += idx * src_block_stride.a[i]; \ src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \ dst_offset += idx * dst_block_stride.a[i]; \
if (constraint_num > 0) { \ if (constraint_num > 0) { \
...@@ -124,12 +124,12 @@ struct Constraint { ...@@ -124,12 +124,12 @@ struct Constraint {
} \ } \
} \ } \
\ \
/* 执行数据拷贝,注意offset已经是字节偏移 */ \ /* 执行数据拷贝,注意offset已经是字节偏移 */ \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \ *reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
\ \
} else { \ } else { \
if (threadIdx.x == 0) { /* 只让0号线程计算 */ \ if (threadIdx.x == 0) { /* 只让0号线程计算 */ \
/* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \ /* 计算当前block处理的数据在src和dst中的基础偏移(bytes) */ \
ptrdiff_t src_offset = 0; \ ptrdiff_t src_offset = 0; \
ptrdiff_t dst_offset = 0; \ ptrdiff_t dst_offset = 0; \
size_t remaining = blockIdx.x; \ size_t remaining = blockIdx.x; \
...@@ -141,22 +141,22 @@ struct Constraint { ...@@ -141,22 +141,22 @@ struct Constraint {
dst_offset += idx * dst_grid_stride.a[i]; \ dst_offset += idx * dst_grid_stride.a[i]; \
} \ } \
\ \
/* 将结果存入共享内存 */ \ /* 将结果存入共享内存 */ \
shared_src_offset = src_offset; \ shared_src_offset = src_offset; \
shared_dst_offset = dst_offset; \ shared_dst_offset = dst_offset; \
} \ } \
\ \
/* 确保所有线程都能看到共享内存中的值 */ \ /* 确保所有线程都能看到共享内存中的值 */ \
__syncthreads(); \ __syncthreads(); \
\ \
/* 所有线程直接使用计算好的偏移值 */ \ /* 所有线程直接使用计算好的偏移值 */ \
ptrdiff_t src_offset = shared_src_offset; \ ptrdiff_t src_offset = shared_src_offset; \
ptrdiff_t dst_offset = shared_dst_offset; \ ptrdiff_t dst_offset = shared_dst_offset; \
\ \
for (ssize_t i = block_array_size - 1; i > 0; i--) { \ for (ssize_t i = block_array_size - 1; i > 0; i--) { \
size_t idx = remaining % block_len.a[i]; \ size_t idx = remaining % block_len.a[i]; \
remaining /= block_len.a[i]; \ remaining /= block_len.a[i]; \
/* 计算偏移量 */ \ /* 计算偏移量 */ \
src_offset += idx * src_block_stride.a[i]; \ src_offset += idx * src_block_stride.a[i]; \
dst_offset += idx * dst_block_stride.a[i]; \ dst_offset += idx * dst_block_stride.a[i]; \
} \ } \
...@@ -164,7 +164,7 @@ struct Constraint { ...@@ -164,7 +164,7 @@ struct Constraint {
src_offset += remaining * src_block_stride.a[0]; \ src_offset += remaining * src_block_stride.a[0]; \
dst_offset += remaining * dst_block_stride.a[0]; \ dst_offset += remaining * dst_block_stride.a[0]; \
\ \
/* 执行数据拷贝,注意offset已经是字节偏移 */ \ /* 执行数据拷贝,注意offset已经是字节偏移 */ \
*reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \ *reinterpret_cast<Tmem_type *>(reinterpret_cast<char *>(dst) + dst_offset) = *reinterpret_cast<const Tmem_type *>(reinterpret_cast<const char *>(src) + src_offset); \
} \ } \
} }
...@@ -328,4 +328,4 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) { ...@@ -328,4 +328,4 @@ utils::Result<void *> getRearrangeKernel(const RearrangeParams &params) {
return utils::Result<void *>(kernel_func); return utils::Result<void *>(kernel_func);
} }
#endif // __REARRANGE_CUDA_KERNEL_H__ #endif // __REARRANGE_MACA_KERNEL_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment