"vscode:/vscode.git/clone" did not exist on "5c19c8c5fcdf64247e408a96eac0590058ed5655"
Commit 96ee9571 authored by Chao Liu's avatar Chao Liu
Browse files

tuned implicit gemm v1 for 3x3 on AMD to 82%. Fixed a bug in 4d tensor blockwise copy.

parent edc89778
...@@ -78,7 +78,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -78,7 +78,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0 #if 0
// for 3x3, 34x34 // for 3x3, 34x34, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4; constexpr index_t CPerBlock = 4;
...@@ -111,6 +111,39 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -111,6 +111,39 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1
// for 3x3, 34x34, Vega 20
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 256;
#elif 0 #elif 0
// for 5x5, 36x36 // for 5x5, 36x36
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
...@@ -264,7 +297,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -264,7 +297,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 0
// for 3x3, 28x28, v1, Pacal // for 3x3, 28x28, v1, Pacal
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
......
...@@ -409,13 +409,13 @@ int main(int argc, char* argv[]) ...@@ -409,13 +409,13 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
constexpr index_t K = 64; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -511,7 +511,7 @@ int main(int argc, char* argv[]) ...@@ -511,7 +511,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 1; constexpr index_t HPad = 1;
constexpr index_t WPad = 1; constexpr index_t WPad = 1;
#elif 1 #elif 0
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -681,7 +681,7 @@ int main(int argc, char* argv[]) ...@@ -681,7 +681,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0 #elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif #endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
......
...@@ -646,6 +646,9 @@ struct Blockwise4dTensorCopy3 ...@@ -646,6 +646,9 @@ struct Blockwise4dTensorCopy3
constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr auto clipboard_desc = make_ConstantTensorDescriptor(
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
#pragma unroll #pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{ {
...@@ -664,13 +667,10 @@ struct Blockwise4dTensorCopy3 ...@@ -664,13 +667,10 @@ struct Blockwise4dTensorCopy3
iloop_d2 * thread_per_d2, iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead); iloop_d3 * thread_per_d3 * DataPerRead);
const index_t dst_offset = const index_t clipboard_offset = clipboard_desc.Get1dIndex(
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_clipboard[dst_offset])) = *(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) =
*(reinterpret_cast<const vector_t*>( *(reinterpret_cast<const vector_t*>(
&p_src[src_offset + mSrcMyThreadOffset])); &p_src[src_offset + mSrcMyThreadOffset]));
} }
...@@ -713,6 +713,9 @@ struct Blockwise4dTensorCopy3 ...@@ -713,6 +713,9 @@ struct Blockwise4dTensorCopy3
constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr auto clipboard_desc = make_ConstantTensorDescriptor(
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
#pragma unroll #pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{ {
...@@ -725,11 +728,8 @@ struct Blockwise4dTensorCopy3 ...@@ -725,11 +728,8 @@ struct Blockwise4dTensorCopy3
#pragma unroll #pragma unroll
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{ {
const index_t src_offset = const index_t clipboard_offset = clipboard_desc.Get1dIndex(
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
const index_t dst_offset = const index_t dst_offset =
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
...@@ -738,7 +738,7 @@ struct Blockwise4dTensorCopy3 ...@@ -738,7 +738,7 @@ struct Blockwise4dTensorCopy3
iloop_d3 * thread_per_d3 * DataPerRead); iloop_d3 * thread_per_d3 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) = *(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[src_offset])); *(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
} }
} }
} }
......
...@@ -263,6 +263,94 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -263,6 +263,94 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
#if DEVICE_BACKEND_HIP
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n");
static_assert(
BlockMatrixStrideA == 0 && BatchPerThread == 1,
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}
#endif
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC> template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const FloatC* __restrict__ p_c_block) const
......
...@@ -127,6 +127,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -127,6 +127,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
#if DEVICE_BACKEND_HIP #if DEVICE_BACKEND_HIP
// TODO: this is not working correctly
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block, __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
......
...@@ -204,21 +204,36 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -204,21 +204,36 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads()) __syncthreads())
{ {
// input: global mem to LDS #if 1
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
// weight: global mem to LDS
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
#else
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
#endif
__syncthreads(); __syncthreads();
// a series of batched GEMM #pragma unroll
for(index_t y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
#pragma unroll
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
blockwise_batch_gemm.Run(p_wei_block + #if 1
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), blockwise_batch_gemm.Run
#else
blockwise_batch_gemm.Run_asm
#endif
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
p_out_thread); p_out_thread);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment