Commit 3399ddaf authored by Jing Zhang's avatar Jing Zhang
Browse files

break vector type to blk_size

parent 59462dca
......@@ -111,7 +111,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
constexpr auto xdlops_gemm = XdlopsGemm<float, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto CLayout = xdlops_gemm.GetOutputLayout();
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
......
......@@ -42,17 +42,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); }
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const
{
return xdlops_gemm.GetOutputLayout().GetNumBlks();
}
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
__device__ constexpr auto GetBlkSize() const
{
return xdlops_gemm.GetOutputLayout().GetBlkSize();
}
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
__device__ static auto CalculateAThreadOriginDataIndex()
{
......@@ -98,13 +92,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
}
}
template <index_t m0, index_t n0, index_t blk_i>
__device__ static CIndex CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<blk_i>)
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i);
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
......@@ -240,17 +235,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); }
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const
{
return xdlops_gemm.GetOutputLayout().GetNumBlks();
}
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
__device__ constexpr auto GetBlkSize() const
{
return xdlops_gemm.GetOutputLayout().GetBlkSize();
}
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
__device__ static auto CalculateAThreadOriginDataIndex()
{
......
......@@ -310,10 +310,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave,
NPerWave,
KPerWave>{};
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
// constexpr auto c_mr_nr_nb_bk_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number<MRepeat>{},
......@@ -338,7 +339,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
StaticBuffer<AddressSpace::Vgpr, vector_type<float, NumBlks * BlkSize>, MRepeat * NRepeat>
StaticBuffer<AddressSpace::Vgpr,
vector_type<float, NumBlks * BlkSize>,
MRepeat * NRepeat * NumXdlops>
c_thread_buf;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
......@@ -471,9 +474,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory
{
constexpr index_t M0 = OutputLayout.M1();
constexpr index_t M1 = OutputLayout.N1();
constexpr index_t M2 = OutputLayout.M0();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
......@@ -483,17 +486,20 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) =
c_thread_buf[Number<mr_i * NRepeat + nr_i>{}]
c_thread_buf[Number<(mr_i * NRepeat + nr_i) * NumXdlops +
xdlops_i>{}]
.template AsType<float>()[Number<blk_i * BlkSize + j>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(mr_i, nr_i, blk_i);
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
......@@ -529,6 +535,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
});
});
});
});
}
}
......
......@@ -50,10 +50,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -74,10 +79,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -528,7 +538,7 @@ struct xdlops_info
static constexpr index_t MRepeats = MRepeats_;
static constexpr index_t NRepeats = NRepeats_;
static constexpr bool IsABroadcast() { return NPerXdlops >= MPerXdlops; }
// static constexpr bool IsABroadcast() { return NPerXdlops >= MPerXdlops; }
static constexpr bool IsKReduction()
{
......@@ -743,9 +753,11 @@ struct XdlopsGemm
using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlksPerXdlops()
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n);
return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
}
__host__ __device__ constexpr XdlopsGemm()
......@@ -791,42 +803,27 @@ struct XdlopsGemm
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0));
mfma_type.template run<MPerXdlops, NPerXdlops>(p_a_wave[Number<a_offset>{}],
p_b_wave[Number<b_offset>{}],
p_c_thread(Number<c_offset>{}));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], p_c_thread);
});
}
__device__ static CIndex GetBeginOfThreadBlk(index_t i)
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
const index_t xdlops_i = i / GetNumBlksPerXdlops();
const index_t j = i % GetNumBlksPerXdlops();
const index_t m_i = xdlops_i / NRepeats;
const index_t n_i = xdlops_i % NRepeats;
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t col_blk = j % mfma_type.num_output_blks;
index_t row_blk = j / mfma_type.num_output_blks;
index_t n_offset = blk_i * mfma_type.n + blk_td;
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
static_if<!IsABroadcast>{}([&](auto) {
col_blk = j / mfma_type.num_output_blks;
row_blk = j % mfma_type.num_output_blks;
});
index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops;
return CIndex{row, col};
return CIndex{m_offset, n_offset};
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
......@@ -835,7 +832,7 @@ struct XdlopsGemm
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
// static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id)
......@@ -850,7 +847,7 @@ struct XdlopsGemm
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
struct OutputLayout
struct CLayout
{
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
......@@ -859,13 +856,16 @@ struct XdlopsGemm
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
__device__ static constexpr index_t GetNumBlks()
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return GetNumBlksPerXdlops() * MRepeats * NRepeats;
return MPerXdlops * NPerXdlops /
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
}
};
__host__ __device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; }
};
} // namespace ck
......
......@@ -198,7 +198,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave>
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x1f32;
// template <index_t AStride, index_t BStride>
......@@ -237,16 +237,28 @@ struct intrin_mfma_f32_32x32x1f32;
//}
//};
template <>
struct intrin_mfma_f32_32x32x1f32<64, 64>
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
......@@ -272,17 +284,23 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
//}
//};
template <index_t MPerWave, index_t NPerWave>
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x2f32;
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
template <index_t COffset>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
......
......@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment