"git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "2337da18dde074890d69b5670d7030b89c2a71b5"
Commit 66edb259 authored by Chao Liu's avatar Chao Liu
Browse files

Merge branch 'inline_asm_v2' of github.com:asroy/modular_convolution into inline_asm_v2

parents 19b41797 62c4d5df
...@@ -361,10 +361,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -361,10 +361,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A-sub, B-sub for copy // thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{}); Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{}); Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()]; float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
...@@ -377,66 +377,42 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -377,66 +377,42 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
// auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread); Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread); Float4* reg_c = (Float4*)(p_c_thread);
void* a_loc = (void*)(p_a_block + mMyThreadOffsetA); void* a_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_loc = (void*)(p_b_block + mMyThreadOffsetB); void* b_loc = (void*)(p_b_block + mMyThreadOffsetB);
// loop over k
int k_chunk = K;
// for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk)
index_t k_begin = 0;
{
#if 0 int lds_a_block_off = sizeof(Float) * M;
ds_read_b128(reg_a[0], a_loc, 0); int lds_b_block_off = sizeof(Float) * N;
ds_read_b128(reg_a[1], a_loc, 256); int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
ds_read_b128(reg_b[0], b_loc, 0); int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
ds_read_b128(reg_b[1], b_loc, 128); ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_b[0], b_loc, 0);
lgkmcnt(0); ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
#pragma unroll
for(int k_i = 1; k_i < K; k_i++)
{
ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
int k = k_begin; ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
int lds_a_block_off = sizeof(Float) * M;
int lds_b_block_off = sizeof(Float) * N;
int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off);
ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off);
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off);
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off);
lgkmcnt(2); lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1); lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0); lgkmcnt(0);
#pragma unroll
for(int i = 0; i < k_chunk - 1; i++)
{
k = k + 1;
ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off);
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
lgkmcnt(0);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#endif
} }
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator> template <class FloatA, class FloatB, class FloatC, class Accumulator>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment