Commit 62ebdfde authored by Jing Zhang's avatar Jing Zhang
Browse files

clean xdlops_gemm

parent cb35d6fc
...@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t KPerBlock = K0; static constexpr index_t KPerBlock = K0;
static constexpr index_t KPack = K1;
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
...@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx(); const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
const auto blk_id = blk_idx[I0]; return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
const auto blk_td = blk_idx[I1];
if constexpr(xdlops_gemm.IsKReduction)
{
return make_tuple(blk_id, 0, waveId_m, blk_td, 0);
}
else
{
return make_tuple(0, 0, waveId_m, laneId, 0);
}
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx();
const auto blk_id = blk_idx[I0]; const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
const auto blk_td = blk_idx[I1];
if constexpr(xdlops_gemm.IsKReduction) return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
{
return make_tuple(blk_id, 0, waveId_n, blk_td, 0);
}
else
{
return make_tuple(0, 0, waveId_n, laneId, 0);
}
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
...@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!"); "wrong!");
...@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<FloatAB, K1> b_thread_vec; vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k0) {
// read A // read A
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(k, I0, I0, I0, I0), make_tuple(k0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
...@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// read B // read B
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
make_tuple(k, I0, I0, I0, I0), make_tuple(k0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
using mfma_input_type = using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type; typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_per_blk>::type;
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> //#include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment