Unverified Commit 941d1f7c authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merging the gfx12 code into public repo. (#1362)

parent a32b1bc6
......@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported())
ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
......@@ -50,8 +50,9 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
......@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
......@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
? false
: true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
? false
: true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
......@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
......@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -48,8 +48,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
......@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -90,8 +90,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
......
......@@ -107,7 +107,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__))
defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported()))
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
......
......@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution;
// check device
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -39,8 +39,9 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
......
......@@ -61,7 +61,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
......@@ -166,6 +166,7 @@ __global__ void
ignore = O;
ignore = G0;
ignore = G1;
ignore = alpha;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx11__))
......@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -60,7 +60,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
......@@ -165,6 +165,7 @@ __global__ void
ignore = O;
ignore = G0;
ignore = G1;
ignore = alpha;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx11__))
......@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B0EnableLds)
{
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor(
B0BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
......@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B1EnableLds)
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_LRow = I2;
#else
constexpr auto B_LRow = I1;
#endif
return transform_tensor_descriptor(
B1BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})),
......
......@@ -50,7 +50,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
......@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
......
......@@ -54,7 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -147,7 +147,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
......@@ -237,7 +237,7 @@ __global__ void
const CDEElementwiseOperation cde_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
......@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
......@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
......@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
......@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
......@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
......
......@@ -45,7 +45,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
......@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1;
constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
......@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
......@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
......@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
......@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
b_block_space_size_aligned * sizeof(BDataType));
};
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
......
......@@ -35,8 +35,9 @@ __global__ void
const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
p_in_global,
out_grid_desc,
......
......@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation element_op_;
};
// Specilized for WMMA
// Specilized for WMMA-Navi3
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
......@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
ElementwiseOperation element_op_{};
};
// Specilized for WMMA-Navi4
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
bool IntraRowSwizzlePerm,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
ignore = src_idx;
}
template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! SliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
"wrong! Buffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
// scalar per access on each dim
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// src_desc error, non constexpr, caused by merge transform
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v_this_row;
// int type temp value due to intrinsic requirement
int temp = 0;
// apply element-wise operation
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// apply intra-row permute.
if constexpr(IntraRowSwizzlePerm)
{
temp = __builtin_amdgcn_permlane16(
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
v_this_row = type_convert_sp<SrcData>(temp);
}
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
});
});
}
ElementwiseOperation element_op_{};
};
} // namespace ck
......@@ -11,12 +11,17 @@ namespace ck {
enum struct WmmaInstr
{
// gfx11
wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16,
wmma_f16_16x16x16_f16,
wmma_bf16_16x16x16_bf16,
wmma_i32_16x16x16_iu8,
wmma_i32_16x16x16_iu4
wmma_i32_16x16x16_iu4,
// gfx12
wmma_f32_16x16x16_f16_gfx12,
wmma_f32_16x16x16_bf16_gfx12,
wmma_i32_16x16x16_iu8_gfx12,
};
/*
......@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
};
// gfx12
// A-swizzled
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma,
index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32)
{
intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
}
};
template <typename src_type_a,
typename src_type_b,
typename dst_type,
......@@ -296,13 +417,21 @@ struct WmmaSelector
template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
#else
return WmmaInstr::wmma_f32_16x16x16_f16;
#endif
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
#else
return WmmaInstr::wmma_f32_16x16x16_bf16;
#endif
}
template <>
......@@ -320,8 +449,13 @@ struct WmmaSelector
template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
#else
return WmmaInstr::wmma_i32_16x16x16_iu8;
#endif
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
......@@ -502,6 +636,9 @@ struct WmmaGemm
__device__ static auto GetSubGroupId()
{
static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
wmma_instr.wave_size,
"");
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
}
......@@ -516,12 +653,20 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
#ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
#endif
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
#ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
#endif
}
__device__ static CIndex GetBeginOfThreadBlk()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment