Commit 0a808724 authored by aska-0096's avatar aska-0096
Browse files

Tidy up + format

parent 289f15de
...@@ -38,8 +38,10 @@ __global__ void ...@@ -38,8 +38,10 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -49,18 +51,17 @@ __global__ void ...@@ -49,18 +51,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_a_grid, p_b_grid,
p_b_grid, p_c_grid,
p_c_grid, p_shared,
p_shared, a_grid_desc_k0_m_k1,
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1,
b_grid_desc_k0_n_k1, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op,
a_element_op, b_element_op,
b_element_op, c_element_op,
c_element_op, block_2_ctile_map);
block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -75,50 +76,49 @@ __global__ void ...@@ -75,50 +76,49 @@ __global__ void
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx1100__))
} }
template < template <index_t BlockSize,
index_t BlockSize, typename FloatAB,
typename FloatAB, typename FloatAcc,
typename FloatAcc, typename FloatCShuffle,
typename FloatCShuffle, typename FloatC,
typename FloatC, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, typename AGridDesc_K0_M_K1,
typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1,
typename BGridDesc_K0_N_K1, typename CGridDesc_M_N,
typename CGridDesc_M_N, typename AElementwiseOperation,
typename AElementwiseOperation, typename BElementwiseOperation,
typename BElementwiseOperation, typename CElementwiseOperation,
typename CElementwiseOperation, index_t MPerBlock,
index_t MPerBlock, index_t NPerBlock,
index_t NPerBlock, index_t K0PerBlock,
index_t K0PerBlock, index_t MPerWmma,
index_t MPerWmma, index_t NPerWmma,
index_t NPerWmma, index_t K1Value,
index_t K1Value, index_t MRepeat,
index_t MRepeat, index_t NRepeat,
index_t NRepeat, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcAccessOrder, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_K1,
index_t ABlockTransferDstScalarPerVector_K1, bool AThreadTransferSrcResetCoordinateAfterRun,
bool AThreadTransferSrcResetCoordinateAfterRun, bool ABlockLdsExtraM,
bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcAccessOrder, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_K1,
index_t BBlockTransferDstScalarPerVector_K1, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BBlockLdsExtraN,
bool BBlockLdsExtraN, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t NumGemmKPrefetchStage = 1,
index_t NumGemmKPrefetchStage = 1, LoopScheduler LoopSched = make_default_loop_scheduler(),
LoopScheduler LoopSched = make_default_loop_scheduler(), PipelineVersion PipelineVer = PipelineVersion::v1>
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 =
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
math::integer_least_multiple(b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
} }
...@@ -308,18 +310,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -308,18 +310,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize, using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<
FloatAB, BlockSize,
FloatAcc, FloatAB,
decltype(a_block_desc_k0perblock_mperblock_k1), FloatAcc,
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
MPerWmma, decltype(b_block_desc_k0perblock_nperblock_k1),
NPerWmma, MPerWmma,
MRepeat, NPerWmma,
NRepeat, MRepeat,
KPack>; NRepeat,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n);
return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_m_n);
} }
// Per pixel // Per pixel
...@@ -362,18 +367,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -362,18 +367,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize, using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<
FloatAB, BlockSize,
FloatAcc, FloatAB,
decltype(a_block_desc_k0perblock_mperblock_k1), FloatAcc,
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
MPerWmma, decltype(b_block_desc_k0perblock_nperblock_k1),
NPerWmma, MPerWmma,
MRepeat, NPerWmma,
NRepeat, MRepeat,
KPack>; NRepeat,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_m_n);
return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_m_n);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
// using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup = remove_cvref_t<decltype( // using
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// = remove_cvref_t<decltype(
// MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( // MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
// CGridDesc_M_N{}))>; // CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
...@@ -419,15 +429,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -419,15 +429,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup& // const
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, // CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
// clang-format off // clang-format off
/*******************************************************************************/ /*******************************************************************************/
// Memory buffer zone. // Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/ /*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// printf("K0 = %d, M = %d, K1 = %d\n", K0, a_grid_desc_k0_m_k1.GetLength(I1), (a_grid_desc_k0_m_k1.GetLength(I2))());
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(),
// (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))());
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
...@@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
...@@ -838,19 +846,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -838,19 +846,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// CONFIRMED
// printf("c_global_step = (%d, %d, %d, %d)\n",
// c_global_step[Number<0>{}],
// c_global_step[Number<1>{}],
// c_global_step[Number<2>{}],
// c_global_step[Number<3>{}]);
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
} }
// clang-format on // clang-format on
} }
}; };
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32 // src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32; struct intrin_wmma_f32_16x16x16_f16_w32;
...@@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32 // src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32; struct intrin_wmma_f32_16x16x16_bf16_w32;
...@@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
} }
}; };
/********************************WAVE64 MODE***********************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment