Commit 8c84c0b1 authored by Jing Zhang's avatar Jing Zhang
Browse files

add KReduction

parent 02bf2be0
......@@ -55,7 +55,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
return make_tuple(0, waveId_m * MPerWave + laneId);
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, m_offset);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, m_offset);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -66,7 +77,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
return make_tuple(0, waveId_n * NPerWave + laneId);
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, n_offset);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, n_offset);
}
}
template <index_t AStride = MPerWave, index_t BStride = NPerWave>
......
......@@ -535,6 +535,11 @@ struct xdlops_info
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
}
static constexpr index_t GetKPerXdlops()
{
return mfma_type.k_base * (IsKReduction() ? mfma_type.num_input_blks : 1);
}
static constexpr auto OutputVecType = OutputVecType_{};
};
......@@ -571,7 +576,7 @@ struct XdlopsGemm
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k and k_base is inconsistent!");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
}
__device__ static constexpr index_t GetRegSizePerXdlops()
......@@ -586,7 +591,9 @@ struct XdlopsGemm
is_same<data_type, ushort>::value,
"base data_type must be float, half, ushort!");
static_for<0, KPerWave, mfma_type.k_base>{}([&](auto k_i) {
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k_i) {
mfma_type.template run<MPerXdlops, NPerXdlops>(
p_a_wave[Number<k_i>{}], p_b_wave[Number<k_i>{}], p_c_thread);
});
......@@ -833,8 +840,19 @@ struct XdlopsGemm
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id)
{
return lane_id / mfma_type.num_threads_blk;
}
static constexpr auto GetBlkTd(const index_t lane_id)
{
return lane_id % mfma_type.num_threads_blk;
}
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment