Commit d990eff6 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 437c996a
...@@ -529,8 +529,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -529,8 +529,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2, static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline");
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A-sub, B-sub // thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
...@@ -557,83 +556,83 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -557,83 +556,83 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
a_thread_buf); a_thread_buf);
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
b_thread_buf); b_thread_buf);
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<NPerLevel1Cluster>{}), make_tuple(I0, Number<NPerLevel1Cluster>{}),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(I0, Number<NPerThreadSubC>{}),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<MPerLevel1Cluster>{}), make_tuple(I0, Number<MPerLevel1Cluster>{}),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(I0, Number<MPerThreadSubC>{}),
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<0>{}, Number<0>{})); make_tuple(I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf, c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{})); make_tuple(I0, Number<NPerThreadSubC>{}));
// loop over rest of k // loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) { static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<0>{}), make_tuple(k, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{})); make_tuple(Number<MPerThreadSubC>{}, I0));
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<0>{}), make_tuple(k, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
b_thread_buf); b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{})); make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
...@@ -642,7 +641,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -642,7 +641,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(k, Number<NPerLevel1Cluster>{}), make_tuple(k, Number<NPerLevel1Cluster>{}),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(I0, Number<NPerThreadSubC>{}),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
...@@ -650,39 +649,39 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -650,39 +649,39 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(k, Number<MPerLevel1Cluster>{}), make_tuple(k, Number<MPerLevel1Cluster>{}),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(I0, Number<MPerThreadSubC>{}),
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<0>{}, Number<0>{})); make_tuple(I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf, c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{})); make_tuple(I0, Number<NPerThreadSubC>{}));
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{})); make_tuple(Number<MPerThreadSubC>{}, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(I0, Number<MPerThreadSubC>{}),
b_thread_buf, b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(I0, Number<NPerThreadSubC>{}),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{})); make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment