"examples/vscode:/vscode.git/clone" did not exist on "ce144d6dd05c4c588d9f2970301ace70eee16d5d"
Commit e6230689 authored by Jing Zhang's avatar Jing Zhang
Browse files

resolve conflict

parents 1e37e838 edc89778
...@@ -423,7 +423,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -423,7 +423,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
time, time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1024) * 1024 * 1024 * 1024) / (time / 1000)); (std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000))); usleep(std::min(time * 1000, float(10000)));
} }
......
...@@ -211,52 +211,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -211,52 +211,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll #pragma unroll
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{ {
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
// copy B-sub to form B
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// loop over batch // loop over batch
#pragma unroll #pragma unroll
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib) for(index_t ib = 0; ib < BatchPerThread; ++ib)
{ {
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + ib * ThreadMatrixStrideC);
// read next batch of a, b // read next batch of a, b
if(BlockMatrixStrideA != 0) if(BlockMatrixStrideA != 0 or ib == 0)
{ {
#pragma unroll #pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
...@@ -265,7 +225,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -265,7 +225,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
a_block_mtx, a_block_mtx,
p_a_block + p_a_block +
a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA, ib * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx, a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(), a_thread_sub_mtx.GetLengths(),
...@@ -273,7 +233,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -273,7 +233,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
if(BlockMatrixStrideB != 0) if(BlockMatrixStrideB != 0 or ib == 0)
{ {
#pragma unroll #pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
...@@ -282,25 +242,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -282,25 +242,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
b_block_mtx, b_block_mtx,
p_b_block + p_b_block +
b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB, ib * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx, b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(), b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{}); Number<DataPerReadB>{});
} }
} }
}
// do last batch of gemm threadwise_gemm(a_thread_mtx,
threadwise_gemm(a_thread_mtx, True,
True, p_a_thread,
p_a_thread, b_thread_mtx,
b_thread_mtx, False,
False, p_b_thread,
p_b_thread, c_thread_mtx,
c_thread_mtx, False,
False, p_c_thread + ib * ThreadMatrixStrideC);
p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC); }
} }
} }
......
...@@ -276,9 +276,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -276,9 +276,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 1 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread); p_out_thread);
} }
} }
...@@ -320,9 +320,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -320,9 +320,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 1 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_double + y * Wi + x, p_in_block_double + y * Wi + x,
p_out_thread); p_out_thread);
} }
} }
...@@ -345,10 +345,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -345,10 +345,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 1 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_double + wei_block_space + (p_wei_block_double + wei_block_space +
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_double + in_block_space + y * Wi + x, p_in_block_double + in_block_space + y * Wi + x,
p_out_thread); p_out_thread);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment