Commit 0a808724 authored by aska-0096's avatar aska-0096
Browse files

Tidy up + format

parent 289f15de
......@@ -23,8 +23,9 @@ template <index_t BlockSize,
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*/
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -35,11 +36,13 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock = BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
......@@ -48,8 +51,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static constexpr auto wmma_gemm = WmmaGemm<FloatAB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......@@ -97,8 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
}
template <index_t m0, index_t n0>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
......@@ -125,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3()
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
......@@ -134,73 +134,103 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, MSubGroup, Number<NRepeat>{}, I1, NThreadPerSubGroup, MAccVgprs));
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N& c_grid_desc_m_n)
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor(
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Thread level, register decriptor. Per-pixel write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup()
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave |NThreadPerSubGroup
make_tuple(Number<MRepeat>{}, I1, MSubGroup, MAccVgprs, Number<NRepeat>{}, I1, NThreadPerSubGroup));
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave
// |NThreadPerSubGroup
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
MAccVgprs,
Number<NRepeat>{},
I1,
NThreadPerSubGroup));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(const CGridDesc_M_N& c_grid_desc_m_n)
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor(
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
......@@ -210,16 +240,18 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
......@@ -229,14 +261,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
......@@ -252,7 +285,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
// constexpr auto RepeatDiff = MRepeat - NRepeat;
// debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset( make_tuple(0, 0, 0, 0,0))>{}], "Avalue ");
// debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset(
// make_tuple(0, 0, 0, 0,0))>{}], "Avalue ");
/* First local prefetch, move out of blockwise operation.
static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
......@@ -291,18 +325,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
// debug_hexprinter(0x3c003c00, a_thread_vec.template AsType<FloatAB>()(Number<0>{}));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
// debug_hexprinter(0x3c003c00, a_thread_vec.template
AsType<FloatAB>()(Number<0>{})); wmma_gemm.template Run( a_thread_vec.template
AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template
AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0, Number<iWmmaK%A_K1>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf);
make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0,
Number<iWmmaK%A_K1>{}), a_block_buf, a_thread_desc_, make_tuple(I0, Number<iCut>{}, I0, I0,
I0), a_thread_buf);
});
// Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
......@@ -328,8 +360,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, Number<iWmmaK%A_K1>{}),
a_block_buf,
make_tuple(Number<iWmmaK/A_K1>{},
Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, Number<iWmmaK%A_K1>{}), a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_thread_buf);
......@@ -355,11 +387,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, Number<iWmmaK%B_K1>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf);
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0,
I0, Number<iWmmaK%B_K1>{}), b_block_buf, b_thread_desc_, make_tuple(I0,
Number<WmmaInnerloop>{}, I0, I0, I0), b_thread_buf);
});
});
*/
......@@ -368,7 +398,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k*WmmaK/A_K1>{}, m0, I0, I0, I0),
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
......@@ -377,7 +407,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k*WmmaK/B_K1>{}, n0, I0, I0, I0),
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
......@@ -386,14 +416,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(i/A_K1, 0, 0, 0, i%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(i/B_K1, 0, 0, 0, i%B_K1))>{}];
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, 0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type =
typename vector_type<FloatAB, WmmaK>::type;
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -405,34 +436,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
});
});
});
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'A';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, a_thread_buf[Number<i>{}], info);
// });
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'B';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, b_thread_buf[Number<i>{}], info);
// });
}
protected:
// A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, I1, I1, I1, Number<A_K1>{}));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / A_K1>{}, I1, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, I1, I1, I1, Number<B_K1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / B_K1>{}, I1, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......@@ -442,7 +455,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK/A_K1, 1, 1, 1, A_K1>,
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<3, 0, 1, 2, 4>,
4,
A_K1,
......@@ -452,7 +465,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK/B_K1, 1, 1, 1, B_K1>,
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<3, 0, 1, 2, 4>,
4,
B_K1,
......@@ -473,11 +486,11 @@ template <index_t BlockSize,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_Selector()
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
......
......@@ -38,8 +38,10 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock,
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -49,8 +51,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid,
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
......@@ -75,8 +76,7 @@ __global__ void
#endif // end of if (defined(__gfx1100__))
}
template <
index_t BlockSize,
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatCShuffle,
......@@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
}
......@@ -308,7 +310,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
......@@ -319,7 +322,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NRepeat,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n);
return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_m_n);
}
// Per pixel
......@@ -362,7 +367,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
......@@ -373,7 +379,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NRepeat,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_m_n);
return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_m_n);
}
__host__ __device__ static constexpr auto
......@@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
// using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup = remove_cvref_t<decltype(
// using
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// = remove_cvref_t<decltype(
// MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
// CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
......@@ -420,14 +430,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
// clang-format off
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// printf("K0 = %d, M = %d, K1 = %d\n", K0, a_grid_desc_k0_m_k1.GetLength(I1), (a_grid_desc_k0_m_k1.GetLength(I2))());
constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(),
// (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))());
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
......@@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
......@@ -838,12 +846,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// CONFIRMED
// printf("c_global_step = (%d, %d, %d, %d)\n",
// c_global_step[Number<0>{}],
// c_global_step[Number<1>{}],
// c_global_step[Number<2>{}],
// c_global_step[Number<3>{}]);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
......
......@@ -12,11 +12,11 @@ namespace ck {
enum struct WmmaInstr
{
wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16 = 0,
wmma_f16_16x16x16_f16 = 0,
wmma_bf16_16x16x16_bf16 = 0,
wmma_i32_16x16x16_iu8 = 0,
wmma_i32_16x16x16_iu4 = 0
wmma_f32_16x16x16_bf16,
wmma_f16_16x16x16_f16,
wmma_bf16_16x16x16_bf16,
wmma_i32_16x16x16_iu8,
wmma_i32_16x16x16_iu4
};
/*
......@@ -70,18 +70,18 @@ enum struct WmmaInstr
* T = Thread ID
*/
template <WmmaInstr Instr,
index_t WaveSize,
typename = void>
struct wmma_type{};
template <WmmaInstr Instr, index_t WaveSize, typename = void>
struct wmma_type
{
};
// A-swizzled
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 ||WaveSize == 64>>
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
......@@ -92,14 +92,15 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
......@@ -116,6 +117,172 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_bf16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
index_t Opsel,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
}
}
};
#endif
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_i32_16x16x16_iu8_w64<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
}
};
template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma>
struct WmmaSelector
{
......@@ -159,20 +326,19 @@ struct WmmaSelector
}
#endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_wmma = wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
static constexpr auto selected_wmma =
wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
__host__ __device__ constexpr WmmaSelector()
{
static_assert(selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.k_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size==
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
}
......@@ -198,7 +364,7 @@ struct WmmaGemm
__host__ __device__ constexpr WmmaGemm()
{
static_assert(NPerWmma == 16 && MPerWmma == 16 ,
static_assert(NPerWmma == 16 && MPerWmma == 16,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
......@@ -209,17 +375,23 @@ struct WmmaGemm
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MBlockxRepeat),
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
......@@ -243,17 +415,23 @@ struct WmmaGemm
// Per-Pixel write
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MBlockxRepeat),
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
......@@ -279,15 +457,13 @@ struct WmmaGemm
return wmma_instr.num_acc_vgprs_per_wave;
}
__device__ static constexpr index_t GetWaveSize()
{
return wmma_instr.wave_size;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert((is_same<src_type, half_t>::value && is_same<dst_type, float>::value) ||
static_assert(
(is_same<src_type, half_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type, bhalf_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type, half_t>::value && is_same<dst_type, half_t>::value) ||
(is_same<src_type, bhalf_t>::value && is_same<dst_type, bhalf_t>::value) ||
......@@ -295,23 +471,20 @@ struct WmmaGemm
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type, int4_t>::value && is_same<dst_type, int32_t>::value)
#endif
,"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), (int8, int32) or (int4, int32)!");
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_a_wave, p_b_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_b_wave, p_a_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
}
}
__device__ static auto GetLaneId()
{
return get_thread_local_1d_id() % wmma_instr.wave_size;
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
__device__ static auto GetSubGroupId()
{
......@@ -324,7 +497,7 @@ struct WmmaGemm
}
__device__ static auto GetSwizzledLaneIdLow()
{
return ((GetLaneIdUnderSubGroup() & 1) << 3 ) | (GetLaneIdUnderSubGroup() >> 1);
return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
......@@ -348,10 +521,10 @@ struct WmmaGemm
static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma;
__host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
__host__ __device__ static constexpr auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(
I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
}
};
......
......@@ -8,6 +8,8 @@
// TODO: Add arch limitation
namespace ck {
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;
......@@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;
......@@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
}
};
/********************************WAVE64 MODE***********************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment