"...composable_kernel.git" did not exist on "81b79a77af658629c9a05e52bac29347a6deb750"
Commit 62ebdfde authored by Jing Zhang's avatar Jing Zhang
Browse files

clean xdlops_gemm

parent cb35d6fc
......@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t KPerBlock = K0;
static constexpr index_t KPack = K1;
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
......@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx();
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(xdlops_gemm.IsKReduction)
{
return make_tuple(blk_id, 0, waveId_m, blk_td, 0);
}
else
{
return make_tuple(0, 0, waveId_m, laneId, 0);
}
return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
if constexpr(xdlops_gemm.IsKReduction)
{
return make_tuple(blk_id, 0, waveId_n, blk_td, 0);
}
else
{
return make_tuple(0, 0, waveId_n, laneId, 0);
}
return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
......@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
......@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k0) {
// read A
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(k, I0, I0, I0, I0),
make_tuple(k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
......@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// read B
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
make_tuple(k, I0, I0, I0, I0),
make_tuple(k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type;
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_per_blk>::type;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
......
......@@ -3,7 +3,7 @@
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
//#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment