Commit b93d2e1b authored by Chao Liu's avatar Chao Liu
Browse files

fix batch gemm asm bug

parent 46a0aec1
...@@ -201,7 +201,7 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -201,7 +201,7 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
if(offset == 0) if(offset == 0)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:0 \n \ ds_read_b128 %0, %1 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
...@@ -350,6 +350,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -350,6 +350,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
else if(offset == 2432)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2432 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 2560) else if(offset == 2560)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -358,6 +366,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -358,6 +366,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
else if(offset == 2688)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2688 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 2816) else if(offset == 2816)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -366,6 +382,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -366,6 +382,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
else if(offset == 2944)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2944 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3072) else if(offset == 3072)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -374,6 +398,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -374,6 +398,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
else if(offset == 3200)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3200 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3328) else if(offset == 3328)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -382,6 +414,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -382,6 +414,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
else if(offset == 3456)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3456 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3584) else if(offset == 3584)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -390,6 +430,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -390,6 +430,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
else if(offset == 3712)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3712 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3840) else if(offset == 3840)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -398,6 +446,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -398,6 +446,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
else if(offset == 3968)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3968 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 4096) else if(offset == 4096)
{ {
asm volatile("\n \ asm volatile("\n \
......
...@@ -293,8 +293,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -293,8 +293,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow(); // A is transposed constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
...@@ -344,24 +342,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -344,24 +342,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]); reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]); reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] = reg_b[1] = *reinterpret_cast<const Float4*>(
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]); &p_b_block[b_block_mtx.Get1dIndex(0, NPerLevel1Cluster) + mMyThreadOffsetB]);
reg_a[1] = reg_a[1] = *reinterpret_cast<const Float4*>(
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]); &p_a_block[a_block_mtx.Get1dIndex(0, MPerLevel1Cluster) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll #pragma unroll
for(index_t k = 1; k < K; ++k) for(index_t k = 1; k < K; ++k)
{ {
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]); reg_a[0] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.Get1dIndex(k, 0) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]); reg_b[0] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.Get1dIndex(k, 0) + mMyThreadOffsetB]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>( reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]); &p_b_block[b_block_mtx.Get1dIndex(k, NPerLevel1Cluster) + mMyThreadOffsetB]);
reg_a[1] = *reinterpret_cast<const Float4*>( reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]); &p_a_block[a_block_mtx.Get1dIndex(k, MPerLevel1Cluster) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
} }
...@@ -430,10 +430,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -430,10 +430,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA); void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB); void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);
constexpr index_t a_lds_row_stride = sizeof(Float) * M; constexpr index_t a_lds_row_stride = sizeof(float) * a_block_mtx.RowStride();
constexpr index_t b_lds_row_stride = sizeof(Float) * N; constexpr index_t b_lds_row_stride = sizeof(float) * b_block_mtx.RowStride();
constexpr index_t a_lds_cluster_col_stride = sizeof(Float) * MPerLevel1Cluster; constexpr index_t a_lds_cluster_col_stride = sizeof(float) * MPerLevel1Cluster;
constexpr index_t b_lds_cluster_col_stride = sizeof(Float) * NPerLevel1Cluster; constexpr index_t b_lds_cluster_col_stride = sizeof(float) * NPerLevel1Cluster;
ds_read_b128(reg_a[0], a_lds_loc, 0); ds_read_b128(reg_a[0], a_lds_loc, 0);
ds_read_b128(reg_b[0], b_lds_loc, 0); ds_read_b128(reg_b[0], b_lds_loc, 0);
......
...@@ -213,7 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -213,7 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 0 #if 1
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_c_h_w_n_global_desc.Get1dIndex( in_c_h_w_n_global_desc.Get1dIndex(
...@@ -241,7 +241,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -241,7 +241,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
__syncthreads(); __syncthreads();
#if 1
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
#elif 0
blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread);
#elif 1
blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread);
#endif
__syncthreads(); __syncthreads();
} }
...@@ -277,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -277,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
#elif 0 #elif 0
blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread); blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread);
#elif 0 #elif 1
blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread); blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread);
#endif #endif
......
...@@ -293,8 +293,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -293,8 +293,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); #if 1
blockwise_batch_gemm.Run
#elif 0
blockwise_batch_gemm.Run_asm
#else
blockwise_batch_gemm.Run_asm_v2
#endif
(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
...@@ -321,8 +328,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -321,8 +328,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_clipboard);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); #if 1
blockwise_batch_gemm.Run
#elif 0
blockwise_batch_gemm.Run_asm
#else
blockwise_batch_gemm.Run_asm_v2
#endif
(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
...@@ -333,8 +347,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -333,8 +347,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space, #if 1
blockwise_batch_gemm.Run
#elif 0
blockwise_batch_gemm.Run_asm
#else
blockwise_batch_gemm.Run_asm_v2
#endif
(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space, p_in_block_double + in_block_space,
p_out_thread); p_out_thread);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment