Unverified Commit 5dbbf5d6 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #63 from ROCm/lwpck-1559

Merge changes for gfx12.
parents 78f637e4 2d3d7190
...@@ -19,48 +19,51 @@ using AElementOp = PassThrough; ...@@ -19,48 +19,51 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout, // clang-format off
BLayout, using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
CLayout, < ALayout,
ADataType, BLayout,
BDataType, CLayout,
CDataType, ADataType,
AccDataType, BDataType,
CShuffleDataType, CDataType,
AElementOp, AccDataType,
BElementOp, CShuffleDataType,
CElementOp, AElementOp,
GemmDefault, BElementOp,
1, CElementOp,
32, GemmDefault,
16, 1, // Prefetch stage
32, 128, // BlockSize
64, 64, // MPerBlock
8, 128, // NPerBlock
16, 64, // KPerBlock
16, 2, // K1
1, 16, // MPerWmma
2, 16, // NPerWmma
S<2, 16, 1>, 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
S<1, 0, 2>, 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<1, 0, 2>, S<4, 32, 1>,
2, S<1, 0, 2>,
8, S<1, 0, 2>,
8, 2,
true, 2,
S<2, 16, 1>, 2,
S<1, 0, 2>, true,
S<1, 0, 2>, S<4, 32, 1>,
2, S<1, 0, 2>,
8, S<1, 0, 2>,
8, 2,
true, 2,
1, 2,
1, true,
S<1, 16, 1, 2>, 1, // C shuffle (M Repeat) Per store
8>; 1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 4: case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break; break;
case 5: case 5:
......
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1200) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1200 gfx1201)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
......
...@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = ...@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2, 2,
4, 4,
4, 4,
true, false,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
4, 4,
4, 4,
true, false,
1, 1,
1, 1,
S<1, 64, 1, 2>, S<1, 64, 1, 2>,
......
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200) list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
......
...@@ -279,8 +279,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) ...@@ -279,8 +279,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch(conv_param.num_dim_spatial_) switch(conv_param.num_dim_spatial_)
{ {
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); case 2:
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
} }
return false; return false;
......
...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial ...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4 #define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8 //#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory = using DeviceMHAFactory =
std::tuple< std::tuple<
#ifdef CK_MHA_USE_WAVE_1 #ifdef CK_MHA_USE_WAVE_1
...@@ -277,10 +277,10 @@ using DeviceMHAFactory = ...@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>
#endif #endif
#ifdef CK_MHA_USE_WAVE_8 #ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
......
...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial ...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4 #define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8 //#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory = using DeviceMHAFactory =
std::tuple< std::tuple<
#ifdef CK_MHA_USE_WAVE_1 #ifdef CK_MHA_USE_WAVE_1
...@@ -277,10 +277,10 @@ using DeviceMHAFactory = ...@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>
#endif #endif
#ifdef CK_MHA_USE_WAVE_8 #ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
......
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1200) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
......
...@@ -70,9 +70,6 @@ struct BlockwiseGemmWMMA ...@@ -70,9 +70,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 2; static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2; static constexpr index_t B_KRow = 2;
static constexpr index_t A_KRow_ = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow_ = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
...@@ -316,7 +313,7 @@ struct BlockwiseGemmWMMA ...@@ -316,7 +313,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -326,8 +323,7 @@ struct BlockwiseGemmWMMA ...@@ -326,8 +323,7 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple( make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -373,7 +369,7 @@ struct BlockwiseGemmWMMA ...@@ -373,7 +369,7 @@ struct BlockwiseGemmWMMA
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow_>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
...@@ -381,7 +377,7 @@ struct BlockwiseGemmWMMA ...@@ -381,7 +377,7 @@ struct BlockwiseGemmWMMA
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow_>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
...@@ -443,30 +439,76 @@ struct BlockwiseGemmWMMA ...@@ -443,30 +439,76 @@ struct BlockwiseGemmWMMA
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopyType = template <bool EnableLds>
ThreadwiseTensorSliceTransfer_v4<FloatA, struct AThreadCopySelector;
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), template <>
decltype(a_thread_desc_), struct AThreadCopySelector<true>
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>, {
Sequence<0, 1, 2, 3, 4, 5>, using type =
5, ThreadwiseTensorSliceTransfer_v4<FloatA,
A_K1, FloatA,
A_K1>; decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
using BThreadCopyType = Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
ThreadwiseTensorSliceTransfer_v4<FloatB, Sequence<0, 1, 2, 3, 4, 5>,
FloatB, 5,
decltype(b_block_desc_k0_n0_n1_n2_k1), A_K1,
decltype(b_thread_desc_), A_K1>;
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>, };
Sequence<0, 1, 2, 3, 4, 5>,
5, template <>
B_K1, struct AThreadCopySelector<false>
B_K1>; {
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
AThreadCopyType a_thread_copy_; FloatA,
BThreadCopyType b_thread_copy_; FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
false>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
}; };
#else #else
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -133,12 +133,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -133,12 +133,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
...@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
else else
{ {
if(!(arg.a_kz_stride_ == 1 && if(!(arg.a_kz_stride_ == 1))
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access A-k check failure\n"); index_t LastK =
return false; AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
if(LastK % ABlockTransferSrcScalarPerVector == 0)
{
printf("DeviceOp: Vector Access A-k check failure\n");
return false;
}
} }
} }
......
...@@ -322,7 +322,7 @@ __global__ void ...@@ -322,7 +322,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
......
...@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true; (MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
...@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16; static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto = static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
? false
: true;
static constexpr auto BEnableLds_auto = static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true; (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
? false
: true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
......
...@@ -48,8 +48,9 @@ __global__ void ...@@ -48,8 +48,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -60,8 +60,7 @@ __global__ void ...@@ -60,8 +60,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx11__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
......
...@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B0EnableLds) if constexpr(B0EnableLds)
{ {
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
B0BlockDesc_{}, B0BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})), Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
...@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B1EnableLds) if constexpr(B1EnableLds)
{ {
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_LRow = I2;
#else
constexpr auto B_LRow = I1; constexpr auto B_LRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
B1BlockDesc_{}, B1BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})), make_pass_through_transform(Number<B_L1>{})),
......
...@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1; constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
...@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
......
...@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
} }
else else
{ {
constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1; constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, make_tuple(Number<KWmmaPerblock>{},
...@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
} }
else else
{ {
constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1; constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, make_tuple(Number<KWmmaPerblock>{},
...@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1; constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
...@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
......
...@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma ...@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
} }
else else
{ {
constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1; constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, make_tuple(Number<KWmmaPerblock>{},
...@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma ...@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
} }
else else
{ {
constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK / 2 / K1; constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, make_tuple(Number<KWmmaPerblock>{},
...@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma ...@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1; constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
...@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma ...@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment