Commit fbc7817b authored by Chao Liu's avatar Chao Liu
Browse files

add asm into lds_double_buffer version

parent 155d7859
...@@ -271,7 +271,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -271,7 +271,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 1 #if 0
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else #else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
...@@ -306,7 +306,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -306,7 +306,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
float time = launch_kernel(gridwise_conv.Run, float time = launch_kernel(gridwise_conv.Run,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
gridwise_conv.GetSharedMemoryUsage(), gridwise_conv.GetDynamicSharedMemoryUsage(),
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()), static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer())); static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
......
...@@ -34,9 +34,8 @@ template <index_t GridSize, ...@@ -34,9 +34,8 @@ template <index_t GridSize,
index_t WeiBlockCopyThreadPerDim1, index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead> index_t WeiBlockCopyDataPerRead>
class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
{ {
public:
__host__ __device__ constexpr index_t GetInputBlockElementSpace() const __host__ __device__ constexpr index_t GetInputBlockElementSpace() const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -97,7 +96,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn ...@@ -97,7 +96,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
return wei_cyxk_block_desc.GetElementSpace(Number<max_align>{}); return wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
} }
__host__ __device__ constexpr index_t GetSharedMemoryUsage() const __host__ __device__ constexpr index_t GetDynamicSharedMemoryUsage() const
{ {
return (GetInputBlockElementSpace() + GetWeightBlockElementSpace()) * sizeof(Float); return (GetInputBlockElementSpace() + GetWeightBlockElementSpace()) * sizeof(Float);
...@@ -300,22 +299,38 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn ...@@ -300,22 +299,38 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads()) __syncthreads())
{ {
// load data // load data
//blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); #if 0
//blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
#elif 0
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
#elif 1
Float4 tmp_in, tmp_wei; Float4 tmp_in, tmp_wei;
Float4* glb_in_p = (Float4 *)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset); Float4* glb_in_p =
Float4* loc_in_p = (Float4 *)(p_in_block + blockwise_in_copy.mDstMyThreadOffset); (Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
Float4* loc_in_p = (Float4*)(p_in_block + blockwise_in_copy.mDstMyThreadOffset);
Float4* glb_wei_p = (Float4 *)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset); Float4* glb_wei_p =
Float4* loc_wei_p = (Float4 *)(p_wei_block + blockwise_wei_copy.mDstMyThreadOffset); (Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
Float4* loc_wei_p = (Float4*)(p_wei_block + blockwise_wei_copy.mDstMyThreadOffset);
global_load(tmp_in, glb_in_p); global_load(tmp_in, glb_in_p);
global_load(tmp_wei, glb_wei_p); global_load(tmp_wei, glb_wei_p);
vmcnt(0); vmcnt(0);
ds_write_b128(tmp_in, loc_in_p); ds_write_b128(tmp_in, loc_in_p);
ds_write_b128(tmp_wei, loc_wei_p); ds_write_b128(tmp_wei, loc_wei_p);
#endif
__syncthreads(); __syncthreads();
......
...@@ -34,9 +34,10 @@ template <index_t GridSize, ...@@ -34,9 +34,10 @@ template <index_t GridSize,
index_t WeiBlockCopyThreadPerDim1, index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead> index_t WeiBlockCopyDataPerRead>
class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
{ {
public: __host__ __device__ constexpr index_t GetDynamicSharedMemoryUsage() const { return 0; }
__global__ static void Run(const Float* const __restrict__ p_in_global, __global__ static void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) Float* const __restrict__ p_out_global)
...@@ -239,9 +240,27 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer ...@@ -239,9 +240,27 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
// preload data into LDS // preload data into LDS
#if 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0);
#else
Float4 tmp_in, tmp_wei;
Float4* glb_in_p =
(Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
Float4* glb_wei_p =
(Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
global_load(tmp_in, glb_in_p);
global_load(tmp_wei, glb_wei_p);
Float4* loc_in_p = (Float4*)(p_in_block_0 + blockwise_in_copy.mDstMyThreadOffset);
Float4* loc_wei_p = (Float4*)(p_wei_block_0 + blockwise_wei_copy.mDstMyThreadOffset);
vmcnt(0);
ds_write_b128(tmp_in, loc_in_p);
ds_write_b128(tmp_wei, loc_wei_p);
#endif
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
...@@ -270,9 +289,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer ...@@ -270,9 +289,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
// load next data // load next data
#if 0 #if 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
#elif 1
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
...@@ -281,6 +297,15 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer ...@@ -281,6 +297,15 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_clipboard);
#elif 1
Float4 tmp_in, tmp_wei;
Float4* glb_in_p =
(Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
Float4* glb_wei_p =
(Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
global_load(tmp_in, glb_in_p);
global_load(tmp_wei, glb_wei_p);
#endif #endif
// compute on current data // compute on current data
...@@ -290,22 +315,31 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer ...@@ -290,22 +315,31 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 1 #if 0
blockwise_gemm.Run blockwise_gemm.Run
#else #elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#elif 1
blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
#if 1 #if 0
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next); p_wei_block_next);
#elif 1
Float4* loc_in_p = (Float4*)(p_in_block_next + blockwise_in_copy.mDstMyThreadOffset);
Float4* loc_wei_p = (Float4*)(p_wei_block_next + blockwise_wei_copy.mDstMyThreadOffset);
vmcnt(0);
ds_write_b128(tmp_in, loc_in_p);
ds_write_b128(tmp_wei, loc_wei_p);
#endif #endif
} }
...@@ -321,15 +355,17 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer ...@@ -321,15 +355,17 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 1 #if 0
blockwise_gemm.Run blockwise_gemm.Run
#else #elif 1
blockwise_gemm.Run_asm
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
} }
......
...@@ -4,28 +4,34 @@ typedef float Float4 __attribute__((ext_vector_type(4))); ...@@ -4,28 +4,34 @@ typedef float Float4 __attribute__((ext_vector_type(4)));
extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]];
inline __device__ void vmcnt(int cnt) { inline __device__ void vmcnt(int cnt)
if(cnt == 0) { {
asm volatile ("\n \ if(cnt == 0)
{
asm volatile("\n \
s_waitcnt vmcnt(0) \n \ s_waitcnt vmcnt(0) \n \
"::); " ::);
} }
else if(cnt == 1) { else if(cnt == 1)
asm volatile ("\n \ {
asm volatile("\n \
s_waitcnt vmcnt(1) \n \ s_waitcnt vmcnt(1) \n \
"::); " ::);
} }
else if(cnt == 2) { else if(cnt == 2)
asm volatile ("\n \ {
asm volatile("\n \
s_waitcnt vmcnt(2) \n \ s_waitcnt vmcnt(2) \n \
"::); " ::);
} }
else if(cnt == 4) { else if(cnt == 4)
asm volatile ("\n \ {
asm volatile("\n \
s_waitcnt vmcnt(2) \n \ s_waitcnt vmcnt(2) \n \
"::); " ::);
} }
else { else
{
assert(0); assert(0);
} }
} }
...@@ -397,13 +403,13 @@ inline __device__ void ds_read_b128(Float4& r, void* lds, int offset = 0) ...@@ -397,13 +403,13 @@ inline __device__ void ds_read_b128(Float4& r, void* lds, int offset = 0)
} }
} }
inline __device__ void global_load(Float4 &r, Float4* ptr) { inline __device__ void global_load(Float4& r, Float4* ptr)
asm volatile("\n \ {
asm volatile("\n \
global_load_dwordx4 %0, %1, off \n \ global_load_dwordx4 %0, %1, off \n \
" "
:"=v"(r) : "=v"(r)
:"v"(ptr) : "v"(ptr));
);
} }
inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0) inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0)
...@@ -411,8 +417,6 @@ inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0) ...@@ -411,8 +417,6 @@ inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0)
asm volatile("\n \ asm volatile("\n \
ds_write_b128 %0, %1 \n \ ds_write_b128 %0, %1 \n \
" "
: :
: "v"(__to_local(lds)), "v"(r) : "v"(__to_local(lds)), "v"(r));
);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment