Commit 0a808724 authored by aska-0096's avatar aska-0096
Browse files

Tidy up + format

parent 289f15de
...@@ -23,8 +23,9 @@ template <index_t BlockSize, ...@@ -23,8 +23,9 @@ template <index_t BlockSize,
/* A: K0PerBlock x MPerBlock x K1 /* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1 * B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*/ */
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -35,11 +36,13 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -35,11 +36,13 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32; static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock = BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
...@@ -48,8 +51,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -48,8 +51,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static constexpr auto wmma_gemm = WmmaGemm<FloatAB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{}; static constexpr auto wmma_gemm = WmmaGemm<FloatAB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
...@@ -97,8 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -97,8 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
} }
template <index_t m0, index_t n0> template <index_t m0, index_t n0>
__device__ static auto __device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
...@@ -125,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -125,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3() __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
...@@ -134,73 +134,103 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -134,73 +134,103 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
} }
// Thread level, register decriptor. Vector-write // Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() __host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{ {
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave |NThreadPerSubGroup |MAccVgprs // |MRepeat |MWave |MSubGroup |NRepeat |NWave
make_tuple(Number<MRepeat>{}, I1, MSubGroup, Number<NRepeat>{}, I1, NThreadPerSubGroup, MAccVgprs)); // |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor( const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
} }
// Thread level, register decriptor. Per-pixel write // Thread level, register decriptor. Per-pixel write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() __host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup()
{ {
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave |NThreadPerSubGroup // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave
make_tuple(Number<MRepeat>{}, I1, MSubGroup, MAccVgprs, Number<NRepeat>{}, I1, NThreadPerSubGroup)); // |NThreadPerSubGroup
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
MAccVgprs,
Number<NRepeat>{},
I1,
NThreadPerSubGroup));
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor( const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
} }
// Provide dimension size // Provide dimension size
__host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() __host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{ {
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
...@@ -210,16 +240,18 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -210,16 +240,18 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
Number<NWaves>{}, Number<NWaves>{},
Number<NPerWMMA>{})); Number<NPerWMMA>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
} }
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
AK0MK1BlockDesc{}, AK0MK1BlockDesc{},
make_tuple( make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_pass_through_transform(Number<A_K0>{}), make_unmerge_transform(
make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})), make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
...@@ -229,14 +261,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -229,14 +261,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
BK0NK1BlockDesc{}, BK0NK1BlockDesc{},
make_tuple( make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_pass_through_transform(Number<B_K0>{}), make_unmerge_transform(
make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})), make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
...@@ -252,7 +285,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -252,7 +285,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
// constexpr auto RepeatDiff = MRepeat - NRepeat; // constexpr auto RepeatDiff = MRepeat - NRepeat;
// debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset( make_tuple(0, 0, 0, 0,0))>{}], "Avalue "); // debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset(
// make_tuple(0, 0, 0, 0,0))>{}], "Avalue ");
/* First local prefetch, move out of blockwise operation. /* First local prefetch, move out of blockwise operation.
static_for<0, NRepeat, 1>{}([&](auto iN){ static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
...@@ -291,18 +325,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -291,18 +325,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
// debug_hexprinter(0x3c003c00, a_thread_vec.template AsType<FloatAB>()(Number<0>{})); // debug_hexprinter(0x3c003c00, a_thread_vec.template
wmma_gemm.template Run( AsType<FloatAB>()(Number<0>{})); wmma_gemm.template Run( a_thread_vec.template
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0, Number<iWmmaK%A_K1>{}), make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0,
a_block_buf, Number<iWmmaK%A_K1>{}), a_block_buf, a_thread_desc_, make_tuple(I0, Number<iCut>{}, I0, I0,
a_thread_desc_, I0), a_thread_buf);
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf);
}); });
// Run FIFO fashion loopover in Square // Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
...@@ -328,8 +360,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -328,8 +360,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, Number<iWmmaK%A_K1>{}), make_tuple(Number<iWmmaK/A_K1>{},
a_block_buf, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, Number<iWmmaK%A_K1>{}), a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0), make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_thread_buf); a_thread_buf);
...@@ -355,11 +387,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -355,11 +387,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, Number<iWmmaK%B_K1>{}), make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0,
b_block_buf, I0, Number<iWmmaK%B_K1>{}), b_block_buf, b_thread_desc_, make_tuple(I0,
b_thread_desc_, Number<WmmaInnerloop>{}, I0, I0, I0), b_thread_buf);
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf);
}); });
}); });
*/ */
...@@ -368,7 +398,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -368,7 +398,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k*WmmaK/A_K1>{}, m0, I0, I0, I0), make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
...@@ -377,7 +407,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -377,7 +407,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k*WmmaK/B_K1>{}, n0, I0, I0, I0), make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
...@@ -386,14 +416,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -386,14 +416,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatAB>()(i) =
[Number<a_thread_desc_.CalculateOffset(make_tuple(i/A_K1, 0, 0, 0, i%A_K1))>{}]; a_thread_buf[Number<a_thread_desc_.CalculateOffset(
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf make_tuple(i / A_K1, 0, 0, 0, i % A_K1))>{}];
[Number<b_thread_desc_.CalculateOffset(make_tuple(i/B_K1, 0, 0, 0, i%B_K1))>{}]; b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, 0, 0, 0, i % B_K1))>{}];
}); });
using wmma_input_type = using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -405,34 +436,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -405,34 +436,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
}); });
}); });
}); });
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'A';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, a_thread_buf[Number<i>{}], info);
// });
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'B';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, b_thread_buf[Number<i>{}], info);
// });
} }
protected: protected:
// A[M0, M1, M2, K0 = WmmaK] // A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, I1, I1, I1, Number<A_K1>{})); make_tuple(Number<WmmaK / A_K1>{}, I1, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK] // B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, I1, I1, I1, Number<B_K1>{})); make_tuple(Number<WmmaK / B_K1>{}, I1, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -442,7 +455,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -442,7 +455,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
FloatAB, FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<WmmaK/A_K1, 1, 1, 1, A_K1>, Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<3, 0, 1, 2, 4>, Sequence<3, 0, 1, 2, 4>,
4, 4,
A_K1, A_K1,
...@@ -452,7 +465,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -452,7 +465,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
FloatAB, FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<WmmaK/B_K1, 1, 1, 1, B_K1>, Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<3, 0, 1, 2, 4>, Sequence<3, 0, 1, 2, 4>,
4, 4,
B_K1, B_K1,
...@@ -473,11 +486,11 @@ template <index_t BlockSize, ...@@ -473,11 +486,11 @@ template <index_t BlockSize,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
LoopScheduler LoopSched> LoopScheduler LoopSched>
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_Selector() constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_Selector()
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
{ {
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize, return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
AK0MK1BlockDesc, AK0MK1BlockDesc,
......
...@@ -38,8 +38,10 @@ __global__ void ...@@ -38,8 +38,10 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -49,8 +51,7 @@ __global__ void ...@@ -49,8 +51,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
...@@ -75,8 +76,7 @@ __global__ void ...@@ -75,8 +76,7 @@ __global__ void
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx1100__))
} }
template < template <index_t BlockSize,
index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 =
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
math::integer_least_multiple(b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
} }
...@@ -308,7 +310,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -308,7 +310,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize, using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<
BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
...@@ -319,7 +322,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -319,7 +322,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NRepeat, NRepeat,
KPack>; KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n); return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_m_n);
} }
// Per pixel // Per pixel
...@@ -362,7 +367,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -362,7 +367,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize, using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<
BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
...@@ -373,7 +379,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -373,7 +379,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NRepeat, NRepeat,
KPack>; KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_m_n); return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_m_n);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
// using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup = remove_cvref_t<decltype( // using
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// = remove_cvref_t<decltype(
// MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( // MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
// CGridDesc_M_N{}))>; // CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
...@@ -420,14 +430,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -420,14 +430,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup& // const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
// clang-format off // clang-format off
/*******************************************************************************/ /*******************************************************************************/
// Memory buffer zone. // Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/ /*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// printf("K0 = %d, M = %d, K1 = %d\n", K0, a_grid_desc_k0_m_k1.GetLength(I1), (a_grid_desc_k0_m_k1.GetLength(I2))());
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(),
// (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))());
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
...@@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
...@@ -838,12 +846,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -838,12 +846,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// CONFIRMED
// printf("c_global_step = (%d, %d, %d, %d)\n",
// c_global_step[Number<0>{}],
// c_global_step[Number<1>{}],
// c_global_step[Number<2>{}],
// c_global_step[Number<3>{}]);
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
......
...@@ -12,11 +12,11 @@ namespace ck { ...@@ -12,11 +12,11 @@ namespace ck {
enum struct WmmaInstr enum struct WmmaInstr
{ {
wmma_f32_16x16x16_f16 = 0, wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16 = 0, wmma_f32_16x16x16_bf16,
wmma_f16_16x16x16_f16 = 0, wmma_f16_16x16x16_f16,
wmma_bf16_16x16x16_bf16 = 0, wmma_bf16_16x16x16_bf16,
wmma_i32_16x16x16_iu8 = 0, wmma_i32_16x16x16_iu8,
wmma_i32_16x16x16_iu4 = 0 wmma_i32_16x16x16_iu4
}; };
/* /*
...@@ -70,18 +70,18 @@ enum struct WmmaInstr ...@@ -70,18 +70,18 @@ enum struct WmmaInstr
* T = Thread ID * T = Thread ID
*/ */
template <WmmaInstr Instr, template <WmmaInstr Instr, index_t WaveSize, typename = void>
index_t WaveSize, struct wmma_type
typename = void> {
struct wmma_type{}; };
// A-swizzled // A-swizzled
template <index_t WaveSize> template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
WaveSize, WaveSize,
typename std::enable_if_t<WaveSize == 32 ||WaveSize == 64>> typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{ {
// Absolute fixing property // Absolute fixing property
// * Data Pixel // * Data Pixel
static constexpr index_t m_per_wmma = 16; static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16; static constexpr index_t n_per_wmma = 16;
...@@ -92,14 +92,15 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, ...@@ -92,14 +92,15 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma; static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety // Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{}; static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x // * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction // * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction // * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC> template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
...@@ -116,6 +117,172 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, ...@@ -116,6 +117,172 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
} }
}; };
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_bf16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
#endif
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_i32_16x16x16_iu8_w64<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
}
};
template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma> template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma>
struct WmmaSelector struct WmmaSelector
{ {
...@@ -159,20 +326,19 @@ struct WmmaSelector ...@@ -159,20 +326,19 @@ struct WmmaSelector
} }
#endif #endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_wmma = wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{}; static constexpr auto selected_wmma =
wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
__host__ __device__ constexpr WmmaSelector() __host__ __device__ constexpr WmmaSelector()
{ {
static_assert(selected_wmma.m_per_wmma == 16, static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == 16, static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.k_per_wmma == 16, static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size== static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register"); "WRONG! Invalid Number of Accumulator Register");
} }
...@@ -198,7 +364,7 @@ struct WmmaGemm ...@@ -198,7 +364,7 @@ struct WmmaGemm
__host__ __device__ constexpr WmmaGemm() __host__ __device__ constexpr WmmaGemm()
{ {
static_assert(NPerWmma == 16 && MPerWmma == 16 , static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
...@@ -209,17 +375,23 @@ struct WmmaGemm ...@@ -209,17 +375,23 @@ struct WmmaGemm
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA> template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{ {
const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); const auto MBlockxRepeat =
const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); const auto NBlockxRepeat =
const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MBlockxRepeat), make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave), make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{}, make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})), Number<wmma_instr.num_acc_vgprs_per_wave>{})),
...@@ -243,17 +415,23 @@ struct WmmaGemm ...@@ -243,17 +415,23 @@ struct WmmaGemm
// Per-Pixel write // Per-Pixel write
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA> template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{ {
const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); const auto MBlockxRepeat =
const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); const auto NBlockxRepeat =
const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MBlockxRepeat), make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave), make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{}, make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})), Number<wmma_instr.num_acc_vgprs_per_wave>{})),
...@@ -279,15 +457,13 @@ struct WmmaGemm ...@@ -279,15 +457,13 @@ struct WmmaGemm
return wmma_instr.num_acc_vgprs_per_wave; return wmma_instr.num_acc_vgprs_per_wave;
} }
__device__ static constexpr index_t GetWaveSize() __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
{
return wmma_instr.wave_size;
}
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert((is_same<src_type, half_t>::value && is_same<dst_type, float>::value) || static_assert(
(is_same<src_type, half_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type, bhalf_t>::value && is_same<dst_type, float>::value) || (is_same<src_type, bhalf_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type, half_t>::value && is_same<dst_type, half_t>::value) || (is_same<src_type, half_t>::value && is_same<dst_type, half_t>::value) ||
(is_same<src_type, bhalf_t>::value && is_same<dst_type, bhalf_t>::value) || (is_same<src_type, bhalf_t>::value && is_same<dst_type, bhalf_t>::value) ||
...@@ -295,23 +471,20 @@ struct WmmaGemm ...@@ -295,23 +471,20 @@ struct WmmaGemm
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type, int4_t>::value && is_same<dst_type, int32_t>::value) || (is_same<src_type, int4_t>::value && is_same<dst_type, int32_t>::value)
#endif #endif
,"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), (int8, int32) or (int4, int32)!"); ,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC) if constexpr(!TransposeC)
{ {
wmma_instr.template run<MPerWmma, NPerWmma>( wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
p_a_wave, p_b_wave, p_c_thread);
} }
else else
{ {
wmma_instr.template run<MPerWmma, NPerWmma>( wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
p_b_wave, p_a_wave, p_c_thread);
} }
} }
__device__ static auto GetLaneId() __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
{
return get_thread_local_1d_id() % wmma_instr.wave_size;
}
__device__ static auto GetSubGroupId() __device__ static auto GetSubGroupId()
{ {
...@@ -324,7 +497,7 @@ struct WmmaGemm ...@@ -324,7 +497,7 @@ struct WmmaGemm
} }
__device__ static auto GetSwizzledLaneIdLow() __device__ static auto GetSwizzledLaneIdLow()
{ {
return ((GetLaneIdUnderSubGroup() & 1) << 3 ) | (GetLaneIdUnderSubGroup() >> 1); return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
} }
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
...@@ -348,10 +521,10 @@ struct WmmaGemm ...@@ -348,10 +521,10 @@ struct WmmaGemm
static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{}; static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma; static constexpr auto wmma_instr = wmma.selected_wmma;
__host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() __host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{ {
return make_tuple( return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
} }
}; };
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32 // src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32; struct intrin_wmma_f32_16x16x16_f16_w32;
...@@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32 // src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32; struct intrin_wmma_f32_16x16x16_bf16_w32;
...@@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
} }
}; };
/********************************WAVE64 MODE***********************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment