Commit fc62babb authored by Jing Zhang's avatar Jing Zhang
Browse files

seperate gfx12 blockwise_gemm

parent f3111877
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
namespace ck { namespace ck {
#ifdef __gfx12__
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -66,13 +67,8 @@ struct BlockwiseGemmWMMA ...@@ -66,13 +67,8 @@ struct BlockwiseGemmWMMA
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation // permutation
#ifdef __gfx12__ static constexpr index_t A_KRow = 2;
static constexpr index_t A_KRow = 1; static constexpr index_t B_KRow = 2;
static constexpr index_t B_KRow = 1;
#else
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
#endif
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
...@@ -114,11 +110,7 @@ struct BlockwiseGemmWMMA ...@@ -114,11 +110,7 @@ struct BlockwiseGemmWMMA
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
#ifdef __gfx12__
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
#else
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
#endif
} }
else else
{ {
...@@ -135,11 +127,7 @@ struct BlockwiseGemmWMMA ...@@ -135,11 +127,7 @@ struct BlockwiseGemmWMMA
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
#ifdef __gfx12__
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
#else
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
#endif
} }
else else
{ {
...@@ -203,6 +191,9 @@ struct BlockwiseGemmWMMA ...@@ -203,6 +191,9 @@ struct BlockwiseGemmWMMA
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0, NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
static_assert(AEnableLds == true, "only support EnableLds");
static_assert(BEnableLds == true, "only support EnableLds");
} }
// transposed WMMA output C' = B' * A' // transposed WMMA output C' = B' * A'
...@@ -303,7 +294,6 @@ struct BlockwiseGemmWMMA ...@@ -303,7 +294,6 @@ struct BlockwiseGemmWMMA
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
#ifdef __gfx12__
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -428,7 +418,347 @@ struct BlockwiseGemmWMMA ...@@ -428,7 +418,347 @@ struct BlockwiseGemmWMMA
}); });
} }
} }
protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
#else #else
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename ABlockDesc,
typename BBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct BlockwiseGemmWMMA
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -560,28 +890,8 @@ struct BlockwiseGemmWMMA ...@@ -560,28 +890,8 @@ struct BlockwiseGemmWMMA
}); });
} }
} }
#endif
protected: protected:
#ifdef __gfx12__
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
#else
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{}, make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{}, Number<MRepeat>{},
...@@ -609,7 +919,6 @@ struct BlockwiseGemmWMMA ...@@ -609,7 +919,6 @@ struct BlockwiseGemmWMMA
Number<B_K1>{}, Number<B_K1>{},
Number<B_K1>{}, Number<B_K1>{},
Number<1>{})); Number<1>{}));
#endif
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -626,11 +935,7 @@ struct BlockwiseGemmWMMA ...@@ -626,11 +935,7 @@ struct BlockwiseGemmWMMA
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
#ifdef __gfx12__
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
#else
Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>, Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
#endif
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1,
...@@ -666,11 +971,7 @@ struct BlockwiseGemmWMMA ...@@ -666,11 +971,7 @@ struct BlockwiseGemmWMMA
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
#ifdef __gfx12__
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
#else
Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>, Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
#endif
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1,
...@@ -698,5 +999,6 @@ struct BlockwiseGemmWMMA ...@@ -698,5 +999,6 @@ struct BlockwiseGemmWMMA
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_; typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
}; };
#endif
} // namespace ck } // namespace ck
...@@ -94,13 +94,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -94,13 +94,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true; (MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
#ifdef __gfx12__
static constexpr auto AEnableLds_manu = true;
static constexpr auto BEnableLds_manu = true;
#else
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false; static constexpr auto BEnableLds_manu = false;
#endif
static constexpr auto AEnableLds = static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
false; // AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds =
true; // BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment