Commit b3cc22a3 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent d16063db
......@@ -30,12 +30,14 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
// MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I3 = Number<4>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -85,8 +87,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, WMMA_a_idx[I1], KPerThread * WMMA_a_idx[I0]);
// |KRepeat |MRepeat|Mwave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -96,20 +98,20 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, WMMA_b_idx[I1], KPerThread * WMMA_b_idx[I0]);
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
}
template <index_t m0, index_t n0, index_t WMMA_i, index_t blk_i>
template <index_t m0, index_t n0>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<WMMA_i>, Number<blk_i>)
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(WMMA_i, blk_i);
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
......@@ -129,27 +131,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t WMMA_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<WMMA_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk4D(WMMA_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
......@@ -162,59 +143,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
// Thread level, register decriptor.
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
// |MRepeat |MWave |MSubGroup |NRepeat |NWave |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, MSubGroup, Number<NRepeat>{}, I1, NThreadPerSubGroup, MAccVgprs));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
__host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerWMMA>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerWMMA>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
return wmma_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
template <typename CGridDesc_M_N>
......@@ -234,32 +187,46 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
__host__ __device__ static constexpr auto MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack()
{
return transform_tensor_descriptor(
static constexpr auto a_block_desc_temp_km0m1m2 = transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_pass_through_transform(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_merge_transform(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
a_block_desc_temp_km0m1m2,
make_tuple(
make_unmerge_transform(make_tuple(Number<A_K0*A_K1/KPack>{}, Number<KPack>{})),
make_pass_through_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
__host__ __device__ static constexpr auto MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack()
{
return transform_tensor_descriptor(
static constexpr auto b_block_desc_temp_kn0n1n2 = transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_pass_through_transform(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_merge_transform(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
b_block_desc_temp_kn0n1n2,
make_tuple(
make_unmerge_transform(make_tuple(Number<B_K0*B_K1/KPack>{}, Number<KPack>{})),
make_pass_through_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
}
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
static constexpr auto a_block_desc_krepeat_m0_m1_m2_kpack = MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack();
static constexpr auto b_block_desc_krepeat_n0_n1_n2_kpack = MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
......@@ -298,7 +265,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack,
make_tuple(Number<iWmmaK>{}, iCut, I0, I0, I0),
a_block_buf,
a_thread_desc_,
......@@ -328,7 +295,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop+RepeatDiff, I0, I0, I0),
a_block_buf,
a_thread_desc_,
......@@ -355,7 +322,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
b_thread_copy_.Run(b_block_desc_krepeat_n0_n1_n2_kpack,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop, I0, I0, I0),
b_block_buf,
b_thread_desc_,
......@@ -380,7 +347,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_block_desc_krepeat_m0_m1_m2_kpack),
decltype(a_thread_desc_),
Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>,
......@@ -390,7 +357,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_block_desc_krepeat_n0_n1_n2_kpack),
decltype(b_thread_desc_),
Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>,
......@@ -413,11 +380,11 @@ template <index_t BlockSize,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2_Selector()
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2<BlockSize,
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
......
......@@ -72,9 +72,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto M1Number = Number<M1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
......@@ -87,10 +86,12 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
#endif
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
......@@ -154,12 +155,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
}
}
static auto MakeCGridDescriptor_M0_N_M1(index_t M, index_t N, index_t StrideC)
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
assert(M % M1 == 0);
const index_t M0 = M / M1;
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
......@@ -173,8 +170,6 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
static_assert(false, "Padding Gemm Not implemented");
/* Not implemented yet.
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
......@@ -183,26 +178,25 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
*/
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M1Number)),
make_pass_through_transform(N)),
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M0_N_M1 = decltype(MakeCGridDescriptor_M0_N_M1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1<
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......@@ -210,7 +204,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M0_N_M1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -238,15 +232,16 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
#endif
NumPrefetch,
LoopSched,
PipelineVer>;
// Argument
struct Argument : public BaseArgument
struct Argument : public BaseArgumentW
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
......@@ -267,8 +262,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m0_n_m1_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......@@ -278,18 +273,18 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{
a_grid_desc_k0_m_k1_ = DeviceGemmWmma::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmWmma::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m0_n_m1_ = DeviceGemmWmma::MakeCGridDescriptor_M0_N_M1(M, N, StrideC);
c_grid_desc_m_n_ = DeviceGemmWmma::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m0_n_m1_, M01, N01);
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m0_n_m1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n_m1_);
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow(c_grid_desc_m_n_);
}
}
......@@ -299,9 +294,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M0_N_M1 c_grid_desc_m0_n_m1_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......@@ -327,15 +322,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m0_n_m1_{ " << arg.c_grid_desc_m0_n_m1_.GetLength(I0)
<< ", " << arg.c_grid_desc_m0_n_m1_.GetLength(I1) << ", "
<< arg.c_grid_desc_m0_n_m1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n_m1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
......@@ -343,7 +338,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m0_n_m1_);
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -358,7 +353,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -375,7 +370,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -389,7 +384,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -406,7 +401,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -447,7 +442,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n_m1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
......
......@@ -11,35 +11,107 @@ namespace ck {
enum struct WmmaInstr
{
wmma_f32_16x16x16_f16_w32 = 0,
wmma_f32_16x16x16_bf16_w32 = 0,
wmma_f16_16x16x16_f16_w32 = 0,
wmma_bf16_16x16x16_bf16_w32 = 0,
wmma_i32_16x16x16_iu8_w32 = 0,
wmma_i32_16x16x16_iu4_w32 = 0
wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16 = 0,
wmma_f16_16x16x16_f16 = 0,
wmma_bf16_16x16x16_bf16 = 0,
wmma_i32_16x16x16_iu8 = 0,
wmma_i32_16x16x16_iu4 = 0
};
template <WmmaInstr instr>
/*
* WMMA Wave Tile Always MxNxK = 16x16x16
* WAVE32
-----------------------------------
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|RC1| | | | | | | | | | | | | | | |
|RC2| | | | | | | | | | | | | | | |
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC6| | | | | | | | | | | | | | | |
|RC7| | | | | | | | | | | | | | | |
-----------------------------------
| | | | | | | | | | | | | | | | | SubGroup 1
| | | | | | | | | | | | | | | | |
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
-----------------------------------
* WAVE64
-----------------------------------
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
| | | | | | | | | | | | | | | | |
-----------------------------------
* RC = Register for storing accumalted result
* T = Thread ID
*/
template <WmmaInstr Instr,
index_t WaveSize,
typename enable_if<WaveSize == 32 || WaveSize == 64, bool>:: = false>
struct wmma_type;
template <>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_w32>
// A-swizzled
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, WaveSize>
{
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t wave_size = 32;
static constexpr index_t lane_size = 16;
static constexpr index_t src_data_size = 2;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_srcregs_per_wmma = 8;
static constexpr index_t num_accregs_per_wmma = 8;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
intrin_wmma_f32_16x16x16_f16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma>
......@@ -51,54 +123,54 @@ struct WmmaSelector
template <>
static constexpr auto GetWmma<half_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_f16_w32;
return WmmaInstr::wmma_f32_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_bf16_w32;
return WmmaInstr::wmma_f32_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<half_t, half_t, 16, 16>()
{
return WmmaInstr::wmma_f16_16x16x16_f16_w32;
return WmmaInstr::wmma_f16_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, 16, 16>()
{
return WmmaInstr::wmma_bf16_16x16x16_bf16_w32;
return WmmaInstr::wmma_bf16_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<int8_t, float, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu8_w32;
return WmmaInstr::wmma_i32_16x16x16_iu8;
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, float, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4_w32;
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
#endif
static constexpr auto selected_wmma = wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>()>{};
static constexpr auto selected_wmma = wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>(), get_warp_size()>{};
__host__ __device__ constexpr WmmaSelector()
{
static_assert(selected_wmma.m_per_wmma == selected_wmma.n_per_wmma,
"WRONG! WMMA_M must equal to WMMA_N");
static_assert(selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.m_per_wmma == selected_wmma.k_per_wmma,
"WRONG! WMMA_M must equal to WMMA_K");
static_assert(selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.k_per_wmma == 16,
"WRONG! WMMA_M must equal to WMMA_N");
"WRONG! WMMA_M must equal to 16");
static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wmma * selected_wmma.acc_data_size==
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Number of Accumulator Register");
......@@ -135,26 +207,26 @@ struct WmmaGemm
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template <typename CDesc_M0_N0_M1_N1_M2_N2>
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(const CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA& c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto MRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I0);
const auto NRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_groups_per_blk>{},
Number<wmma_instr.num_input_blks>{},
Number<wmma_instr.group_size>{})),
make_pass_through_transform(Number<wmma_instr.num_threads_per_blk>{})),
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MRepeat),
make_pass_through_transform(Mwave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NRepeat),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -163,91 +235,22 @@ struct WmmaGemm
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(Number<wmma_instr.num_threads_per_blk>{}),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_groups_per_blk>{},
Number<wmma_instr.num_input_blks>{},
Number<wmma_instr.group_size>{}))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<2, 6>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{}));
Sequence<5>{}));
}
template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
__device__ static constexpr index_t GetRegSizePerWmma()
{
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
return transform_tensor_descriptor(
c_desc_g_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(G),
make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(wmma_instr.num_groups_per_blk,
wmma_instr.num_input_blks,
wmma_instr.group_size)),
make_pass_through_transform(wmma_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{},
Sequence<8>{}));
return wmma_instr.num_acc_vgprs_per_wave;
}
__device__ static constexpr index_t GetRegSizePerXdlops()
__device__ static constexpr index_t GetWaveSize()
{
return MPerWmma * NPerWmma / wmma_instr.wave_size;
return wmma_instr.wave_size;
}
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
......@@ -272,67 +275,50 @@ struct WmmaGemm
}
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
__device__ static auto GetLaneId()
{
return get_thread_local_1d_id() % wmma_instr.wave_size;
}
__device__ static auto GetLaneIdHigh()
__device__ static auto GetSubGroupId()
{
return GetLaneId() / 16;
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
}
__device__ static auto GetLaneIdLow()
__device__ static auto GetLaneIdUnderSubGroup()
{
return GetLaneId() % 16;
return GetLaneId() % wmma_instr.num_thread_per_subgroups;
}
__device__ static auto GetSwizzledLaneIdLow()
{
return ((GetLaneIdLow() & 1) << 3 ) | (GetLaneIdLow() >> 1);
return ((GetLaneIdUnderSubGroup() & 1) << 3 ) | (GetLaneIdUnderSubGroup() >> 1);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return make_tuple(0, GetSwizzledLaneIdLow());
return GetSwizzledLaneIdLow();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
return make_tuple(0, GetLaneIdLow());
return GetLaneIdUnderSubGroup();
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
__device__ static CIndex GetBeginOfThreadBlk()
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
index_t n_offset = blk_i * wmma_instr.n_per_blk + blk_td;
index_t m_offset = xdlops_i * wmma_instr.m_per_blk + blk_id * wmma_instr.group_size;
index_t n_offset = GetLaneIdUnderSubGroup();
index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
}
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma;
static constexpr auto KPerXdlops = wmma.GetKPerXdlops();
static constexpr auto K1PerXdlops = wmma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
__host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(
Number<wmma_instr.num_groups_per_blk>{}, I1, Number<wmma_instr.group_size>{}, I1);
Number<I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
}
};
......
......@@ -8,7 +8,6 @@
// TODO: Add arch limitation
namespace ck {
// wave32 only
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;
......@@ -24,6 +23,20 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment