Commit c0ffe379 authored by Jing Zhang's avatar Jing Zhang
Browse files

add 2x2 pipeline

parent 40016f20
...@@ -36,9 +36,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -36,9 +36,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// static constexpr index_t MPerBlock = M0 * M1; // A is transposed
// static constexpr index_t NPerBlock = N0 * N1;
static constexpr index_t MWaves = M1 / MPerWave; static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave; static constexpr index_t NWaves = N1 / NPerWave;
...@@ -101,9 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -101,9 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
} }
} }
__device__ static CIndex CalculateCThreadOriginDataIndex(const index_t m_repeat_id, __device__ static CIndex
const index_t n_repeat_id, CalculateCThreadOriginDataIndex(const index_t m0, const index_t n0, const index_t blk_i)
const index_t blk_i)
{ {
const index_t waveId = get_thread_local_1d_id() / WaveSize; const index_t waveId = get_thread_local_1d_id() / WaveSize;
...@@ -113,8 +109,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -113,8 +109,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
const index_t row = m_repeat_id * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row; const index_t row = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n_repeat_id * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col; const index_t col = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
return CIndex{row, col}; return CIndex{row, col};
} }
...@@ -148,7 +144,54 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -148,7 +144,54 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_for<0, KPerBlock, KPerWave>{}([&](auto k) { // read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<KPerWave, KPerBlock, KPerWave>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0),
a_block_buf, a_block_buf,
...@@ -156,6 +199,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -156,6 +199,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0),
b_block_buf, b_block_buf,
...@@ -163,12 +214,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -163,12 +214,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
b_thread_buf); b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
0, 1,
0>(a_thread_buf, b_thread_buf, c_thread_buf); 1>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0), make_tuple(k, I1, I0),
b_block_buf, b_block_buf,
...@@ -176,12 +229,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -176,12 +229,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
b_thread_buf); b_thread_buf);
xdlops_gemm.template Run2<decltype(a_thread_desc_), // read A_sub_1
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0), make_tuple(k, I1, I0),
a_block_buf, a_block_buf,
...@@ -189,18 +237,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -189,18 +237,34 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
1, 0,
0>(a_thread_buf, b_thread_buf, c_thread_buf); 0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
1, 0,
1>(a_thread_buf, b_thread_buf, c_thread_buf); 1>(a_thread_buf, b_thread_buf, c_thread_buf);
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
} }
private: private:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment