"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "b6571d2295bf4b4d5c6ae159b1d8b4dc36e300a5"
Commit 0d6aa311 authored by Jing Zhang's avatar Jing Zhang
Browse files

inline asm

parent 753b98b5
...@@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
#elif 0 #elif 1
// 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer // 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
......
...@@ -332,12 +332,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -332,12 +332,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat * NPerLevel1Cluster + n_in_sub_c}; n_repeat * NPerLevel1Cluster + n_in_sub_c};
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator, index_t block_off> template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block, __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread, FloatC* __restrict__ p_c_thread,
Accumulator f_accum, Accumulator f_accum) const
Number<block_off>) const
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{}; constexpr auto False = integral_constant<bool, false>{};
...@@ -378,45 +377,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -378,45 +377,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
//auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
//auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
void* a_loc = (void *)(p_a_block + mMyThreadOffsetA);
void* b_loc = (void *)(p_b_block + mMyThreadOffsetB);
#pragma unroll #pragma unroll
// loop over k // loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
#if 0
auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
void* a_loc = (void *)(p_a_block + a_src_index);
void* b_loc = (void *)(p_b_block + b_src_index);
//asm volatile("\n \
//ds_read_b128 %0, %2 \n \
//ds_read_b128 %1, %2 offset:256\n \
//"
//: "=v"(reg_a[0]), "=v"(reg_a[1])
//: "v"(__to_local(a_loc))
//);
ds_read_b128(reg_a[0], a_loc, 0); ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_a[1], a_loc, 256); ds_read_b128(reg_a[1], a_loc, 256);
ds_read_b128(reg_b[0], b_loc, 0); ds_read_b128(reg_b[0], b_loc, 0);
ds_read_b128(reg_b[1], b_loc, 128); ds_read_b128(reg_b[1], b_loc, 128);
lgkmcnt(0); lgkmcnt(0);
threadwise_gemm(a_thread_mtx, outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
True, outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
p_a_thread, outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
b_thread_mtx, outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
False, #else
p_b_thread, ds_read_b128(reg_a[0], a_loc, k_begin * 512);
c_thread_mtx, ds_read_b128(reg_b[0], b_loc, k_begin * 256);
False, ds_read_b128(reg_b[1], b_loc, 128 + k_begin * 256);
p_c_thread, ds_read_b128(reg_a[1], a_loc, 256 + k_begin * 512);
f_accum); lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#endif
} }
} }
......
...@@ -323,7 +323,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn ...@@ -323,7 +323,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + y * Wi + x, p_in_block + y * Wi + x,
p_out_thread, p_out_thread,
f_accum, Number<in_block_element_space>()); f_accum);
} }
} }
} }
......
...@@ -12,7 +12,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -12,7 +12,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr auto src_mtx = SrcMatrix{}; constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{}; constexpr auto dst_mtx = DstMatrix{};
#if 0 #if 1
for(index_t i = 0; i < NRow; ++i) for(index_t i = 0; i < NRow; ++i)
{ {
for(index_t j = 0; j < NCol; ++j) for(index_t j = 0; j < NCol; ++j)
...@@ -72,6 +72,7 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -72,6 +72,7 @@ __device__ void threadwise_gemm(MatrixA,
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
#if 1
for(index_t i = 0; i < M; i+=4) for(index_t i = 0; i < M; i+=4)
{ {
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
...@@ -88,6 +89,13 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -88,6 +89,13 @@ __device__ void threadwise_gemm(MatrixA,
outerProduct4x4(a_vec[0], b_vec[0], c_vec[0], c_vec[2], c_vec[4], c_vec[6]); outerProduct4x4(a_vec[0], b_vec[0], c_vec[0], c_vec[2], c_vec[4], c_vec[6]);
} }
} }
#else
const Float4 *a_vec = (const Float4 *)p_a_thread;
const Float4 *b_vec = (const Float4 *)p_b_thread;
Float4 *c_vec = (Float4 *)p_c_thread;
outerProduct8x8(a_vec, b_vec, c_vec);
#endif
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment