Commit 05d7a087 authored by Chao Liu's avatar Chao Liu
Browse files

enable 128x128 block gemm

parent 6a3f3f95
...@@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
#elif 1 #elif 0
// 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer // 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -221,8 +221,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -221,8 +221,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// 1x1, 14x14, Vega 20, hack CPerBlock = 1 // 1x1, 14x14, Vega 20, try
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8; constexpr index_t CPerBlock = 8;
......
...@@ -377,13 +377,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -377,13 +377,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
//auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; // auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
//auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; // auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread); Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread); Float4* reg_c = (Float4*)(p_c_thread);
void* a_loc = (void *)(p_a_block + mMyThreadOffsetA); void* a_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_loc = (void *)(p_b_block + mMyThreadOffsetB); void* b_loc = (void*)(p_b_block + mMyThreadOffsetB);
// loop over k // loop over k
int k_chunk = 2; int k_chunk = 2;
#pragma unroll #pragma unroll
...@@ -403,9 +403,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -403,9 +403,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else #else
int k = k_begin; int k = k_begin;
int lds_a_block_off = sizeof(Float) * M; int lds_a_block_off = sizeof(Float) * M;
int lds_b_block_off = sizeof(Float) * N; int lds_b_block_off = sizeof(Float) * N;
int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float); int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float); int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off); ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off);
......
...@@ -272,7 +272,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn ...@@ -272,7 +272,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t max_align = constexpr index_t max_align =
mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr index_t in_block_element_space = constexpr index_t in_block_element_space =
in_cb_block_desc.GetElementSpace(Number<max_align>{}); in_cb_block_desc.GetElementSpace(Number<max_align>{});
...@@ -297,7 +297,8 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn ...@@ -297,7 +297,8 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), __syncthreads()) p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads())
{ {
// load data // load data
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
......
...@@ -4,56 +4,68 @@ typedef float Float4 __attribute__((ext_vector_type(4))); ...@@ -4,56 +4,68 @@ typedef float Float4 __attribute__((ext_vector_type(4)));
extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]];
inline __device__ void lgkmcnt(int cnt){ inline __device__ void lgkmcnt(int cnt)
{
#if 1 #if 1
if(cnt == 0) { if(cnt == 0)
{
asm volatile("\n \ asm volatile("\n \
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
"::); " ::);
} }
else if(cnt == 1) { else if(cnt == 1)
{
asm volatile("\n \ asm volatile("\n \
s_waitcnt lgkmcnt(1) \n \ s_waitcnt lgkmcnt(1) \n \
"::); " ::);
} }
else if(cnt == 2) { else if(cnt == 2)
{
asm volatile("\n \ asm volatile("\n \
s_waitcnt lgkmcnt(2) \n \ s_waitcnt lgkmcnt(2) \n \
"::); " ::);
} }
else if(cnt == 3) { else if(cnt == 3)
{
asm volatile("\n \ asm volatile("\n \
s_waitcnt lgkmcnt(3) \n \ s_waitcnt lgkmcnt(3) \n \
"::); " ::);
} }
else if(cnt == 4) { else if(cnt == 4)
{
asm volatile("\n \ asm volatile("\n \
s_waitcnt lgkmcnt(4) \n \ s_waitcnt lgkmcnt(4) \n \
"::); " ::);
} }
else { else
{
assert(0); assert(0);
} }
#endif #endif
} }
inline __device__ void outerProduct1x4(const float *a, const float *b, float *c) { inline __device__ void outerProduct1x4(const float* a, const float* b, float* c)
{
asm volatile("\n \ asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \ v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \ v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \ v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \ v_mac_f32 %3, %4, %8 \n \
" "
: : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
"=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) : "v"(a[0]),
: "v"(b[0]),
"v"(a[0]), "v"(b[1]),
"v"(b[0]), "v"(b[1]), "v"(b[2]), "v"(b[3]), "v"(b[2]),
"0"(c[0]), "1"(c[1]), "2"(c[2]), "3"(c[3]) "v"(b[3]),
); "0"(c[0]),
"1"(c[1]),
"2"(c[2]),
"3"(c[3]));
} }
inline __device__ void outerProduct1x4(const float &a, const Float4 &b, Float4 &c) { inline __device__ void outerProduct1x4(const float& a, const Float4& b, Float4& c)
{
#if 0 #if 0
asm volatile( asm volatile(
"\n \ "\n \
...@@ -67,12 +79,13 @@ inline __device__ void outerProduct1x4(const float &a, const Float4 &b, Float4 & ...@@ -67,12 +79,13 @@ inline __device__ void outerProduct1x4(const float &a, const Float4 &b, Float4 &
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w) "v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
); );
#else #else
outerProduct1x4(&a, (float *)&b, (float *)&c); outerProduct1x4(&a, (float*)&b, (float*)&c);
#endif #endif
} }
inline __device__ void
inline __device__ void outerProduct4x4(const Float4 &a, const Float4 &b, Float4 &c0, Float4 &c1, Float4 &c2, Float4 &c3) { outerProduct4x4(const Float4& a, const Float4& b, Float4& c0, Float4& c1, Float4& c2, Float4& c3)
{
#if 0 #if 0
asm volatile( asm volatile(
"\n \ "\n \
...@@ -126,7 +139,7 @@ inline __device__ void outerProduct4x4(const Float4 &a, const Float4 &b, Float4 ...@@ -126,7 +139,7 @@ inline __device__ void outerProduct4x4(const Float4 &a, const Float4 &b, Float4
#endif #endif
} }
inline __device__ void outerProduct8x8(const Float4 *a, const Float4 *b, Float4 *c) inline __device__ void outerProduct8x8(const Float4* a, const Float4* b, Float4* c)
{ {
outerProduct4x4(a[0], b[0], c[0], c[2], c[4], c[6]); outerProduct4x4(a[0], b[0], c[0], c[2], c[4], c[6]);
outerProduct4x4(a[0], b[1], c[1], c[3], c[5], c[7]); outerProduct4x4(a[0], b[1], c[1], c[3], c[5], c[7]);
...@@ -134,250 +147,223 @@ inline __device__ void outerProduct8x8(const Float4 *a, const Float4 *b, Float4 ...@@ -134,250 +147,223 @@ inline __device__ void outerProduct8x8(const Float4 *a, const Float4 *b, Float4
outerProduct4x4(a[1], b[1], c[9], c[11], c[13], c[15]); outerProduct4x4(a[1], b[1], c[9], c[11], c[13], c[15]);
} }
inline __device__ void ds_read_b128(Float4 &r, void *lds, int offset = 0) inline __device__ void ds_read_b128(Float4& r, void* lds, int offset = 0)
{ {
if(offset == 0) if(offset == 0)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:0 \n \ ds_read_b128 %0, %1 offset:0 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 128) else if(offset == 128)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:128 \n \ ds_read_b128 %0, %1 offset:128 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 256) else if(offset == 256)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:256 \n \ ds_read_b128 %0, %1 offset:256 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 384) else if(offset == 384)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:384 \n \ ds_read_b128 %0, %1 offset:384 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 512) else if(offset == 512)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:512 \n \ ds_read_b128 %0, %1 offset:512 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 640) else if(offset == 640)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:640 \n \ ds_read_b128 %0, %1 offset:640 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 768) else if(offset == 768)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:768 \n \ ds_read_b128 %0, %1 offset:768 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 896) else if(offset == 896)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:896 \n \ ds_read_b128 %0, %1 offset:896 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1024) else if(offset == 1024)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1024 \n \ ds_read_b128 %0, %1 offset:1024 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1152) else if(offset == 1152)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1152 \n \ ds_read_b128 %0, %1 offset:1152 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1280) else if(offset == 1280)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1280 \n \ ds_read_b128 %0, %1 offset:1280 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1408) else if(offset == 1408)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1408 \n \ ds_read_b128 %0, %1 offset:1408 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1536) else if(offset == 1536)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1536 \n \ ds_read_b128 %0, %1 offset:1536 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1664) else if(offset == 1664)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1664 \n \ ds_read_b128 %0, %1 offset:1664 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1792) else if(offset == 1792)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1792 \n \ ds_read_b128 %0, %1 offset:1792 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 1920) else if(offset == 1920)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:1920 \n \ ds_read_b128 %0, %1 offset:1920 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 2048) else if(offset == 2048)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:2048 \n \ ds_read_b128 %0, %1 offset:2048 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 2176) else if(offset == 2176)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:2176 \n \ ds_read_b128 %0, %1 offset:2176 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 2304) else if(offset == 2304)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:2304 \n \ ds_read_b128 %0, %1 offset:2304 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 2560) else if(offset == 2560)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:2560 \n \ ds_read_b128 %0, %1 offset:2560 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 2816) else if(offset == 2816)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:2816 \n \ ds_read_b128 %0, %1 offset:2816 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 3072) else if(offset == 3072)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:3072 \n \ ds_read_b128 %0, %1 offset:3072 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 3328) else if(offset == 3328)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:3328 \n \ ds_read_b128 %0, %1 offset:3328 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 3584) else if(offset == 3584)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:3584 \n \ ds_read_b128 %0, %1 offset:3584 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 3840) else if(offset == 3840)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:3840 \n \ ds_read_b128 %0, %1 offset:3840 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 4096) else if(offset == 4096)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:4096 \n \ ds_read_b128 %0, %1 offset:4096 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else if(offset == 4352) else if(offset == 4352)
{ {
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1 offset:4352 \n \ ds_read_b128 %0, %1 offset:4352 \n \
" "
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds)) : "v"(__to_local(lds)));
);
} }
else else
{ {
......
...@@ -31,10 +31,10 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -31,10 +31,10 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const index_t src_index = src_mtx.Get1dIndex(i, 0); const index_t src_index = src_mtx.Get1dIndex(i, 0);
const index_t dst_index = dst_mtx.Get1dIndex(i, 0); const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
Float4 *reg_p = (Float4 *)&p_dst[dst_index]; Float4* reg_p = (Float4*)&p_dst[dst_index];
Float4 *loc_p = (Float4 *)&p_src[src_index]; Float4* loc_p = (Float4*)&p_src[src_index];
ds_read_b128(reg_p[0], (void *)&loc_p[0]); ds_read_b128(reg_p[0], (void*)&loc_p[0]);
} }
#endif #endif
} }
...@@ -86,9 +86,9 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -86,9 +86,9 @@ __device__ void threadwise_gemm(MatrixA,
} }
} }
#else #else
const Float4 *a_vec = (const Float4 *)p_a_thread; const Float4* a_vec = (const Float4*)p_a_thread;
const Float4 *b_vec = (const Float4 *)p_b_thread; const Float4* b_vec = (const Float4*)p_b_thread;
Float4 *c_vec = (Float4 *)p_c_thread; Float4* c_vec = (Float4*)p_c_thread;
outerProduct8x8(a_vec, b_vec, c_vec); outerProduct8x8(a_vec, b_vec, c_vec);
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment