Commit 57a8ccf3 authored by Jing Zhang's avatar Jing Zhang
Browse files

in progress

parent 66d5e5b3
......@@ -368,8 +368,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
FloatA *p_a_thread = p_thread;
FloatB *p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
......@@ -381,6 +383,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#if 0
// copy A-sub to form A
#if 0
#pragma unroll
......@@ -406,13 +409,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
float4* reg = (float4 *)(p_a_thread + dst_index);
reg[0] = loc[0];
reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
reg[1] = loc[16];
//reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
//asm volatile("\n \
//ds_read2_b64 %0, %2 offset1:1 \n \
//ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
//ds_read2_b64 %1, %2 offset0:32 offset1:33 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[0]), "=v"(reg[MPerThreadSubC/4])
//: "v"(__to_local((void *)&p_a_block[src_index]))
//: "=v"(reg[0]), "=v"(reg[1])
//: "v"(__to_local((void *)(loc)))
//);
}
#endif
......@@ -439,8 +443,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
float4* reg = (float4 *)(p_b_thread + dst_index);
reg[0] = loc[0];
reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4];
reg[1] = loc[8];
//reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4];
//asm volatile("\n \
//ds_read2_b64 %0, %2 offset1:1 \n \
//ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
//s_waitcnt lgkmcnt(0)"
//: "=v"(reg[0]), "=v"(reg[1])
//: "v"(__to_local((void *)(loc)))
//);
}
#endif
#else
auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
const float4* a_loc = (const float4 *)(p_a_block + a_src_index);
const float4* b_loc = (const float4 *)(p_b_block + b_src_index);
float4* reg = (float4 *)(p_a_thread + dst_index);
//reg[0] = a_loc[0];
//reg[1] = a_loc[16];
//reg[2] = b_loc[0];
//reg[3] = b_loc[8];
//s_waitcnt lgkmcnt(0) // 000000001398: BF8CC07F
asm volatile("\n \
ds_read2_b64 %0, %4 offset1:1 \n \
ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \
ds_read2_b64 %2, %5 offset1:1 \n \
ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3])
: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc)))
);
#endif
// C = A * B
......@@ -495,7 +534,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
"v"(p_b_thread[bindex + 4]),
"v"(p_b_thread[bindex + 5]),
"v"(p_b_thread[bindex + 6]),
"v"(p_b_thread[bindex + 7])
"v"(p_b_thread[bindex + 7]),
"0"(p_c_thread[cindex + 0]),
"1"(p_c_thread[cindex + 1]),
"2"(p_c_thread[cindex + 2]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment