Commit 62ebdfde authored by Jing Zhang's avatar Jing Zhang
Browse files

clean xdlops_gemm

parent cb35d6fc
...@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t KPerBlock = K0; static constexpr index_t KPerBlock = K0;
static constexpr index_t KPack = K1;
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
...@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx(); const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
const auto blk_id = blk_idx[I0]; return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
const auto blk_td = blk_idx[I1];
if constexpr(xdlops_gemm.IsKReduction)
{
return make_tuple(blk_id, 0, waveId_m, blk_td, 0);
}
else
{
return make_tuple(0, 0, waveId_m, laneId, 0);
}
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto laneId = wave_idx[I2];
const auto blk_idx = xdlops_gemm.GetBlkIdx();
const auto blk_id = blk_idx[I0]; const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
const auto blk_td = blk_idx[I1];
if constexpr(xdlops_gemm.IsKReduction) return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
{
return make_tuple(blk_id, 0, waveId_n, blk_td, 0);
}
else
{
return make_tuple(0, 0, waveId_n, laneId, 0);
}
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
...@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!"); "wrong!");
...@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<FloatAB, K1> b_thread_vec; vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k0) {
// read A // read A
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(k, I0, I0, I0, I0), make_tuple(k0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
...@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// read B // read B
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
make_tuple(k, I0, I0, I0, I0), make_tuple(k0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
using mfma_input_type = using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type; typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_per_blk>::type;
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
...@@ -9,19 +9,16 @@ namespace ck { ...@@ -9,19 +9,16 @@ namespace ck {
enum struct mfma_instr enum struct mfma_instr
{ {
/// fp32
mfma_f32_32x32x1xf32 = 0, mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32, mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32, mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction
/// fp16
mfma_f32_32x32x4f16, mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16, mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16, mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction mfma_f32_16x16x16f16, // k reduction
/// bfp16
mfma_f32_32x32x2bf16, mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16, mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16, mfma_f32_4x4x2bf16,
...@@ -36,18 +33,16 @@ template <> ...@@ -36,18 +33,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 1; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -65,18 +60,16 @@ template <> ...@@ -65,18 +60,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32> struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -94,18 +87,16 @@ template <> ...@@ -94,18 +87,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32> struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -123,18 +114,16 @@ template <> ...@@ -123,18 +114,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32> struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 1; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -153,18 +142,16 @@ template <> ...@@ -153,18 +142,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32> struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4; static constexpr index_t m_per_blk = 4;
static constexpr index_t m = 4; static constexpr index_t n_per_blk = 64;
static constexpr index_t n = 64; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 1; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -182,18 +169,16 @@ template <> ...@@ -182,18 +169,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16> struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -211,18 +196,16 @@ template <> ...@@ -211,18 +196,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16> struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 8; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -240,18 +223,16 @@ template <> ...@@ -240,18 +223,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16> struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 16; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -269,18 +250,16 @@ template <> ...@@ -269,18 +250,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16> struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -298,18 +277,16 @@ template <> ...@@ -298,18 +277,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16> struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4; static constexpr index_t m_per_blk = 4;
static constexpr index_t m = 4; static constexpr index_t n_per_blk = 64;
static constexpr index_t n = 64; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -328,18 +305,16 @@ template <> ...@@ -328,18 +305,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16> struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -362,18 +337,16 @@ template <> ...@@ -362,18 +337,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16> struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -395,18 +368,16 @@ template <> ...@@ -395,18 +368,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16> struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 8; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -428,18 +399,16 @@ template <> ...@@ -428,18 +399,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16> struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -461,18 +430,16 @@ template <> ...@@ -461,18 +430,16 @@ template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4; static constexpr index_t m_per_blk = 4;
static constexpr index_t m = 4; static constexpr index_t n_per_blk = 64;
static constexpr index_t n = 64; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -505,14 +472,9 @@ struct xdlops_info ...@@ -505,14 +472,9 @@ struct xdlops_info
return true; return true;
} }
static constexpr bool IsKReduction()
{
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
}
static constexpr index_t GetKPerXdlops() static constexpr index_t GetKPerXdlops()
{ {
return IsKReduction() ? mfma_type.num_input_blks : 1; return mfma_type.is_k_reduction ? mfma_type.num_input_blks : 1;
} }
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
...@@ -521,6 +483,13 @@ struct xdlops_info ...@@ -521,6 +483,13 @@ struct xdlops_info
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack> template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
struct XdlopsGemm struct XdlopsGemm
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
template <class base_type_ = base_type, template <class base_type_ = base_type,
index_t MPerWave_ = MPerWave, index_t MPerWave_ = MPerWave,
index_t NPerWave_ = NPerWave> index_t NPerWave_ = NPerWave>
...@@ -684,7 +653,30 @@ struct XdlopsGemm ...@@ -684,7 +653,30 @@ struct XdlopsGemm
__device__ static constexpr index_t GetNumXdlops() __device__ static constexpr index_t GetNumXdlops()
{ {
return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); return MPerXdlops * NPerXdlops /
(mfma_type.m_per_blk * mfma_type.n_per_blk * mfma_type.num_output_blks);
}
__host__ __device__ static void mfma_info_check()
{
static_assert(mfma_type.num_threads_per_blk == mfma_type.n_per_blk,
"n_per_blk != num_threads_per_blk");
static_assert(mfma_type.num_regs_per_blk * mfma_type.num_input_blks == mfma_type.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_per_blk * mfma_type.wave_size ==
mfma_type.m_per_blk * mfma_type.n_per_blk,
"num_regs_per_blk incorrect");
static_assert(mfma_type.is_k_reduction ||
(mfma_type.num_input_blks == mfma_type.num_output_blks),
"is_k_reduction wrong!");
} }
__host__ __device__ constexpr XdlopsGemm() __host__ __device__ constexpr XdlopsGemm()
...@@ -696,30 +688,12 @@ struct XdlopsGemm ...@@ -696,30 +688,12 @@ struct XdlopsGemm
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64, MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
} }
template <typename CM0N0M1N1M2N2GridDesc> template <typename CM0N0M1N1M2N2GridDesc>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CM0N0M1N1M2N2GridDesc& c_m0_n0_m1_n1_m2_n2_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CM0N0M1N1M2N2GridDesc& c_m0_n0_m1_n1_m2_n2_grid_desc)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto M0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I0); constexpr auto M0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I1); constexpr auto N0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I2); constexpr auto M1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I2);
...@@ -727,9 +701,9 @@ struct XdlopsGemm ...@@ -727,9 +701,9 @@ struct XdlopsGemm
constexpr auto M2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I4); constexpr auto M2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I4);
constexpr auto N2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I5); constexpr auto N2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I5);
static_assert(N2 == mfma_type.num_threads_blk, ""); static_assert(N2 == mfma_type.num_threads_per_blk, "");
static_assert( static_assert(
M2 == (mfma_type.num_groups_blk * mfma_type.num_output_blks * mfma_type.group_size), M2 == (mfma_type.num_groups_per_blk * mfma_type.num_output_blks * mfma_type.group_size),
""); "");
return transform_dynamic_tensor_descriptor( return transform_dynamic_tensor_descriptor(
...@@ -738,10 +712,10 @@ struct XdlopsGemm ...@@ -738,10 +712,10 @@ struct XdlopsGemm
make_pass_through_transform(N0), make_pass_through_transform(N0),
make_pass_through_transform(M1), make_pass_through_transform(M1),
make_pass_through_transform(N1), make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_type.num_groups_blk, make_unmerge_transform(make_tuple(mfma_type.num_groups_per_blk,
mfma_type.num_input_blks, mfma_type.num_input_blks,
mfma_type.group_size)), mfma_type.group_size)),
make_pass_through_transform(mfma_type.num_threads_blk)), make_pass_through_transform(mfma_type.num_threads_per_blk)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -768,40 +742,79 @@ struct XdlopsGemm ...@@ -768,40 +742,79 @@ struct XdlopsGemm
is_same<base_type, ushort>::value, is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!"); "base base_type must be float, half, ushort!");
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); static_assert(KPack % mfma_type.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
static_for<0, KPack / mfma_type.k_base, 1>{}([&](auto k) { static_for<0, KPack / mfma_type.k_per_blk, 1>{}([&](auto k) {
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>( mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[k], p_b_wave[k], p_c_thread); p_a_wave[k], p_b_wave[k], p_c_thread);
}); });
} }
static constexpr auto GetBlkIdx() __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_type.wave_size; }
__device__ static auto GetBlkIdx(const index_t laneId)
{ {
const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform( make_tuple(make_merge_transform(
make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_blk))), make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto blk_idx = threadidx_to_blk_idx_adaptor.CalculateBottomIndex( const auto blk_idx =
make_multi_index(get_thread_local_1d_id())); threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
const auto blk_id = blk_idx[Number<1>{}]; const auto blk_id = blk_idx[I1];
const auto blk_td = blk_idx[Number<2>{}]; const auto blk_td = blk_idx[I2];
return make_tuple(blk_id, blk_td); return make_tuple(blk_id, blk_td);
} }
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId);
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(mfma_type.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId);
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(mfma_type.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{ {
const auto blk_idx = GetBlkIdx(); const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx(laneId);
const auto blk_id = blk_idx[Number<0>{}]; const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[Number<1>{}]; const auto blk_td = blk_idx[I1];
index_t n_offset = blk_i * mfma_type.n + blk_td; index_t n_offset = blk_i * mfma_type.n_per_blk + blk_td;
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; index_t m_offset = xdlops_i * mfma_type.m_per_blk + blk_id * mfma_type.group_size;
return CIndex{m_offset, n_offset}; return CIndex{m_offset, n_offset};
} }
...@@ -810,26 +823,25 @@ struct XdlopsGemm ...@@ -810,26 +823,25 @@ struct XdlopsGemm
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
struct CLayout struct CLayout
{ {
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_per_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_per_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_per_blk; }
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops() __device__ static constexpr index_t GetNumXdlops()
{ {
return MPerXdlops * NPerXdlops / return MPerXdlops * NPerXdlops /
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks); (mfma_type.m_per_blk * mfma_type.n_per_blk * mfma_type.num_output_blks);
} }
}; };
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> //#include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment