Commit d990eff6 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 437c996a
......@@ -529,8 +529,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline");
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
......@@ -557,83 +556,83 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<NPerLevel1Cluster>{}),
make_tuple(I0, Number<NPerLevel1Cluster>{}),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
make_tuple(I0, Number<NPerThreadSubC>{}),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<MPerLevel1Cluster>{}),
make_tuple(I0, Number<MPerLevel1Cluster>{}),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
make_tuple(I0, Number<MPerThreadSubC>{}),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
c_thread_buf,
make_tuple(Number<0>{}, Number<0>{}));
make_tuple(I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}));
make_tuple(I0, Number<NPerThreadSubC>{}));
// loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<0>{}),
make_tuple(k, I0),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{}));
make_tuple(Number<MPerThreadSubC>{}, I0));
// read B_sub_0
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<0>{}),
make_tuple(k, I0),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
......@@ -642,7 +641,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(k, Number<NPerLevel1Cluster>{}),
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
make_tuple(I0, Number<NPerThreadSubC>{}),
b_thread_buf);
// read A_sub_1
......@@ -650,39 +649,39 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(k, Number<MPerLevel1Cluster>{}),
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
make_tuple(I0, Number<MPerThreadSubC>{}),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
c_thread_buf,
make_tuple(Number<0>{}, Number<0>{}));
make_tuple(I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}));
make_tuple(I0, Number<NPerThreadSubC>{}));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
make_tuple(I0, I0),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{}));
make_tuple(Number<MPerThreadSubC>{}, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment