Commit e624df92 authored by Chao Liu's avatar Chao Liu
Browse files

enabled ds_read_b128 and ds_write_b128 on hip c++

parent 471830a0
......@@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256;
#elif 1
#elif 0
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
......@@ -219,7 +219,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 0
#elif 1
// 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
......
......@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -583,7 +583,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
#elif 1
// 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128;
constexpr index_t C = 2048;
......@@ -667,9 +667,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
......
#!/bin/bash
export KMDUMPISA=1
export KMDUMPLLVM=1
export KMOPTLLC=-mattr=+enable-ds128
make -j driver
/opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.isa
/opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.asm
......@@ -132,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
......@@ -164,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n");
using Float4 = vector_type<float, 4>::MemoryType;
float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
FloatA* p_a_thread = p_thread;
FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
void* a_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_loc = (void*)(p_b_block + mMyThreadOffsetB);
int lds_a_block_off = sizeof(Float) * M;
int lds_b_block_off = sizeof(Float) * N;
int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_b[0], b_loc, 0);
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
lgkmcnt(2);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
#pragma unroll
for(int k_i = 1; k_i < K; k_i++)
for(index_t k = 1; k < K; ++k)
{
ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
lgkmcnt(2);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
......
......@@ -213,17 +213,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
#if 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double);
#else
vmcnt(0);
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
p_wei_block_double);
#endif
}
// register
......@@ -261,7 +253,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
......@@ -271,11 +262,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
for(index_t x = 0; x < X; ++x)
{
#if 1
#if 0
blockwise_gemm.Run
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0
#elif 1
blockwise_gemm.Run_asm
#endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
......@@ -284,18 +275,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
#if 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
#else
vmcnt(0);
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
p_wei_block_next);
#endif
}
}
......@@ -320,11 +303,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
for(index_t x = 0; x < X; ++x)
{
#if 1
#if 0
blockwise_gemm.Run
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0
#elif 1
blockwise_gemm.Run_asm
#endif
(p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
......@@ -333,19 +316,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
}
}
#if 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double + wei_block_space);
#else
vmcnt(0);
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
p_wei_block_double + wei_block_space);
#endif
// odd
__syncthreads();
......@@ -354,11 +328,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
for(index_t x = 0; x < X; ++x)
{
#if 1
#if 0
blockwise_gemm.Run
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0
#elif 1
blockwise_gemm.Run_asm
#endif
(p_wei_block_double + wei_block_space +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment