Commit 6a9d7b64 authored by aska-0096's avatar aska-0096
Browse files

temp save

parent d4adc71a
...@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault, GemmDefault,
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 16, // NPerBlock
64, // KPerBlock 32, // KPerBlock
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
1, // M Repeat 1, // M Repeat
8, // N-Repeat 1, // N-Repeat
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, 8,
8, 8,
true, true,
S<4, 64, 1>, S<4, 16, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
...@@ -59,8 +59,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -59,8 +59,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, 8,
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
4, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>, S<1, 128, 1, 2>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -94,12 +94,14 @@ using DeviceGemmInstance = ...@@ -94,12 +94,14 @@ using DeviceGemmInstance =
TensorSpecB1, TensorSpecB1,
TensorSpecC, TensorSpecC,
256, 256,
// Gemm 0
128, // MPerBlock 128, // MPerBlock
128, // LPerBlock 128, // LPerBlock
4, // K0PerBlock 32, // KPerBlock
8, // K1 8, // K1
// Gemm 1
64, // NPerBlock 64, // NPerBlock
4, // L0PerBlock 32, // LPerBlock
8, // L1 8, // L1
16, // MPerWMMA 16, // MPerWMMA
16, // LPerWMMA 16, // LPerWMMA
......
...@@ -53,10 +53,10 @@ template <index_t NumDimG, ...@@ -53,10 +53,10 @@ template <index_t NumDimG,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t LPerBlock, ck::index_t LPerBlock,
ck::index_t K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim ck::index_t KPerBlock,
ck::index_t K1, // ck::index_t K1,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t L0PerBlock, ck::index_t LPerBlock,
ck::index_t L1, ck::index_t L1,
ck::index_t MPerWMMA, ck::index_t MPerWMMA,
ck::index_t LPerWMMA, ck::index_t LPerWMMA,
...@@ -128,8 +128,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -128,8 +128,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr index_t NumDimGemm1N = NumDimN; static constexpr index_t NumDimGemm1N = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimL; static constexpr index_t NumDimGemm1K = NumDimL;
static constexpr index_t KPerBlock = K0PerBlock * K1;
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle; using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -137,6 +135,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -137,6 +135,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = LWaves == 1 ? false : true;
// static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
// static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>, Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>, Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
...@@ -146,13 +153,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -146,13 +153,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1Spec, B1Spec,
CSpec>; CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
if constexpr(AEnableLds)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<K1>{}); Number<K1>{});
} }
else
{
return Transform::MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
WmmaK, Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{}, Number<K1>{})
}
}
static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec, static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec) const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
...@@ -170,7 +187,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -170,7 +187,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number<L1>{}); Number<L1>{});
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor_BK0_L_BK1({}, {})); using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor_BK0_L_BK1({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {})); using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
...@@ -250,17 +267,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -250,17 +267,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor // InMemory Data Descriptor
AGridDesc_AK0_M_AK1, AGridDesc,
B0GridDesc_BK0_L_BK1, B0GridDesc_BK0_L_BK1,
B1GridDesc_BL0_N_BL1, B1GridDesc_BL0_N_BL1,
CGridDesc_M_N, CGridDesc_M_N,
// Tiling Family // Tiling Family
MPerBlock, MPerBlock,
LPerBlock, LPerBlock,
K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim KPerBlock,
K1, // K1,
NPerBlock, NPerBlock,
L0PerBlock, LPerBlock,
L1, L1,
MPerWMMA, MPerWMMA,
LPerWMMA, LPerWMMA,
...@@ -277,6 +294,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -277,6 +294,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
true, true,
AEnableLds,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferThreadClusterArrangeOrder,
...@@ -285,6 +303,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -285,6 +303,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0BlockTransferSrcScalarPerVector, B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_K1, B0BlockTransferDstScalarPerVector_K1,
true, true,
B0EnableLds,
B0BlockLdsAddExtraL, B0BlockLdsAddExtraL,
B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
...@@ -293,6 +312,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -293,6 +312,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1BlockTransferSrcScalarPerVector, B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_L1, B1BlockTransferDstScalarPerVector_L1,
false, false,
B1EnableLds,
B1BlockLdsAddExtraN, B1BlockLdsAddExtraN,
CShuffleMRepeatPerShuffle, CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle, CShuffleNRepeatPerShuffle,
...@@ -338,7 +358,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -338,7 +358,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{DeviceOp::MakeB0GridDescriptor_BK0_L_BK1( b0_grid_desc_bk0_l_bk1_{DeviceOp::MakeB0GridDescriptor_BK0_L_BK1(
b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1( b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
...@@ -404,7 +424,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -404,7 +424,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c_grid_; CDataType* p_c_grid_;
// Tensor Descriptors // Tensor Descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc a_grid_desc_ak0_m_ak1_;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_; B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_; B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -463,7 +483,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -463,7 +483,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc,
DeviceOp::B0GridDesc_BK0_L_BK1, DeviceOp::B0GridDesc_BK0_L_BK1,
DeviceOp::B1GridDesc_BL0_N_BL1, DeviceOp::B1GridDesc_BL0_N_BL1,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -741,11 +761,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -741,11 +761,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< LPerBlock << ", " << LPerBlock << ", "
<< K0PerBlock << ", " << KPerBlock << ", "
<< K1 << ", " << K1 << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< L0PerBlock << ", " << LPerBlock << ", "
<< L1 << L1
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -134,10 +134,10 @@ template <typename FloatA, ...@@ -134,10 +134,10 @@ template <typename FloatA,
typename CGridDesc_M_N, typename CGridDesc_M_N,
index_t MPerBlock, index_t MPerBlock,
index_t LPerBlock, index_t LPerBlock,
index_t K0PerBlock, // K0 * K1Value = Gemm0 GEMM_K Dim index_t KPerBlock,
index_t K1Value, index_t K1Value,
index_t NPerBlock, index_t NPerBlock,
index_t L0PerBlock, index_t LPerBlock,
index_t L1Value, index_t L1Value,
index_t MPerWmma, index_t MPerWmma,
index_t LPerWmma, index_t LPerWmma,
...@@ -153,6 +153,7 @@ template <typename FloatA, ...@@ -153,6 +153,7 @@ template <typename FloatA,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
bool AEnableLds,
bool ABlockLdsExtraM, bool ABlockLdsExtraM,
typename B0BlockTransferThreadClusterLengths_K0_L_K1, typename B0BlockTransferThreadClusterLengths_K0_L_K1,
typename B0BlockTransferThreadClusterArrangeOrder, typename B0BlockTransferThreadClusterArrangeOrder,
...@@ -161,6 +162,7 @@ template <typename FloatA, ...@@ -161,6 +162,7 @@ template <typename FloatA,
index_t B0BlockTransferSrcScalarPerVector, index_t B0BlockTransferSrcScalarPerVector,
index_t B0BlockTransferDstScalarPerVector_K1, index_t B0BlockTransferDstScalarPerVector_K1,
bool B0ThreadTransferSrcResetCoordinateAfterRun, bool B0ThreadTransferSrcResetCoordinateAfterRun,
bool B0EnableLds,
bool B0BlockLdsExtraN, bool B0BlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_L0_N_L1, typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
...@@ -169,6 +171,7 @@ template <typename FloatA, ...@@ -169,6 +171,7 @@ template <typename FloatA,
index_t B1BlockTransferSrcScalarPerVector, index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_L1, index_t B1BlockTransferDstScalarPerVector_L1,
bool B1ThreadTransferSrcResetCoordinateAfterRun, bool B1ThreadTransferSrcResetCoordinateAfterRun,
bool B1EnableLds,
bool B1BlockLdsExtraN, bool B1BlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
...@@ -190,38 +193,124 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -190,38 +193,124 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto AK0 = Number<K0PerBlock>{};
static constexpr auto AK1 = Number<K1Value>{}; static constexpr auto AK1 = Number<K1Value>{};
static constexpr auto BK0 = Number<K0PerBlock>{}; static constexpr auto BK0 = Number<KPerBlock/K1Value>{};
static constexpr auto BK1 = Number<K1Value>{}; static constexpr auto BK1 = Number<K1Value>{};
static constexpr auto L0PerBlock = LPerBlock / L1Value;
static constexpr auto AL0 = Number<L0PerBlock / 2>{}; static constexpr auto AL0 = Number<L0PerBlock / 2>{};
static constexpr auto AL1 = Number<L1Value>{}; static constexpr auto AL1 = Number<L1Value>{};
static constexpr auto BL0 = Number<L0PerBlock>{}; static constexpr auto BL0 = Number<L0PerBlock>{};
static constexpr auto BL1 = Number<L1Value>{}; static constexpr auto BL1 = Number<L1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, AEnableLds, B0EnableLds,NumGemmKPrefetchStage, LoopSched>())>;
template <typename A0BlockDesc_AK0_M_AK1> __host__ __device__ static constexpr auto MakeABlockDescriptor()
__host__ __device__ static constexpr auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1(const A0BlockDesc_AK0_M_AK1&)
{ {
constexpr index_t A_K0 = A0BlockDesc_AK0_M_AK1{}.GetLength(I0); constexpr auto a_block_desc = [&]() {
constexpr index_t A_K1 = A0BlockDesc_AK0_M_AK1{}.GetLength(I2); if constexpr(AEnableLds)
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); {
// K0->M->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / AK1;
constexpr auto max_lds_align = AK1;
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + 1>{} * AK1, AK1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, AK1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<MRepeat>{} * AK1, AK1, AK1, AK1, AK1, I1));
}
}();
return a_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
if constexpr(AEnableLds)
{
constexpr auto K0PerBlock = KPerBlock / AK1;
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return a_block_copy_step;
}
// Describe how data read from (LDS/VGPR) buffer
template <typename ABlockDesc_>
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
{
constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor( return transform_tensor_descriptor(
A0BlockDesc_AK0_M_AK1{}, ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}), make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform( make_unmerge_transform(make_tuple(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
}
}();
return a_wave_desc;
}
template <typename B0BlockDesc_BK0_L_BK1> template <typename B0BlockDesc_BK0_L_BK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -273,14 +362,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -273,14 +362,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1() __host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
...@@ -318,19 +399,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -318,19 +399,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
const index_t gemm0_bytes_end = const index_t gemm0_bytes_end =
(SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) + (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0)); SharedMemTrait::b0_block_space_size_aligned);
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned);
sizeof(FloatB1);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned
sizeof(FloatAcc0);
const index_t c_block_bytes_end = const index_t c_block_bytes_end = SharedMemTrait::c_block_space_size;
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
} }
...@@ -434,38 +512,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -434,38 +512,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b0_block_desc_bk0_l_bk1 =
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
static constexpr auto b1_block_desc_bl0_n_bl1 =
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = AEnableLds ? math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); MakeABlockDescriptor().GetElementSpaceSize() * sizeof(FloatA), max_lds_align) : 0;
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b0_block_space_size_aligned = B0EnableLds ? math::integer_least_multiple(
b0_block_desc_bk0_l_bk1.GetElementSpaceSize(), max_lds_align); GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1().GetElementSpaceSize() * sizeof(FloatB0), max_lds_align) : 0;
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b1_block_space_size_aligned = B1EnableLds ? math::integer_least_multiple(
b1_block_desc_bl0_n_bl1.GetElementSpaceSize(), max_lds_align); GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1().GetElementSpaceSize() * sizeof(FloatB1), max_lds_align) : 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a_block_space_size_aligned.value; static constexpr auto b0_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0; static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction // LDS allocation for reduction
// Feature to add, IntraThread Reduction
static constexpr index_t reduction_space_size_aligned = static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align); math::integer_least_multiple(BlockSize, max_lds_align) * sizeof(FloatAcc0);
static constexpr auto reduction_space_offset = 0; static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
static constexpr auto c_block_space_size = static constexpr auto c_block_space_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
.GetElementSpaceSize(); .GetElementSpaceSize() * sizeof(FloatCShuffle);
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -520,12 +590,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -520,12 +590,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K = [&](){
if constexpr(AEnableLds){
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
}
}();
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1(); constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
auto a_block_trait = [&](){
// A matrix blockwise copy // A matrix blockwise copy
if constexpr(AEnableLds)
{
constexpr auto AK0PerBlock = KPerBlock/ AK1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset,
SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ AElementwiseOperation, /* typename SrcElementwiseOperation, */ AElementwiseOperation,
...@@ -537,7 +621,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -537,7 +621,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename SrcData, */ FloatA, /* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA, /* typename DstData, */ FloatA,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
...@@ -551,10 +635,49 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -551,10 +635,49 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_k0perblock_mperblock_k1, a_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(a_block_buf, a_blockwise_copy);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA,
FloatA,
decltype(a_grid_desc),
decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{},
I1,
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc,
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(a_block_buf, a_blockwise_copy);
}
};
// B matrix blockwise copy // B matrix blockwise copy
auto b0_blockwise_copy = auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -585,6 +708,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -585,6 +708,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
auto a_block_buf = a_block_trait()[I0];
auto a_blockwise_copy = a_block_trait()[I1];
/*******************************************************************************/ /*******************************************************************************/
// Gemm0 // Gemm0
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
...@@ -595,7 +721,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -595,7 +721,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
FloatA, FloatA,
FloatB0, FloatB0,
FloatAcc0, FloatAcc0,
decltype(MakeA0BlockDescriptor_K0_M0_M1_M2_K1(a_block_desc_k0perblock_mperblock_k1)), decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)), decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)),
MPerBlock, MPerBlock,
LPerBlock, LPerBlock,
...@@ -632,18 +758,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -632,18 +758,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset, auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset,
b0_block_desc_k0perblock_lperblock_k1.GetElementSpaceSize()); b0_block_desc_k0perblock_lperblock_k1.GetElementSpaceSize());
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
constexpr auto b0_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b0_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto a_block_reset_copy_step = [&](){
if constexpr(AEnableLds){
return make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
else{
return make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0, 0, 0, 0);
}
}();
const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0); const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0);
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
/*******************************************************************************/ /*******************************************************************************/
// softmax // softmax
/*******************************************************************************/ /*******************************************************************************/
...@@ -734,7 +867,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -734,7 +867,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// 0x76543210 0xfedcba98 // 0x76543210 0xfedcba98
// src Rowlane // src Rowlane
0x76543210, 0xfedcba98, 0x76543210, 0xfedcba98,
false>{tensor_operation::element_wise::PassThrough{}}; false>{};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
...@@ -815,7 +948,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -815,7 +948,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
} }
// gemm0 start, A-B swaped // gemm0 start, A-B swaped
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
a_block_desc_k0perblock_mperblock_k1, a_block_desc,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
...@@ -828,7 +961,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -828,7 +961,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
b0_block_slice_copy_step, b0_block_slice_copy_step,
blockwise_gemm0, blockwise_gemm0,
acc0_thread_buf, acc0_thread_buf,
K0BlockMainLoop); KBlockMainLoop);
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN) if constexpr(MaskOutUpperTriangle || PadN)
{ {
......
...@@ -343,7 +343,7 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -343,7 +343,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_buf = a_block_buf_switch; // a_block_buf = a_block_buf_switch;
++i; ++i;
} while(i < (num_loop - 1)); } while(i < (num_loop - 1));
} }
......
...@@ -130,8 +130,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -130,8 +130,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto B_K0 = BGridDesc_K0_N_K1{}.GetLength(I0);
static constexpr auto B_K1 = BGridDesc_K0_N_K1{}.GetLength(I2);
// FIX ME: To be deprecated // FIX ME: To be deprecated
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -273,6 +271,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -273,6 +271,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&) MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
{ {
constexpr auto B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{}, BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}), make_tuple(make_pass_through_transform(Number<B_K0>{}),
...@@ -528,8 +529,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -528,8 +529,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
} }
}(); }();
// printf("---------------K = %d\n", K);
constexpr auto a_block_desc = MakeABlockDescriptor(); constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
...@@ -703,7 +702,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -703,7 +702,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/ /*******************************************************************************/
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
// printf("a_block_slice_copy_step FirstKdim = %d\n", a_block_slice_copy_step[I0]);
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline // gridwise GEMM pipeline
......
...@@ -1395,34 +1395,28 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1395,34 +1395,28 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation // apply element-wise operation
element_op_(v_this_row, src_buf[Number<src_offset>{}]); element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm) if constexpr(IntraRowSwizzlePerm)
{ {
temp = __builtin_amdgcn_permlane16( // 0x76543210, 0xfedcba98 // temp = __builtin_amdgcn_permlane16(
temp, // temp,
type_convert<int>(v_this_row), // type_convert<int>(v_this_row),
0xb3a29180, // 0xb3a29180,
0xf7e6d5c4, // 0xf7e6d5c4,
1, // 1,
0); // 0);
v_this_row = type_convert<SrcData>(temp); v_this_row = type_convert<SrcData>(temp);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
} }
// apply inter-row permute. // apply inter-row permute.
temp = __builtin_amdgcn_permlanex16(temp, // temp = __builtin_amdgcn_permlanex16(temp,
type_convert<int>(v_this_row), // type_convert<int>(v_this_row),
LowEightRowlaneIdx, // LowEightRowlaneIdx,
HighEightRowLaneIdx, // HighEightRowLaneIdx,
1, // 1,
0); // 0);
v_theother_row = type_convert<SrcData>(temp); v_theother_row = type_convert<SrcData>(temp);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if(get_thread_local_1d_id() % 32 < 16) if(get_thread_local_1d_id() % 32 < 16)
{ {
// apply type convert // apply type convert
......
...@@ -179,6 +179,26 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -179,6 +179,26 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template <typename AGridDesc_M_K, typename Number>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
const AGridDesc_M_K& a_grid_desc_m_k, const Number& WmmaK, const Number& MRepeat,
const Number& MWaves, const Number& MPerWmma, const Number& AK1)
{
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlcok;
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AKWmma = K / WmmaK;
constexpr auto AKRow = WmmaK / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AKWmma, Number<AKRow>{}, AK1)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, MWaves, MPerWmma))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
// //
// B (alias of B0) // B (alias of B0)
// //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment