Commit 579f84c6 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 7e003d31
...@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault, GemmDefault,
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
16, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
1, // M Repeat 2, // M Repeat
1, // N-Repeat 4, // N-Repeat
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 128, 1, 2>, S<1, 64, 1, 4>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 4: case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break; break;
default: default:
......
...@@ -129,7 +129,7 @@ using DeviceGemmInstance = ...@@ -129,7 +129,7 @@ using DeviceGemmInstance =
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
8, 8,
1, 1, // be eight?
false, false,
1, // CShuffleMWmmaPerWavePerShuffle 1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle 2, // CShuffleNWmmaPerWavePerShuffle
......
...@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf("Warm up 1 time\n"); printf("Warm up 1 time\n");
#endif #endif
// warm up // warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); // kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 100; const int nrepeat = 1;
#if DEBUG_LOG #if DEBUG_LOG
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
#endif #endif
......
...@@ -27,6 +27,8 @@ template <index_t BlockSize, ...@@ -27,6 +27,8 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false> bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data /* Option: Read from LDS, big buffer hold all threads required data
* Source * Source
...@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA ...@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
static constexpr bool AEnableLds = NWaves == 1 ? false : true;
static constexpr bool BEnableLds = MWaves == 1 ? false : true;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication. // Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1; static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1; static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
......
...@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds = NWaves == 1 ? false : true; static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true; static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// static constexpr auto AEnableLds = true;
// static constexpr auto BEnableLds = true;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory // Describe how data read from Global memory
......
...@@ -45,7 +45,7 @@ __global__ void ...@@ -45,7 +45,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -84,7 +84,7 @@ __global__ void ...@@ -84,7 +84,7 @@ __global__ void
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_batch_offset,
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -98,7 +98,7 @@ __global__ void ...@@ -98,7 +98,7 @@ __global__ void
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
...@@ -115,7 +115,7 @@ template <typename GridwiseOp, ...@@ -115,7 +115,7 @@ template <typename GridwiseOp,
typename BDataType, typename BDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename AGridDesc_K0_M_K1, typename AGridDesc,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -135,7 +135,7 @@ __global__ void ...@@ -135,7 +135,7 @@ __global__ void
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc a_grid_desc,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -176,7 +176,7 @@ __global__ void ...@@ -176,7 +176,7 @@ __global__ void
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_batch_offset,
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -193,7 +193,7 @@ __global__ void ...@@ -193,7 +193,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -207,7 +207,7 @@ template <typename GridwiseOp, ...@@ -207,7 +207,7 @@ template <typename GridwiseOp,
typename BDataType, typename BDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename AGridDesc_K0_M_K1, typename AGridDesc,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -225,7 +225,7 @@ __global__ void ...@@ -225,7 +225,7 @@ __global__ void
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc a_grid_desc,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -244,7 +244,7 @@ __global__ void ...@@ -244,7 +244,7 @@ __global__ void
p_ds_grid, p_ds_grid,
p_e_grid, p_e_grid,
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -257,7 +257,7 @@ __global__ void ...@@ -257,7 +257,7 @@ __global__ void
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -276,7 +276,7 @@ template < // DataType Family ...@@ -276,7 +276,7 @@ template < // DataType Family
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
// InMemory Data Descriptor // InMemory Data Descriptor
typename AGridDesc_K0_M_K1, typename AGridDesc,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename DsGridDesc_M_N, typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
...@@ -288,7 +288,7 @@ template < // DataType Family ...@@ -288,7 +288,7 @@ template < // DataType Family
// Tiling Family // Tiling Family
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t KPerBlock,
index_t MPerWmma, index_t MPerWmma,
index_t NPerWmma, index_t NPerWmma,
index_t K1Value, index_t K1Value,
...@@ -303,6 +303,7 @@ template < // DataType Family ...@@ -303,6 +303,7 @@ template < // DataType Family
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
bool AEnableLds,
bool ABlockLdsExtraM, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
...@@ -311,6 +312,7 @@ template < // DataType Family ...@@ -311,6 +312,7 @@ template < // DataType Family
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BEnableLds,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
...@@ -335,36 +337,161 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -335,36 +337,161 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
AEnableLds,
BEnableLds,
NumGemmKPrefetchStage,
LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() // Describe how data store to (LDS/VGPR) buffer from Global memory
__host__ __device__ static constexpr auto MakeABlockDescriptor()
{ {
constexpr auto max_lds_align = K1; constexpr auto a_block_desc = [&]() {
if constexpr(AEnableLds)
{
// K0->M->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM)
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { {
if constexpr(ABlockLdsExtraM) return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}
else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1));
}
}();
return a_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
if constexpr(AEnableLds)
{
constexpr auto K0PerBlock = KPerBlock / K1;
return make_multi_index(K0PerBlock, 0, 0);
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
} }
}(); }();
return a_block_desc_k0perblock_mperblock_k1; return a_block_copy_step;
}
__host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
{
constexpr auto b_block_copy_step = [&]() {
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock / K1;
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return b_block_copy_step;
}
// Describe how data read from (LDS/VGPR) buffer
template <typename ABlockDesc_>
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
{
constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
}
}();
return a_wave_desc;
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
{
constexpr auto B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto K0PerBlock = KPerBlock / K1;
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
...@@ -416,28 +543,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -416,28 +543,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0perblock_mperblock_k1 = const index_t gemm_bytes_end =
GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType)+
SharedMemTrait::b_block_space_size_aligned * sizeof(BDataType);
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); const index_t c_block_bytes_end =
SharedMemTrait::c_shuffle_block_space_size * sizeof(CShuffleDataType);
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(ADataType) + return math::max(gemm_bytes_end, c_block_bytes_end);
b_block_space_size_aligned * sizeof(BDataType));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n, const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
...@@ -450,9 +569,41 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -450,9 +569,41 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
(NPerBlock % (NRepeat * NPerWmma)) == 0, (NPerBlock % (NRepeat * NPerWmma)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto GetAProblemsizeMK = [&]() {
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); if constexpr(AEnableLds)
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); {
return make_tuple(a_grid_desc.GetLength(I1),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
}
else
{
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5));
}
};
const auto GetBProblemsizeNK = [&]() {
if constexpr(BEnableLds)
{
return make_tuple(b_grid_desc_k0_n_k1.GetLength(I1),
b_grid_desc_k0_n_k1.GetLength(I0) *
b_grid_desc_k0_n_k1.GetLength(I2));
}
else
{
return make_tuple(
b_grid_desc_k0_n_k1.GetLength(I1) * b_grid_desc_k0_n_k1.GetLength(I2) *
b_grid_desc_k0_n_k1.GetLength(I4),
b_grid_desc_k0_n_k1.GetLength(I0) * b_grid_desc_k0_n_k1.GetLength(I3) *
b_grid_desc_k0_n_k1.GetLength(I5));
}
};
const auto M = GetAProblemsizeMK()[I0];
const auto N = GetBProblemsizeNK()[I0];
const auto K = GetAProblemsizeMK()[I1];
bool valid = true; bool valid = true;
...@@ -468,21 +619,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -468,21 +619,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
} }
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K == GetBProblemsizeNK()[I1]))
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
{ {
printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); printf("GridwiseOp: ABE descriptor dimension cross check failure\n");
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{ {
printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); printf("GridwiseOp: Problemsize descriptor dimension check failure\n");
return false; return false;
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock; const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
...@@ -546,6 +696,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -546,6 +696,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
e_grid_desc_m_n); e_grid_desc_m_n);
} }
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto max_lds_align = K1;
static constexpr auto a_block_space_size_aligned =
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
max_lds_align): 0;
static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
max_lds_align): 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
.GetElementSpaceSize();
static constexpr auto c_shuffle_block_space_offset = 0;
};
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
...@@ -560,7 +735,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -560,7 +735,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -575,7 +750,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -575,7 +750,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/*******************************************************************************/ /*******************************************************************************/
// Memory buffer zone. // Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple( const auto ds_grid_buf = generate_tuple(
...@@ -603,23 +778,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -603,23 +778,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/*******************************************************************************/ /*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K = [&](){
constexpr auto max_lds_align = K1; if constexpr(AEnableLds){
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); }
// A matrix blockwise copy else{
auto a_blockwise_copy = return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, }
}();
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
auto a_block_trait = [&](){
// A matrix blockwise copy
if constexpr(AEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared),
a_block_desc.GetElementSpaceSize());
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
/* typename SrcElementwiseOperation, */ AElementwiseOperation, /* typename SrcElementwiseOperation, */ AElementwiseOperation,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, /* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, /* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>, /* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ ADataType, /* typename SrcData, */ FloatA,
/* typename DstData, */ ADataType, /* typename DstData, */ FloatA,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
...@@ -630,62 +821,138 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -630,62 +821,138 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/* index_t DstScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( /* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
a_grid_desc_k0_m_k1, a_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_k0perblock_mperblock_k1, a_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy return make_tuple(a_block_buf, a_blockwise_copy);
auto b_blockwise_copy = }
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, else
BElementwiseOperation, {
ck::tensor_operation::element_wise::PassThrough, // Thread-wise copy
InMemoryDataOperationEnum::Set, // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
Sequence<K0PerBlock, NPerBlock, K1>, constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
BBlockTransferThreadClusterLengths_K0_N_K1, auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
BBlockTransferThreadClusterArrangeOrder, a_block_desc.GetElementSpaceSize());
BDataType,
BDataType, // Limitation: NumDim of Src and Dst descriptor should be identical
decltype(b_grid_desc_k0_n_k1), auto a_blockwise_copy =
decltype(b_block_desc_k0perblock_nperblock_k1), ThreadwiseTensorSliceTransfer_v2<FloatA,
BBlockTransferSrcAccessOrder, FloatA,
Sequence<0, 1, 2>, decltype(a_grid_desc),
BBlockTransferSrcVectorDim, decltype(a_block_desc),
2, Sequence<Number<KWmmaPerBlock>{},
BBlockTransferSrcScalarPerVector, Number<MRepeat>{},
BBlockTransferDstScalarPerVector_K1, I1,
1, I1,
1, I1,
BThreadTransferSrcResetCoordinateAfterRun, Number<K1Value>{}>,
true>( Sequence<0, 1, 2, 3, 4, 5>,
b_grid_desc_k0_n_k1, 5,
make_multi_index(0, n_block_data_idx_on_grid, 0), ABlockTransferSrcScalarPerVector,
b_element_op, AThreadTransferSrcResetCoordinateAfterRun,
b_block_desc_k0perblock_nperblock_k1, true>(
make_multi_index(0, 0, 0), a_grid_desc,
ck::tensor_operation::element_wise::PassThrough{}); make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(a_block_buf, a_blockwise_copy);
}
};
auto b_block_trait = [&](){
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
FloatB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(b_block_buf, b_blockwise_copy);
}
else
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc),
Sequence<Number<K0PerBlock>{},
Number<NRepeat>{},
Number<K1Value>{}>,
Sequence<0, 1, 2>,
2,
BBlockTransferSrcScalarPerVector,
1>(
make_multi_index(0, get_thread_local_1d_id()/32 * 16 + get_thread_local_1d_id() % 16, 0));
return make_tuple(b_block_buf, b_blockwise_copy);
}
};
auto a_block_buf = a_block_trait()[I0];
auto a_blockwise_copy = a_block_trait()[I1];
auto b_block_buf = b_block_trait()[I0];
auto b_blockwise_copy = b_block_trait()[I1];
/*******************************************************************************/ /*******************************************************************************/
// GEMM // GEMM
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc)),
MPerWmma, MPerBlock,
NPerWmma, NPerBlock,
MRepeat, KPerBlock,
NRepeat, MPerWmma,
KPack, NPerWmma,
false, MRepeat,
true>{}; NRepeat,
KPack>{};
// Prepare Register for C matrix // Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -702,7 +969,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -702,7 +969,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc_k0perblock_mperblock_k1, a_block_desc_k0perblock_mperblock_k1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
......
...@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true> ...@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Enabled, Mat-B Lds Enabled\n");
// preload data into LDS // preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
...@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
}, },
Number<a_block_desc.GetLengths().GetSize()>{}); Number<a_block_desc.GetLengths().GetSize()>{});
#endif #endif
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Disabled, Mat-B Lds Enabled\n");
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0); constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf; auto a_block_buf_switch = a_block_buf;
......
...@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NPerWmma, NPerWmma,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
AEnableLds,
BEnableLds>{};
// Prepare Register for C matrix // Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment