Commit b3cc22a3 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent d16063db
...@@ -30,12 +30,14 @@ template <index_t BlockSize, ...@@ -30,12 +30,14 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack>
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 // MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I3 = Number<4>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -85,8 +87,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -85,8 +87,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|Mwave |MLane |KPack
return make_tuple(0, waveId_m, WMMA_a_idx[I1], KPerThread * WMMA_a_idx[I0]); return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -96,20 +98,20 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -96,20 +98,20 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, waveId_n, WMMA_b_idx[I1], KPerThread * WMMA_b_idx[I0]); return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
} }
template <index_t m0, index_t n0, index_t WMMA_i, index_t blk_i> template <index_t m0, index_t n0>
__device__ static auto __device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<WMMA_i>, Number<blk_i>) CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(WMMA_i, blk_i); const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
...@@ -129,27 +131,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -129,27 +131,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t WMMA_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<WMMA_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk4D(WMMA_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
...@@ -162,59 +143,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -162,59 +143,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
} }
// Thread level, register decriptor.
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths(); constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); // |MRepeat |MWave |MSubGroup |NRepeat |NWave |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, MSubGroup, Number<NRepeat>{}, I1, NThreadPerSubGroup, MAccVgprs));
} }
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{ {
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{}, Number<MWaves>{},
Number<NWaves>{},
Number<MPerWMMA>{}, Number<MPerWMMA>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{}, Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{}, Number<NWaves>{},
Number<MPerWMMA>{},
Number<NPerWMMA>{})); Number<NPerWMMA>{}));
return wmma_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( return wmma_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
c_block_desc_g_m0_n0_m1_n1_m2_n2);
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
...@@ -234,32 +187,46 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -234,32 +187,46 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
} }
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() __host__ __device__ static constexpr auto MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack()
{ {
return transform_tensor_descriptor( static constexpr auto a_block_desc_temp_km0m1m2 = transform_tensor_descriptor(
AK0MK1BlockDesc{}, AK0MK1BlockDesc{},
make_tuple( make_tuple(
make_pass_through_transform(make_tuple(Number<A_K0>{}, Number<A_K1>{})), make_merge_transform(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform( make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
a_block_desc_temp_km0m1m2,
make_tuple(
make_unmerge_transform(make_tuple(Number<A_K0*A_K1/KPack>{}, Number<KPack>{})),
make_pass_through_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
} }
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() __host__ __device__ static constexpr auto MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack()
{ {
return transform_tensor_descriptor( static constexpr auto b_block_desc_temp_kn0n1n2 = transform_tensor_descriptor(
BK0NK1BlockDesc{}, BK0NK1BlockDesc{},
make_tuple( make_tuple(
make_pass_through_transform(make_tuple(Number<B_K0>{}, Number<B_K1>{})), make_merge_transform(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform( make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
b_block_desc_temp_kn0n1n2,
make_tuple(
make_unmerge_transform(make_tuple(Number<B_K0*B_K1/KPack>{}, Number<KPack>{})),
make_pass_through_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
} }
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); static constexpr auto a_block_desc_krepeat_m0_m1_m2_kpack = MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); static constexpr auto b_block_desc_krepeat_n0_n1_n2_kpack = MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
...@@ -298,7 +265,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -298,7 +265,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec.template AsType<wmma_input_type>(), b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack,
make_tuple(Number<iWmmaK>{}, iCut, I0, I0, I0), make_tuple(Number<iWmmaK>{}, iCut, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
...@@ -328,7 +295,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -328,7 +295,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec.template AsType<wmma_input_type>(), b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop+RepeatDiff, I0, I0, I0), make_tuple(Number<iWmmaK>{}, WmmaInnerloop+RepeatDiff, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
...@@ -355,7 +322,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -355,7 +322,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec.template AsType<wmma_input_type>(), b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(b_block_desc_krepeat_n0_n1_n2_kpack,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop, I0, I0, I0), make_tuple(Number<iWmmaK>{}, WmmaInnerloop, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
...@@ -380,7 +347,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -380,7 +347,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_krepeat_m0_m1_m2_kpack),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, WmmaK>, Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
...@@ -390,7 +357,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 ...@@ -390,7 +357,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_krepeat_n0_n1_n2_kpack),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, WmmaK>, Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
...@@ -413,11 +380,11 @@ template <index_t BlockSize, ...@@ -413,11 +380,11 @@ template <index_t BlockSize,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
LoopScheduler LoopSched> LoopScheduler LoopSched>
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2_Selector() constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_Selector()
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
{ {
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2<BlockSize, return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
AK0MK1BlockDesc, AK0MK1BlockDesc,
......
...@@ -72,9 +72,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -72,9 +72,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto M1Number = Number<M1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{ {
...@@ -87,10 +86,12 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -87,10 +86,12 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
} }
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
} }
#endif
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
...@@ -154,12 +155,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -154,12 +155,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
} }
} }
static auto MakeCGridDescriptor_M0_N_M1(index_t M, index_t N, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
assert(M % M1 == 0);
const index_t M0 = M / M1;
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
...@@ -173,8 +170,6 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -173,8 +170,6 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
if constexpr(GemmSpec == GemmSpecialization::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
static_assert(false, "Padding Gemm Not implemented");
/* Not implemented yet.
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -183,26 +178,25 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -183,26 +178,25 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
*/
} }
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M1Number)), make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
} }
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M0_N_M1 = decltype(MakeCGridDescriptor_M0_N_M1(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
...@@ -210,7 +204,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -210,7 +204,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M0_N_M1, CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -238,15 +232,16 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -238,15 +232,16 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
#endif
NumPrefetch, NumPrefetch,
LoopSched,
PipelineVer>; PipelineVer>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgumentW
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
...@@ -267,8 +262,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -267,8 +262,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m0_n_m1_{}, c_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -278,18 +273,18 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -278,18 +273,18 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmWmma::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmWmma::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmWmma::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmWmma::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m0_n_m1_ = DeviceGemmWmma::MakeCGridDescriptor_M0_N_M1(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m0_n_m1_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
c_grid_desc_m0_n_m1_, c_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n_m1_); GridwiseGemm::MakeCGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow(c_grid_desc_m_n_);
} }
} }
...@@ -299,9 +294,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -299,9 +294,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M0_N_M1 c_grid_desc_m0_n_m1_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -327,15 +322,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -327,15 +322,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m0_n_m1_{ " << arg.c_grid_desc_m0_n_m1_.GetLength(I0) std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m0_n_m1_.GetLength(I1) << ", " << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m0_n_m1_.GetLength(I2) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n_m1_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
...@@ -343,7 +338,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -343,7 +338,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m0_n_m1_); arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
...@@ -358,7 +353,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -358,7 +353,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -375,7 +370,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -375,7 +370,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -389,7 +384,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -389,7 +384,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -406,7 +401,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -406,7 +401,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -447,7 +442,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -447,7 +442,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n_m1_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#define DISABLE_C_SHUFFLE
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -23,11 +22,7 @@ template <typename GridwiseGemm, ...@@ -23,11 +22,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma, typename CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs,
#ifndef DISABLE_C_SHUFFLE
typename C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma,
typename C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma,
#endif
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -41,18 +36,10 @@ __global__ void ...@@ -41,18 +36,10 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
#ifndef DISABLE_C_SHUFFLE
const C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
const C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
#endif
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -65,16 +52,10 @@ __global__ void ...@@ -65,16 +52,10 @@ __global__ void
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid,
p_c1_grid,
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
#ifndef DISABLE_C_SHUFFLE
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
#endif
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -83,13 +64,9 @@ __global__ void ...@@ -83,13 +64,9 @@ __global__ void
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_c0_grid;
ignore = p_c1_grid;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; ignore = c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs;
ignore = c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma;
ignore = c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
...@@ -106,8 +83,6 @@ template < ...@@ -106,8 +83,6 @@ template <
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -135,10 +110,9 @@ template < ...@@ -135,10 +110,9 @@ template <
index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t CShuffleMWmmaPerWavePerShuffle, typename CThreadTransferSrcDstAccessOrder,
index_t CShuffleNWmmaPerWavePerShuffle, index_t CThreadTransferSrcDstVectorDim,
typename CBlockTransferClusterLengths_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma, index_t CThreadTransferDstScalarPerVector,
index_t CBlockTransferScalarPerVector_NWaveNPerWmma,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
...@@ -160,84 +134,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -160,84 +134,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto inst_max_size = 16 / sizeof(FloatAB); constexpr auto max_lds_align = K1;
constexpr auto k1perinst = (K1 <inst_max_size) ? K1 : inst_max_size;
constexpr auto K10 = K1 / k1perinst;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_k10_m_k1perinst = [&]() { constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
// May have static err if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, K10, Number<MPerBlock>{}, k1perinst), k1perinst); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}(); }();
return a_block_desc_k0_k10_m_k1perinst; return a_block_desc_k0perblock_mperblock_k1;
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst() __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{ {
constexpr auto inst_max_size = 16 / sizeof(FloatAB); constexpr auto max_lds_align = K1;
constexpr auto k1perinst = (K1 <inst_max_size) ? K1 : inst_max_size;
constexpr auto K10 = K1 / k1perinst;
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_k10_n_k1perinst = [&]() { constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
return make_naive_tensor_descriptor_aligned( if constexpr(BBlockLdsExtraN)
make_tuple(Number<K0PerBlock>{}, K10, Number<NPerBlock>{}, k1perinst), k1perinst); {
}(); return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
return b_block_desc_k0_k10_n_k1perinst; make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma()
{ {
constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); return make_naive_tensor_descriptor_aligned(
constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto return b_block_desc_k0perblock_nperblock_k1;
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMWmmaPerWavePerShuffle>{},
Number<MWave * MPerWmma>{},
I1,
Number<CShuffleNWmmaPerWavePerShuffle>{},
Number<NWave * NPerWmma>{}));
return c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma;
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst(); constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = a_block_desc_k0_k10_m_k1perinst.GetLength(I3); constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size = 0; return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
#ifndef DISABLE_C_SHUFFLE
// LDS allocation for C shuffle in LDS
constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma =
GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma();
constexpr auto c_block_size =
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.GetElementSpaceSize();
#endif
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatC));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -293,7 +249,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -293,7 +249,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
template <typename CGridDesc_M_N_> template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma( MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N_& c_grid_desc_m_n) const CGridDesc_M_N_& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
...@@ -305,17 +261,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -305,17 +261,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma);
constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma);
const auto c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = constexpr index_t MLaneHigh = 2;
constexpr index_t MLaneLow = NWmmaPerWave / MLaneHigh;
constexpr index_t NLane = NWmmaPerWave;
const auto c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
transform_tensor_descriptor( transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple( make_tuple(make_unmerge_transform(make_tuple(
MBlock, Number<MWmmaPerWave>{}, Number<MWave * MPerWmma>{})), MBlock, Number<MWmmaPerWave>{}, Number<MWave>{}, Number<MLaneHigh>{}, Number<MLaneLow>{})),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
NBlock, Number<NWmmaPerWave>{}, Number<NWave * NPerWmma>{}))), NBlock, Number<NWmmaPerWave>{}, Number<NWave>{}, Number<NLane>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2, 3, 8>{}, Sequence<4, 5, 6, 7>{}));
return c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; return c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
...@@ -325,21 +285,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -325,21 +285,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma = using CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs =
remove_cvref_t<decltype( remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma( MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
#ifndef DISABLE_C_SHUFFLE
using C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma(
C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma(
C1GridDesc_M_N{}))>;
#endif
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
...@@ -348,74 +297,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -348,74 +297,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma& const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs&
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
#ifndef DISABLE_C_SHUFFLE
const C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma&
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
const C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma&
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
#endif
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, p_c_grid, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize());
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.GetElementSpaceSize());
#ifndef DISABLE_C_SHUFFLE
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_grid,
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c1_grid,
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.GetElementSpaceSize());
#endif
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
make_tuple( make_tuple(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0),
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4))))
.GetLength(I0), { return; }
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst();
// B matrix in LDS memory, dst of blockwise copy // Store BlockId into SGPR
constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst(); const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment /*******************************************************************************/
constexpr auto max_lds_align = a_block_desc_k0_m_k10_k11.GetLength(I3); // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -429,7 +349,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -429,7 +349,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
/* typename SrcData, */ FloatAB, /* typename SrcData, */ FloatAB,
/* typename DstData, */ FloatAB, /* typename DstData, */ FloatAB,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0_k10_m_k1perinst), /* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, /* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
...@@ -443,7 +363,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -443,7 +363,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_k0_k10_m_k1perinst, a_block_desc_k0perblock_mperblock_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
...@@ -459,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -459,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_k10_n_k1perinst), decltype(b_block_desc_k0perblock_nperblock_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
...@@ -473,43 +393,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -473,43 +393,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
b_block_desc_k0_k10_n_k1perinst, b_block_desc_k0perblock_nperblock_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
/*******************************************************************************/
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += a_mtx * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS // b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in register
// register
// sanity check constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWmmaops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_k10_m_k1perinst), decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0_k10_n_k1perinst), decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma, MPerWmma,
NPerWmma, NPerWmma,
MWmmaPerWave, MWmmaPerWave,
NWmmaPerWave, NWmmaPerWave,
K1>{}; KPack>{};
// Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
/*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatAB*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
math::integer_least_multiple(a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize(), max_lds_align); auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize());
// Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
...@@ -517,13 +436,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -517,13 +436,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
a_block_desc_k0_k10_m_k1perinst, a_block_desc_k0perblock_mperblock_k1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
b_block_desc_k0_k10_n_k1perinst, b_block_desc_k0perblock_nperblock_k1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
...@@ -531,270 +450,79 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -531,270 +450,79 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
#ifndef DISABLE_C_SHUFFLE // NO C-shuffle, direct write
// shuffle C and write out
{ {
static_assert(MWmmaPerWave % CShuffleMWmmaPerWavePerShuffle == 0 && constexpr c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
NWmmaPerWave % CShuffleNWmmaPerWavePerShuffle == 0, blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLaneLow();
"wrong!"); constexpr c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma);
constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); constexpr auto MRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0);
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1);
// TODO: hacky, fix it! constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto NRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I3);
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto Nwave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
// TODO: hacky, fix it! constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = // Mapping
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma =
GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_pass_through_transform(
Number<CShuffleMWmmaPerWavePerShuffle>{}), // M0 (MWmmaPerWave) per
// shuffle
make_unmerge_transform(
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerWmma
make_freeze_transform(I0), // freeze nblock
make_pass_through_transform(
Number<CShuffleNWmmaPerWavePerShuffle>{}), // N0 (NWmmaPerWave) per
// shuffle
make_unmerge_transform(
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerWmma
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<2, 4, 5, 6>{},
Sequence<>{},
Sequence<1>{},
Sequence<3, 7>{})
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx = const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor(
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(m_thread_data_on_grid));
make_multi_index(n_thread_data_on_block));
const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup(
// VGPR to LDS make_multi_index(n_thread_data_on_grid));
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC, auto c_thread_copy =
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), ThreadwiseTensorSliceTransfer_v1r3<
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), /* typename SrcData */ FloatAcc,
ck::tensor_operation::element_wise::PassThrough, /* typename DstData */ FloatC,
Sequence<CShuffleMWmmaPerWavePerShuffle, /* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
CShuffleNWmmaPerWavePerShuffle, /* typename DstDesc */ decltype(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
I1, /* typename ElementwiseOperation */ CElementwiseOperation,
I1, /* typename SliceLengths */ Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
M2, /* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder,
I1, /* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim,
M4, /* index_t DstScalarPerVector */ CThreadTransferDstScalarPerVector,
I1>, /* InMemoryDataOperationEnum DstInMemOp */ CGlobalMemoryDataOperation,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, /* index_t DstScalarStrideInVector */ 1,
7, /* bool DstResetCoordinateAfterRun */ true>
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMWmmaPerWavePerShuffle,
MWave * MPerWmma,
1,
CShuffleNWmmaPerWavePerShuffle,
NWave * NPerWmma>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma,
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
FloatC, // typename Src0Data,
FloatC, // typename Src1Data,
FloatC, // typename Src2Data,
FloatC, // typename DstData,
decltype(
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma),
decltype(
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma),
decltype(
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma),
decltype(
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerWmma, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
make_multi_index(0, 0, 0, 0, 0, 0),
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
c_element_op};
constexpr auto mwmmaperwave_forward_step =
make_multi_index(0, CShuffleMWmmaPerWavePerShuffle, 0, 0, 0, 0);
constexpr auto nwmmaperwave_forward_step =
make_multi_index(0, 0, 0, 0, CShuffleNWmmaPerWavePerShuffle, 0);
constexpr auto nwmmaperwave_backward_step =
make_multi_index(0, 0, 0, 0, -CShuffleNWmmaPerWavePerShuffle, 0);
static_for<0, MWmmaPerWave, CShuffleMWmmaPerWavePerShuffle>{}([&](auto mwmmaperwave_iter) {
constexpr auto mwmmaperwave = mwmmaperwave_iter;
static_for<0,
NWmmaPerWave,
CShuffleNWmmaPerWavePerShuffle>{}([&](auto nwmmaperwave_iter) {
constexpr bool nwmmaperwave_forward_sweep =
(mwmmaperwave % (2 * CShuffleMWmmaPerWavePerShuffle) == 0);
constexpr index_t nwmmaperwave_value =
nwmmaperwave_forward_sweep
? nwmmaperwave_iter
: (NWmmaPerWave - nwmmaperwave_iter - CShuffleNWmmaPerWavePerShuffle);
constexpr auto nwmmaperwave = Number<nwmmaperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(mwmmaperwave, nwmmaperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
// LDS to global
c_block_copy_lds_to_global.Run(
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
c_block_buf,
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
c0_grid_buf,
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
c1_grid_buf,
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
c_grid_buf);
// move on nwmmaperwave dimension
if constexpr(nwmmaperwave_forward_sweep &&
(nwmmaperwave < NWmmaPerWave - CShuffleNWmmaPerWavePerShuffle))
{
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
nwmmaperwave_forward_step);
c_block_copy_lds_to_global.MoveSrc2SliceWindow(
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
nwmmaperwave_forward_step);
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
nwmmaperwave_forward_step);
}
else if constexpr((!nwmmaperwave_forward_sweep) && (nwmmaperwave > 0))
{
c_block_copy_lds_to_global.MoveSrc1SliceWindow(
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
nwmmaperwave_backward_step);
c_block_copy_lds_to_global.MoveSrc2SliceWindow(
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
nwmmaperwave_backward_step);
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma,
nwmmaperwave_backward_step);
}
});
// move on mwmmaperwave dimension
if constexpr(mwmmaperwave < MWmmaPerWave - CShuffleMWmmaPerWavePerShuffle)
{ {
c_block_copy_lds_to_global.MoveSrc1SliceWindow( /* dst_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, /* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0],
mwmmaperwave_forward_step); m_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
c_block_copy_lds_to_global.MoveSrc2SliceWindow( n_thread_data_on_grid_idx[I0],
c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, n_thread_data_on_grid_idx[I1],
mwmmaperwave_forward_step); n_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3]),
c_block_copy_lds_to_global.MoveDstSliceWindow( /* element_op */ c_element_op
c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, };
mwmmaperwave_forward_step);
c_thread_copy.Run(
/* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_start point */ make_tuple(I0, I0, I0, I0, I0, I0, I0),
/* c_buffer */ c_thread_buf,
/* c_grid_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_grid_buf */ c_grid_buf);
} }
}); // clang-format on
}
#endif
} }
}; };
......
...@@ -11,35 +11,107 @@ namespace ck { ...@@ -11,35 +11,107 @@ namespace ck {
enum struct WmmaInstr enum struct WmmaInstr
{ {
wmma_f32_16x16x16_f16_w32 = 0, wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16_w32 = 0, wmma_f32_16x16x16_bf16 = 0,
wmma_f16_16x16x16_f16_w32 = 0, wmma_f16_16x16x16_f16 = 0,
wmma_bf16_16x16x16_bf16_w32 = 0, wmma_bf16_16x16x16_bf16 = 0,
wmma_i32_16x16x16_iu8_w32 = 0, wmma_i32_16x16x16_iu8 = 0,
wmma_i32_16x16x16_iu4_w32 = 0 wmma_i32_16x16x16_iu4 = 0
}; };
template <WmmaInstr instr> /*
* WMMA Wave Tile Always MxNxK = 16x16x16
* WAVE32
-----------------------------------
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|RC1| | | | | | | | | | | | | | | |
|RC2| | | | | | | | | | | | | | | |
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC6| | | | | | | | | | | | | | | |
|RC7| | | | | | | | | | | | | | | |
-----------------------------------
| | | | | | | | | | | | | | | | | SubGroup 1
| | | | | | | | | | | | | | | | |
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
-----------------------------------
* WAVE64
-----------------------------------
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
| | | | | | | | | | | | | | | | |
-----------------------------------
* RC = Register for storing accumalted result
* T = Thread ID
*/
template <WmmaInstr Instr,
index_t WaveSize,
typename enable_if<WaveSize == 32 || WaveSize == 64, bool>:: = false>
struct wmma_type; struct wmma_type;
template <> // A-swizzled
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_w32> template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, WaveSize>
{ {
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16; static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16; static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16; static constexpr index_t k_per_wmma = 16;
static constexpr index_t wave_size = 32; static constexpr index_t src_a_data_size = 2;
static constexpr index_t lane_size = 16; static constexpr index_t src_b_data_size = 2;
static constexpr index_t src_data_size = 2;
static constexpr index_t acc_data_size = 4; static constexpr index_t acc_data_size = 4;
static constexpr index_t num_srcregs_per_wmma = 8; // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_accregs_per_wmma = 8; static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC> template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{ {
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c); intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
} }
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_f16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
}; };
template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma> template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma>
...@@ -51,54 +123,54 @@ struct WmmaSelector ...@@ -51,54 +123,54 @@ struct WmmaSelector
template <> template <>
static constexpr auto GetWmma<half_t, float, 16, 16>() static constexpr auto GetWmma<half_t, float, 16, 16>()
{ {
return WmmaInstr::wmma_f32_16x16x16_f16_w32; return WmmaInstr::wmma_f32_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, float, 16, 16>() static constexpr auto GetWmma<bhalf_t, float, 16, 16>()
{ {
return WmmaInstr::wmma_f32_16x16x16_bf16_w32; return WmmaInstr::wmma_f32_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<half_t, half_t, 16, 16>() static constexpr auto GetWmma<half_t, half_t, 16, 16>()
{ {
return WmmaInstr::wmma_f16_16x16x16_f16_w32; return WmmaInstr::wmma_f16_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, 16, 16>() static constexpr auto GetWmma<bhalf_t, bhalf_t, 16, 16>()
{ {
return WmmaInstr::wmma_bf16_16x16x16_bf16_w32; return WmmaInstr::wmma_bf16_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<int8_t, float, 16, 16>() static constexpr auto GetWmma<int8_t, float, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu8_w32; return WmmaInstr::wmma_i32_16x16x16_iu8;
} }
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, float, 16, 16>() static constexpr auto GetWmma<int4_t, float, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4_w32; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
#endif #endif
static constexpr auto selected_wmma = wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>()>{}; static constexpr auto selected_wmma = wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>(), get_warp_size()>{};
__host__ __device__ constexpr WmmaSelector() __host__ __device__ constexpr WmmaSelector()
{ {
static_assert(selected_wmma.m_per_wmma == selected_wmma.n_per_wmma, static_assert(selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to WMMA_N"); "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == selected_wmma.k_per_wmma, static_assert(selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to WMMA_K"); "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.k_per_wmma == 16, static_assert(selected_wmma.k_per_wmma == 16,
"WRONG! WMMA_M must equal to WMMA_N"); "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wmma * selected_wmma.acc_data_size== static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Number of Accumulator Register"); "WRONG! Number of Accumulator Register");
...@@ -135,26 +207,26 @@ struct WmmaGemm ...@@ -135,26 +207,26 @@ struct WmmaGemm
} }
// XDL output supporting C = A * B // XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2 // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_M0_N0_M1_N1_M2_N2> template <typename CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(const CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA& c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma)
{ {
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); const auto MRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); const auto NRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I3);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); const auto MWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I1);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); const auto NWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2, c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(M0), make_tuple(make_pass_through_transform(MRepeat),
make_pass_through_transform(N0), make_pass_through_transform(Mwave),
make_pass_through_transform(M1), make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
make_pass_through_transform(N1), Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_groups_per_blk>{}, make_pass_through_transform(NRepeat),
Number<wmma_instr.num_input_blks>{}, make_pass_through_transform(NWave),
Number<wmma_instr.group_size>{})), make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_pass_through_transform(Number<wmma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -163,91 +235,22 @@ struct WmmaGemm ...@@ -163,91 +235,22 @@ struct WmmaGemm
Sequence<5>{}), Sequence<5>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2, 6>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(Number<wmma_instr.num_threads_per_blk>{}),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_groups_per_blk>{},
Number<wmma_instr.num_input_blks>{},
Number<wmma_instr.group_size>{}))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{}, Sequence<3>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}));
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{}));
} }
template <typename CDesc_G_M0_N0_M1_N1_M2_N2> __device__ static constexpr index_t GetRegSizePerWmma()
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
{ {
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); return wmma_instr.num_acc_vgprs_per_wave;
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
return transform_tensor_descriptor(
c_desc_g_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(G),
make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(wmma_instr.num_groups_per_blk,
wmma_instr.num_input_blks,
wmma_instr.group_size)),
make_pass_through_transform(wmma_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{},
Sequence<8>{}));
} }
__device__ static constexpr index_t GetRegSizePerXdlops() __device__ static constexpr index_t GetWaveSize()
{ {
return MPerWmma * NPerWmma / wmma_instr.wave_size; return wmma_instr.wave_size;
} }
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
...@@ -272,67 +275,50 @@ struct WmmaGemm ...@@ -272,67 +275,50 @@ struct WmmaGemm
} }
} }
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } __device__ static auto GetLaneId()
{
return get_thread_local_1d_id() % wmma_instr.wave_size;
}
__device__ static auto GetLaneIdHigh() __device__ static auto GetSubGroupId()
{ {
return GetLaneId() / 16; return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
} }
__device__ static auto GetLaneIdLow() __device__ static auto GetLaneIdUnderSubGroup()
{ {
return GetLaneId() % 16; return GetLaneId() % wmma_instr.num_thread_per_subgroups;
} }
__device__ static auto GetSwizzledLaneIdLow() __device__ static auto GetSwizzledLaneIdLow()
{ {
return ((GetLaneIdLow() & 1) << 3 ) | (GetLaneIdLow() >> 1); return ((GetLaneIdUnderSubGroup() & 1) << 3 ) | (GetLaneIdUnderSubGroup() >> 1);
} }
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
return make_tuple(0, GetSwizzledLaneIdLow()); return GetSwizzledLaneIdLow();
} }
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex()
{ {
return make_tuple(0, GetLaneIdLow()); return GetLaneIdUnderSubGroup();
} }
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) __device__ static CIndex GetBeginOfThreadBlk()
{ {
const auto blk_idx = GetBlkIdx(); index_t n_offset = GetLaneIdUnderSubGroup();
index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
index_t n_offset = blk_i * wmma_instr.n_per_blk + blk_td;
index_t m_offset = xdlops_i * wmma_instr.m_per_blk + blk_id * wmma_instr.group_size;
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{}; static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma; static constexpr auto wmma_instr = wmma.selected_wmma;
static constexpr auto KPerXdlops = wmma.GetKPerXdlops(); __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
static constexpr auto K1PerXdlops = wmma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
{ {
return make_tuple( return make_tuple(
Number<wmma_instr.num_groups_per_blk>{}, I1, Number<wmma_instr.group_size>{}, I1); Number<I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
} }
}; };
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
// wave32 only
// src: fp16, dst: fp32 // src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32; struct intrin_wmma_f32_16x16x16_f16_w32;
...@@ -24,6 +23,20 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -24,6 +23,20 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32 // src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32; struct intrin_wmma_f32_16x16x16_bf16_w32;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment