Commit b93d2e1b authored by Chao Liu's avatar Chao Liu
Browse files

fix batch gemm asm bug

parent 46a0aec1
......@@ -201,7 +201,7 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
if(offset == 0)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:0 \n \
ds_read_b128 %0, %1 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
......@@ -350,6 +350,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 2432)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2432 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 2560)
{
asm volatile("\n \
......@@ -358,6 +366,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 2688)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2688 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 2816)
{
asm volatile("\n \
......@@ -366,6 +382,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 2944)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2944 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3072)
{
asm volatile("\n \
......@@ -374,6 +398,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3200)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3200 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3328)
{
asm volatile("\n \
......@@ -382,6 +414,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3456)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3456 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3584)
{
asm volatile("\n \
......@@ -390,6 +430,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3712)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3712 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3840)
{
asm volatile("\n \
......@@ -398,6 +446,14 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 3968)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3968 \n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
else if(offset == 4096)
{
asm volatile("\n \
......
......@@ -293,8 +293,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
......@@ -344,24 +342,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.Get1dIndex(0, NPerLevel1Cluster) + mMyThreadOffsetB]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.Get1dIndex(0, MPerLevel1Cluster) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
reg_a[0] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.Get1dIndex(k, 0) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
reg_b[0] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.Get1dIndex(k, 0) + mMyThreadOffsetB]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
&p_b_block[b_block_mtx.Get1dIndex(k, NPerLevel1Cluster) + mMyThreadOffsetB]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
&p_a_block[a_block_mtx.Get1dIndex(k, MPerLevel1Cluster) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
......@@ -430,10 +430,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);
constexpr index_t a_lds_row_stride = sizeof(Float) * M;
constexpr index_t b_lds_row_stride = sizeof(Float) * N;
constexpr index_t a_lds_cluster_col_stride = sizeof(Float) * MPerLevel1Cluster;
constexpr index_t b_lds_cluster_col_stride = sizeof(Float) * NPerLevel1Cluster;
constexpr index_t a_lds_row_stride = sizeof(float) * a_block_mtx.RowStride();
constexpr index_t b_lds_row_stride = sizeof(float) * b_block_mtx.RowStride();
constexpr index_t a_lds_cluster_col_stride = sizeof(float) * MPerLevel1Cluster;
constexpr index_t b_lds_cluster_col_stride = sizeof(float) * NPerLevel1Cluster;
ds_read_b128(reg_a[0], a_lds_loc, 0);
ds_read_b128(reg_b[0], b_lds_loc, 0);
......
......@@ -213,7 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 0
#if 1
const Float* p_in_global_block_offset =
p_in_global +
in_c_h_w_n_global_desc.Get1dIndex(
......@@ -241,7 +241,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
__syncthreads();
#if 1
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
#elif 0
blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread);
#elif 1
blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread);
#endif
__syncthreads();
}
......@@ -277,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
#elif 0
blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread);
#elif 0
#elif 1
blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread);
#endif
......
......@@ -293,8 +293,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: GEMM on current data
#if 1
blockwise_batch_gemm.Run
#elif 0
blockwise_batch_gemm.Run_asm
#else
blockwise_batch_gemm.Run_asm_v2
#endif
(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
......@@ -321,8 +328,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: GEMM on current data
#if 1
blockwise_batch_gemm.Run
#elif 0
blockwise_batch_gemm.Run_asm
#else
blockwise_batch_gemm.Run_asm_v2
#endif
(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
......@@ -333,8 +347,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space,
// LDS double buffer: GEMM on current data
#if 1
blockwise_batch_gemm.Run
#elif 0
blockwise_batch_gemm.Run_asm
#else
blockwise_batch_gemm.Run_asm_v2
#endif
(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment