Commit 3399ddaf authored by Jing Zhang's avatar Jing Zhang
Browse files

break vector type to blk_size

parent 59462dca
...@@ -111,7 +111,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -111,7 +111,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
constexpr auto xdlops_gemm = XdlopsGemm<float, GemmMPerWave, GemmNPerWave, GemmKPerWave>{}; constexpr auto xdlops_gemm = XdlopsGemm<float, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto CLayout = xdlops_gemm.GetOutputLayout(); constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1(); constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1(); constexpr index_t M1 = CLayout.N1();
......
...@@ -42,17 +42,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -42,17 +42,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t MRepeat = M0; static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0; static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); } __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
{
return xdlops_gemm.GetOutputLayout().GetNumBlks();
}
__device__ constexpr auto GetBlkSize() const __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
{
return xdlops_gemm.GetOutputLayout().GetBlkSize();
}
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex()
{ {
...@@ -98,13 +92,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -98,13 +92,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
} }
} }
template <index_t m0, index_t n0, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<blk_i>) __device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{ {
const index_t waveId = get_thread_local_1d_id() / WaveSize; const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i); const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
...@@ -240,17 +235,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -240,17 +235,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static constexpr index_t MRepeat = M0; static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0; static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); } __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
{
return xdlops_gemm.GetOutputLayout().GetNumBlks();
}
__device__ constexpr auto GetBlkSize() const __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
{
return xdlops_gemm.GetOutputLayout().GetBlkSize();
}
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex()
{ {
......
...@@ -310,10 +310,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -310,10 +310,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave, MPerWave,
NPerWave, NPerWave,
KPerWave>{}; KPerWave>{};
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout(); constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = OutputLayout.GetBlkSize(); constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks(); constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
// constexpr auto c_mr_nr_nb_bk_thread_desc = // constexpr auto c_mr_nr_nb_bk_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number<MRepeat>{}, // make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number<MRepeat>{},
...@@ -338,7 +339,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -338,7 +339,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{} // Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
StaticBuffer<AddressSpace::Vgpr, vector_type<float, NumBlks * BlkSize>, MRepeat * NRepeat> StaticBuffer<AddressSpace::Vgpr,
vector_type<float, NumBlks * BlkSize>,
MRepeat * NRepeat * NumXdlops>
c_thread_buf; c_thread_buf;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -471,9 +474,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -471,9 +474,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory // output: register to global memory
{ {
constexpr index_t M0 = OutputLayout.M1(); constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = OutputLayout.N1(); constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = OutputLayout.M0(); constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc = constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
...@@ -483,49 +486,53 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -483,49 +486,53 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
static_for<0, MRepeat, 1>{}([&](auto mr_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) { static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) { static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, BlkSize, 1>{}([&](auto j) { static_for<0, NumBlks, 1>{}([&](auto blk_i) {
c_blk_buf_(j) = static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf[Number<mr_i * NRepeat + nr_i>{}] c_blk_buf_(j) =
.template AsType<float>()[Number<blk_i * BlkSize + j>{}]; c_thread_buf[Number<(mr_i * NRepeat + nr_i) * NumXdlops +
xdlops_i>{}]
.template AsType<float>()[Number<blk_i * BlkSize + j>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_m2_n_global_tensor_iterator_hacks);
}); });
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(mr_i, nr_i, blk_i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_m2_n_global_tensor_iterator_hacks);
}); });
}); });
}); });
......
...@@ -50,10 +50,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> ...@@ -50,10 +50,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
static constexpr index_t cycles = 64; static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1; static constexpr index_t k_base = 1;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
} }
}; };
...@@ -74,10 +79,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32> ...@@ -74,10 +79,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
static constexpr index_t cycles = 64; static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1; static constexpr index_t k_base = 1;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC> template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
} }
}; };
...@@ -528,7 +538,7 @@ struct xdlops_info ...@@ -528,7 +538,7 @@ struct xdlops_info
static constexpr index_t MRepeats = MRepeats_; static constexpr index_t MRepeats = MRepeats_;
static constexpr index_t NRepeats = NRepeats_; static constexpr index_t NRepeats = NRepeats_;
static constexpr bool IsABroadcast() { return NPerXdlops >= MPerXdlops; } // static constexpr bool IsABroadcast() { return NPerXdlops >= MPerXdlops; }
static constexpr bool IsKReduction() static constexpr bool IsKReduction()
{ {
...@@ -743,9 +753,11 @@ struct XdlopsGemm ...@@ -743,9 +753,11 @@ struct XdlopsGemm
using CIndex = MultiIndex<2>; using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlksPerXdlops() __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{ {
return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n); return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
} }
__host__ __device__ constexpr XdlopsGemm() __host__ __device__ constexpr XdlopsGemm()
...@@ -791,42 +803,27 @@ struct XdlopsGemm ...@@ -791,42 +803,27 @@ struct XdlopsGemm
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops"); static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)); constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) { static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0)); constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0)); constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0));
mfma_type.template run<MPerXdlops, NPerXdlops>(p_a_wave[Number<a_offset>{}], mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_b_wave[Number<b_offset>{}], p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], p_c_thread);
p_c_thread(Number<c_offset>{}));
}); });
} }
__device__ static CIndex GetBeginOfThreadBlk(index_t i) __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{ {
const index_t xdlops_i = i / GetNumBlksPerXdlops();
const index_t j = i % GetNumBlksPerXdlops();
const index_t m_i = xdlops_i / NRepeats;
const index_t n_i = xdlops_i % NRepeats;
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk; const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk; const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t col_blk = j % mfma_type.num_output_blks; index_t n_offset = blk_i * mfma_type.n + blk_td;
index_t row_blk = j / mfma_type.num_output_blks; index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
static_if<!IsABroadcast>{}([&](auto) { return CIndex{m_offset, n_offset};
col_blk = j / mfma_type.num_output_blks;
row_blk = j % mfma_type.num_output_blks;
});
index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops;
return CIndex{row, col};
} }
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
...@@ -834,8 +831,8 @@ struct XdlopsGemm ...@@ -834,8 +831,8 @@ struct XdlopsGemm
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); // static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id) static constexpr auto GetBlkId(const index_t lane_id)
...@@ -850,7 +847,7 @@ struct XdlopsGemm ...@@ -850,7 +847,7 @@ struct XdlopsGemm
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
struct OutputLayout struct CLayout
{ {
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
...@@ -859,13 +856,16 @@ struct XdlopsGemm ...@@ -859,13 +856,16 @@ struct XdlopsGemm
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
__device__ static constexpr index_t GetNumBlks() __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{ {
return GetNumBlksPerXdlops() * MRepeats * NRepeats; return MPerXdlops * NPerXdlops /
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
} }
}; };
__host__ __device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; } __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; }
}; };
} // namespace ck } // namespace ck
......
...@@ -198,7 +198,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( ...@@ -198,7 +198,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
// template <index_t AStride, index_t BStride> // template <index_t AStride, index_t BStride>
...@@ -237,16 +237,28 @@ struct intrin_mfma_f32_32x32x1f32; ...@@ -237,16 +237,28 @@ struct intrin_mfma_f32_32x32x1f32;
//} //}
//}; //};
template <> template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64> struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_a,
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0); reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
} }
}; };
...@@ -272,17 +284,23 @@ struct intrin_mfma_f32_32x32x1f32<64, 64> ...@@ -272,17 +284,23 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
//} //}
//}; //};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x2f32; struct intrin_mfma_f32_32x32x2f32;
template <> template <index_t COffset>
struct intrin_mfma_f32_32x32x2f32<32, 32> struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
......
...@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else #else
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPerWave = 4; constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 1; constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment