Commit e624df92 authored by Chao Liu's avatar Chao Liu
Browse files

enabled ds_read_b128 and ds_write_b128 on hip c++

parent 471830a0
...@@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
#elif 1 #elif 0
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -219,7 +219,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -219,7 +219,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 1
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer // 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
constexpr index_t BPerBlock = 128; constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
......
...@@ -409,7 +409,7 @@ int main(int argc, char* argv[]) ...@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -583,7 +583,7 @@ int main(int argc, char* argv[]) ...@@ -583,7 +583,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 1x1 filter, 14x14 image, C = 2048 // 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 2048; constexpr index_t C = 2048;
...@@ -667,9 +667,9 @@ int main(int argc, char* argv[]) ...@@ -667,9 +667,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0 #elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0 #elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif #endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
......
#!/bin/bash #!/bin/bash
export KMDUMPISA=1 export KMDUMPISA=1
export KMDUMPLLVM=1 export KMDUMPLLVM=1
export KMOPTLLC=-mattr=+enable-ds128
make -j driver make -j driver
/opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.isa /opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.asm
...@@ -132,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -132,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const FloatC* __restrict__ p_c_thread) const
{ {
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{}; constexpr auto False = integral_constant<bool, false>{};
...@@ -164,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -164,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{}); Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 && static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8, MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n"); "Run_asm cannot deal with this GEMM shape yet\n");
using Float4 = vector_type<float, 4>::MemoryType; using Float4 = vector_type<float, 4>::MemoryType;
float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
FloatA* p_a_thread = p_thread;
FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread); Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread); Float4* reg_c = (Float4*)(p_c_thread);
void* a_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_loc = (void*)(p_b_block + mMyThreadOffsetB); reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
int lds_a_block_off = sizeof(Float) * M; reg_b[1] =
int lds_b_block_off = sizeof(Float) * N; *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float); reg_a[1] =
int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float); *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_b[0], b_loc, 0);
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
#pragma unroll #pragma unroll
for(int k_i = 1; k_i < K; k_i++) for(index_t k = 1; k < K; ++k)
{ {
ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off); reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off); reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off); reg_b[1] = *reinterpret_cast<const Float4*>(
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off); &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
lgkmcnt(2); reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
} }
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
......
...@@ -213,17 +213,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -213,17 +213,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_clipboard);
#if 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double); blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double); p_wei_block_double);
#else
vmcnt(0);
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
p_wei_block_double);
#endif
} }
// register // register
...@@ -261,7 +253,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -261,7 +253,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard); p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_clipboard);
...@@ -271,11 +262,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -271,11 +262,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{ {
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
#if 1 #if 0
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
...@@ -284,18 +275,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -284,18 +275,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
} }
} }
#if 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next); p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next); p_wei_block_next);
#else
vmcnt(0);
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
p_wei_block_next);
#endif
} }
} }
...@@ -320,11 +303,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -320,11 +303,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{ {
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
#if 1 #if 0
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
...@@ -333,19 +316,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -333,19 +316,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
} }
} }
#if 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
#else
vmcnt(0);
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
p_wei_block_double + wei_block_space);
#endif
// odd // odd
__syncthreads(); __syncthreads();
...@@ -354,11 +328,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -354,11 +328,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{ {
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
#if 1 #if 0
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_double + wei_block_space + (p_wei_block_double + wei_block_space +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment