Commit 840a617d authored by Wenkai's avatar Wenkai
Browse files

use static kernel

parent 04da3554
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_splitk_c_shuffle.hpp" #include "device_gemm_xdl_splitk_c_shuffle.hpp"
#include "device_gemm_xdl_splitk_c_shuffle_static.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
...@@ -44,6 +45,19 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -44,6 +45,19 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#if USEING_STATIC_KERNEL
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffleStatic
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 3, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 2, S<1, 4, 1, 64>, 2>;
<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 2, S<1, 4, 1, 64>, 2>;
//<Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 16, 128, 128, 8, 8, 16, 16, 1, 2, S<1, 16, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 2, S<1, 4, 1, 64>, 2>;
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 4, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 2, S<1, 4, 1, 64>, 2>;
#else
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...@@ -52,9 +66,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu ...@@ -52,9 +66,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 3, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 1, S<1, 16, 1, 16>, 2>; <Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 3, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 1, S<1, 16, 1, 16>, 2>;
//<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 4, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 1, S<1, 16, 1, 16>, 2>; //<Row, Row, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 4, 256, 16, 128, 32, 8, 2, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 8, 1, 1, S<1, 16, 1, 16>, 2>;
// clang-format on // clang-format on
#endif
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
#pragma once
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "device_gemm.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "gemm_specialization.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 1
#endif
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmXdlSplitKCShuffleStatic
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto IM = Number<M_matrix>{};
static constexpr auto IN = Number<N_matrix>{};
static constexpr auto IK = Number<K_matrix>{};
static constexpr auto IKBatch = Number<K_batch>{};
static auto
MakeAGridDescriptor_KBatch_K0_M_K1()
{
static constexpr auto K = IK;
static constexpr auto KBatch = IKBatch;
static constexpr auto M = IM;
static constexpr auto StrideA = IK;
static constexpr auto KPad = Number<math::integer_divide_ceil(IK, KPerBlock * IKBatch) * (KPerBlock * IKBatch)>{};
assert(KPad % (AK1 * KBatch) == 0);
static constexpr auto AK0 = Number<KPad / (AK1 * KBatch)>{};
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, Number<AK1>{})),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto
MakeBGridDescriptor_KBatch_K0_N_K1()
{
static constexpr auto K = IK;
static constexpr auto KBatch = IKBatch;
static constexpr auto N = IN;
static constexpr auto StrideB = IN;
static constexpr auto KPad = Number<math::integer_divide_ceil(IK, KPerBlock * IKBatch) * (KPerBlock * IKBatch)>{};
assert(KPad % (BK1 * KBatch) == 0);
constexpr auto BK0 = Number<KPad / (BK1 * KBatch)>{};
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding)
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, Number<BK1>{})),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
static auto MakeCGridDescriptor_M_N()
{
static constexpr auto M = IM;
static constexpr auto N = IN;
static constexpr auto StrideC = IN;
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if(c_grid_desc_m_n.IsKnownAtCompileTime())
printf("c_grid_desc_m_n yes\n");
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
}
static auto GetKPad(index_t K, index_t KBatch)
{
const index_t KPad = math::integer_divide_ceil(K, KPerBlock * KBatch) * (KPerBlock * KBatch);
return KPad;
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1());
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1());
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N());
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
NumGemmKPrefetchStage,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
NumGemmKPrefetchStage,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock_Static(CGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t,
index_t,
index_t,
index_t,
index_t,
index_t,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t k_batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
k_batch_{k_batch}
{
//int KPad = DeviceGemmXdlSplitKCShuffleStatic::GetKPad(K, k_batch_);
a_grid_desc_kbatch_k0_m_k1_ =
DeviceGemmXdlSplitKCShuffleStatic::MakeAGridDescriptor_KBatch_K0_M_K1();
b_grid_desc_kbatch_k0_n_k1_ =
DeviceGemmXdlSplitKCShuffleStatic::MakeBGridDescriptor_KBatch_K0_N_K1();
c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffleStatic::MakeCGridDescriptor_M_N();
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptorStatic(c_grid_desc_m_n_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock_Static(c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t k_batch_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmXdlSplitKCShuffleStatic::Argument;
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) * arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K);
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
if(arg.a_grid_desc_kbatch_k0_m_k1_.IsKnownAtCompileTime())
printf("a_grid_desc_kbatch_k0_m_k1_ known at compile time\n");
if(arg.b_grid_desc_kbatch_k0_n_k1_.IsKnownAtCompileTime())
printf("b_grid_desc_kbatch_k0_n_k1_ known at compile time\n");
if(arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.IsKnownAtCompileTime())
printf("c_grid_desc_mblock_mperblock_nblock_nperblock_ known at compile time\n");
//if(arg.block_2_ctile_map_.IsKnownAtCompileTime())
// printf("block_2_ctile_map_ known at compile time\n");
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
if(has_main_k0_block_loop)
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::Block2CTileMap>,
true>;
Run(kernel);
}
}
else
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::Block2CTileMap>,
false>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdlSplitKCShuffleStatic::Block2CTileMap>,
false>;
Run(kernel);
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op,
KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op,
KBatch);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdlSplitKCShuffleStatic"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -239,6 +239,79 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -239,6 +239,79 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
}; };
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static() = default;
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static(const CGridDesc_M_N& c_grid_desc_m_n)
: c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const index_t grid_size = M0 * N0 * KSplit_;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
#if 1
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
#else
const index_t idx_ksplit = block_1d_id % KSplit_;
block_1d_id = block_1d_id / KSplit_;
#endif
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_ksplit,
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
static constexpr auto M01_ = Number<8>{};
static constexpr auto KSplit_ = Number<K_batch>{};
CGridDesc_M_N c_grid_desc_m_n_;
};
// Blocks of row-vectors // Blocks of row-vectors
template <index_t MPerBlock, template <index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
......
...@@ -139,12 +139,12 @@ struct GridwiseGemmPipeline_v2<2> ...@@ -139,12 +139,12 @@ struct GridwiseGemmPipeline_v2<2>
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{ {
// TODO: improve applicability // TODO: improve applicability
return num_loop > 2; return num_loop >= 3;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop > 2; return num_loop >= 5;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -179,20 +179,31 @@ struct GridwiseGemmPipeline_v2<2> ...@@ -179,20 +179,31 @@ struct GridwiseGemmPipeline_v2<2>
index_t num_loop) index_t num_loop)
{ {
// global read 0 // global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); static_for<0, 2, 1>{}([&](auto i_pre){
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
s_nop();
// move to 1 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); s_nop();
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// global read 1 // move to i_pre + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write 0
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
index_t i = 0; index_t i = 0;
// main body // main body
...@@ -200,87 +211,92 @@ struct GridwiseGemmPipeline_v2<2> ...@@ -200,87 +211,92 @@ struct GridwiseGemmPipeline_v2<2>
{ {
do do
{ {
// move to i + 2 static_for<0, 2, 1>{}([&](auto i_main){
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); block_sync_lds();
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read i + 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write i
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read i + 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 3 // GEMM i
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1 block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
// global read i + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
// LDS write i + 1 // move to i + 3
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// global read i + 3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
block_sync_lds(); // LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_main + 1) % 2>{});
// global read i + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<(i_main + 1) % 2>{});
// GEMM i + 1 // LDS write i + 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_main + 1) % 2>{});
// global read i + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<(i_main + 1) % 2>{});
block_sync_lds(); });
i += 2; i += 2;
} while(i < (num_loop - 2)); } while(i < (num_loop - 4));
} }
// tail // tail
if (i > num_loop - 2) if (i == num_loop - 3)
{ {
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1 // LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 1 // GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// tail block_sync_lds();
else if (i == num_loop - 2)
{
// Write num_loop - 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 2 // GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
block_sync_lds(); // tail
else if (i == num_loop - 4)
{
static_for<0, 4, 1>{}([&](auto i_res){
block_sync_lds();
// LDS write num_loop - 1 // GEMM num_loop - 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
block_sync_lds(); if constexpr(i_res < 3)
{
block_sync_lds();
if constexpr(i_res < 1)
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 2>{});
if constexpr(i_res < 1)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 2>{});
if constexpr(i_res < 1)
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
}
});
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
...@@ -300,12 +316,12 @@ struct GridwiseGemmPipeline_v2<3> ...@@ -300,12 +316,12 @@ struct GridwiseGemmPipeline_v2<3>
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{ {
// TODO: improve applicability // TODO: improve applicability
return num_loop > 3; return num_loop >= 4;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop > 3; return num_loop >= 7;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -342,8 +358,13 @@ struct GridwiseGemmPipeline_v2<3> ...@@ -342,8 +358,13 @@ struct GridwiseGemmPipeline_v2<3>
static_for<0, 3, 1>{}([&](auto i_pre){ static_for<0, 3, 1>{}([&](auto i_pre){
// global read i_pre // global read i_pre
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{}); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
s_nop();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{}); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
s_nop();
// move to i_pre + 1 // move to i_pre + 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -352,6 +373,16 @@ struct GridwiseGemmPipeline_v2<3> ...@@ -352,6 +373,16 @@ struct GridwiseGemmPipeline_v2<3>
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
// LDS write i_main
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read i_main + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write i_main
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
index_t i = 0; index_t i = 0;
// main body // main body
...@@ -360,85 +391,118 @@ struct GridwiseGemmPipeline_v2<3> ...@@ -360,85 +391,118 @@ struct GridwiseGemmPipeline_v2<3>
do do
{ {
static_for<0, 3, 1>{}([&](auto i_main){ static_for<0, 3, 1>{}([&](auto i_main){
block_sync_lds();
// LDS write i_main // GEMM i_main
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_main>{}); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// global Read i_main + 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_main>{});
// LDS write i_main block_sync_lds();
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_main>{});
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_main>{});
// move to i_main + 3 // move to i_main + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds(); // LDS write i_main
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_main + 1) % 3>{});
// GEMM i_main // global Read i_main + 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<(i_main + 1) % 3>{});
block_sync_lds(); // LDS write i_main
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_main + 1) % 3>{});
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<(i_main + 1) % 3>{});
}); });
i += 3; i += 3;
} while(i < (num_loop - 3)); } while(i < (num_loop - 6));
} }
// tail // tail
if (i == num_loop - 3) if (i == num_loop - 6)
{ {
static_for<0, I3, 1>{}([&](auto i_res){ static_for<0, 6, 1>{}([&](auto i_res){
// Write num_loop - 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 3 // GEMM num_loop - 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
if constexpr(i_res < 5)
{
block_sync_lds(); block_sync_lds();
if constexpr(i_res < 2)
{
// move to i_res + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// LDS write i_res
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 2)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<(i_res + 1) % 3>{});
// LDS write i_res
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 2)
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<(i_res + 1) % 3>{});
}
}); });
} }
// tail // tail
else if (i == num_loop - 2) else if (i == num_loop - 5)
{ {
static_for<0, I2, 1>{}([&](auto i_res){ static_for<0, 5, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds(); block_sync_lds();
// GEMM num_loop // GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
if constexpr(i_res < 4)
{
block_sync_lds(); block_sync_lds();
if constexpr(i_res < 1)
{
// move to i_res + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// LDS write i_res
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 1)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<(i_res + 1) % 3>{});
// LDS write i_res
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 3>{});
if constexpr(i_res < 1)
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<(i_res + 1) % 3>{});
}
}); });
} }
// tail // tail
else if (i == num_loop - 1) else if (i == num_loop - 4)
{ {
static_for<0, I1, 1>{}([&](auto i_res){ static_for<0, 4, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds(); block_sync_lds();
// GEMM num_loop // GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); if constexpr(i_res < 3)
{
block_sync_lds();
// LDS write i_res
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<(i_res + 1) % 3>{});
// LDS write i_res
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<(i_res + 1) % 3>{});
}
}); });
} }
...@@ -501,8 +565,13 @@ struct GridwiseGemmPipeline_v2<4> ...@@ -501,8 +565,13 @@ struct GridwiseGemmPipeline_v2<4>
static_for<0, 4, 1>{}([&](auto i_pre){ static_for<0, 4, 1>{}([&](auto i_pre){
// global read i_pre // global read i_pre
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{}); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
s_nop();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{}); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
s_nop();
// move to i_pre + 1 // move to i_pre + 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
......
...@@ -47,6 +47,11 @@ __global__ void ...@@ -47,6 +47,11 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
//void* kargs_ptr = (&p_a_grid)+0x40;
//if(get_block_1d_id()==1&&get_thread_local_1d_id()==0)
// printf("kargs=0x%p, kargs+64=%d\n", (&p_a_grid), *static_cast<int*>(kargs_ptr));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -124,6 +129,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -124,6 +129,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto IM = Number<M_matrix>{};
static constexpr auto IN = Number<N_matrix>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
...@@ -257,6 +265,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -257,6 +265,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
} }
__host__ __device__ static constexpr auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock_Static(const CMNGridDesc& )
{
const auto M = IM;
const auto N = IN;
const auto MBlock = Number<M / MPerBlock>{};
const auto NBlock = Number<N / NPerBlock>{};
return make_naive_tensor_descriptor_packed(make_tuple(MBlock, Number<MPerBlock>{}, NBlock, Number<NPerBlock>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
...@@ -265,6 +285,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -265,6 +285,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
c_m_n_grid_desc, 8, KBatch); c_m_n_grid_desc, 8, KBatch);
} }
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptorStatic(
const CMNGridDesc& c_m_n_grid_desc)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt_Static<MPerBlock, NPerBlock, CMNGridDesc>(
c_m_n_grid_desc);
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
...@@ -278,9 +306,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -278,9 +306,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{})); Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
} }
#if USEING_STATIC_KERNEL
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock_Static(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptorStatic(CMNGridDesc{}));
#else
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
#endif
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
......
...@@ -47,3 +47,32 @@ ...@@ -47,3 +47,32 @@
#ifdef CK_USE_AMD_MFMA #ifdef CK_USE_AMD_MFMA
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#endif #endif
#define USEING_STATIC_KERNEL 1
#define MNKB_16_1152_5120_8 0
#define MNKB_16_5120_384_3 1
#define MNKB_16_1280_5120_8 1
#define MNKB_16_5120_1280_5 1
#if MNKB_16_1152_5120_8
#define M_matrix 16
#define N_matrix 1152
#define K_matrix 5120
#define K_batch 8
#elif MNKB_16_5120_384_3
#define M_matrix 16
#define N_matrix 5120
#define K_matrix 384
#define K_batch 4
#elif MNKB_16_1280_5120_8
#define M_matrix 16
#define N_matrix 1280
#define K_matrix 5120
#define K_batch 8
#elif MNKB_16_5120_1280_5
#define M_matrix 16
#define N_matrix 5120
#define K_matrix 1280
#define K_batch 5
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment