Commit 66d5e5b3 authored by Jing Zhang's avatar Jing Zhang
Browse files

in progress

parent f7498d66
...@@ -402,18 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -402,18 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
//const float4* loc = (const float4 *)(p_a_block + src_index); const float4* loc = (const float4 *)(p_a_block + src_index);
float4* reg = (float4 *)(p_a_thread + dst_index); float4* reg = (float4 *)(p_a_thread + dst_index);
//reg[0] = loc[0]; reg[0] = loc[0];
//reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4]; reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
asm volatile("\n \ //asm volatile("\n \
ds_read2_b64 %0, %2 offset1:1 \n \ //ds_read2_b64 %0, %2 offset1:1 \n \
ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \ //ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
s_waitcnt lgkmcnt(0)" //s_waitcnt lgkmcnt(0)"
: "=v"(reg[0]), "=v"(reg[1]) //: "=v"(reg[0]), "=v"(reg[MPerThreadSubC/4])
: "v"(__to_local((void *)&p_a_block[src_index])) //: "v"(__to_local((void *)&p_a_block[src_index]))
); //);
} }
#endif #endif
...@@ -459,16 +459,52 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -459,16 +459,52 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
for(index_t k = 0; k < 1; ++k) for(index_t k = 0; k < 1; ++k)
{ {
// M = 8 // M = 8
const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, 0);
for(index_t i = 0; i < 8; ++i) for(index_t i = 0; i < 8; ++i)
{ {
// N = 8 // N = 8
for(index_t j = 0; j < 8; ++j)
{
const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, j); const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
const index_t cindex = c_thread_mtx.Get1dIndex(i, j); //for(index_t j = 0; j < 8; ++j)
{
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex]; //p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
asm volatile("\n \
v_mac_f32 %0, %8, %9 \n \
v_mac_f32 %1, %8, %10 \n \
v_mac_f32 %2, %8, %11 \n \
v_mac_f32 %3, %8, %12 \n \
v_mac_f32 %4, %8, %13 \n \
v_mac_f32 %5, %8, %14 \n \
v_mac_f32 %6, %8, %15 \n \
v_mac_f32 %7, %8, %16 \n \
"
: "=v"(p_c_thread[cindex + 0]),
"=v"(p_c_thread[cindex + 1]),
"=v"(p_c_thread[cindex + 2]),
"=v"(p_c_thread[cindex + 3]),
"=v"(p_c_thread[cindex + 4]),
"=v"(p_c_thread[cindex + 5]),
"=v"(p_c_thread[cindex + 6]),
"=v"(p_c_thread[cindex + 7])
: "v"(p_a_thread[aindex]),
"v"(p_b_thread[bindex + 0]),
"v"(p_b_thread[bindex + 1]),
"v"(p_b_thread[bindex + 2]),
"v"(p_b_thread[bindex + 3]),
"v"(p_b_thread[bindex + 4]),
"v"(p_b_thread[bindex + 5]),
"v"(p_b_thread[bindex + 6]),
"v"(p_b_thread[bindex + 7])
"0"(p_c_thread[cindex + 0]),
"1"(p_c_thread[cindex + 1]),
"2"(p_c_thread[cindex + 2]),
"3"(p_c_thread[cindex + 3]),
"4"(p_c_thread[cindex + 4]),
"5"(p_c_thread[cindex + 5]),
"6"(p_c_thread[cindex + 6]),
"7"(p_c_thread[cindex + 7])
);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment