Commit c0ffe379 authored by Jing Zhang's avatar Jing Zhang
Browse files

add 2x2 pipeline

parent 40016f20
......@@ -36,9 +36,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// static constexpr index_t MPerBlock = M0 * M1; // A is transposed
// static constexpr index_t NPerBlock = N0 * N1;
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
......@@ -101,9 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
}
}
__device__ static CIndex CalculateCThreadOriginDataIndex(const index_t m_repeat_id,
const index_t n_repeat_id,
const index_t blk_i)
__device__ static CIndex
CalculateCThreadOriginDataIndex(const index_t m0, const index_t n0, const index_t blk_i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
......@@ -113,8 +109,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t row = m_repeat_id * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n_repeat_id * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
const index_t row = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
return CIndex{row, col};
}
......@@ -148,40 +144,92 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_for<0, KPerBlock, KPerWave>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
make_tuple(I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
make_tuple(I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<KPerWave, KPerBlock, KPerWave>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0),
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
make_tuple(I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0),
a_block_buf,
......@@ -189,18 +237,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
});
}
private:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment