Commit 489599ba authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

add multiD support into gridwise and deviceOp

parent ad1597c4
...@@ -25,7 +25,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -25,7 +25,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, ck::Tuple<>, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
256, 256,
224, 256, 224, 256,
......
...@@ -133,10 +133,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -133,10 +133,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
{},
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else #else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif #endif
M, M,
...@@ -144,6 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -144,6 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
K, K,
StrideA, StrideA,
StrideB, StrideB,
{},
StrideC, StrideC,
KBatch, KBatch,
a_element_op, a_element_op,
......
...@@ -14,21 +14,26 @@ template <typename ALayout, ...@@ -14,21 +14,26 @@ template <typename ALayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmV2 : public BaseOperator struct DeviceGemmV2 : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t KSplit, ck::index_t KSplit,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
......
...@@ -25,6 +25,7 @@ template <typename ALayout, ...@@ -25,6 +25,7 @@ template <typename ALayout,
typename CLayout, typename CLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename CDataType, typename CDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
...@@ -69,11 +70,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -69,11 +70,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout, ALayout,
...@@ -83,6 +87,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -83,6 +87,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BDataType, BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
Tuple<>,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -586,19 +591,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -586,19 +591,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
std::array<const void*, NumDTensor> p_ds,
CDataType* p_c, CDataType* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch, index_t KBatch,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; return Argument{p_a,
p_b,
p_ds,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -606,28 +627,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -606,28 +627,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch, index_t KBatch,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation) override CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideDs,
StrideC, StrideC,
KBatch); KBatch,
a_element_op,
b_element_op,
c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
namespace ck { namespace ck {
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
...@@ -43,6 +45,7 @@ __global__ void ...@@ -43,6 +45,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid, karg.p_c_grid,
p_shared, p_shared,
karg); karg);
...@@ -75,6 +78,7 @@ __global__ void ...@@ -75,6 +78,7 @@ __global__ void
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid, karg.p_c_grid,
p_shared_0, p_shared_0,
p_shared_1, p_shared_1,
...@@ -91,6 +95,7 @@ template <typename ALayout, ...@@ -91,6 +95,7 @@ template <typename ALayout,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -147,6 +152,21 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -147,6 +152,21 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
...@@ -458,6 +478,28 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -458,6 +478,28 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
} }
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) { return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); },
Number<NumDTensor>{});
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
[&](auto i) {
return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[i], MBlock, NBlock);
},
Number<NumDTensor>{});
}
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
struct Problem struct Problem
{ {
__host__ Problem(index_t M_, __host__ Problem(index_t M_,
...@@ -465,6 +507,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -465,6 +507,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t K_, index_t K_,
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_, index_t StrideC_,
index_t KBatch_) index_t KBatch_)
: M{M_}, : M{M_},
...@@ -472,6 +515,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -472,6 +515,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
K{K_}, K{K_},
StrideA{StrideA_}, StrideA{StrideA_},
StrideB{StrideB_}, StrideB{StrideB_},
StrideDs{StrideDs_},
StrideC{StrideC_}, StrideC{StrideC_},
KBatch{KBatch_}, KBatch{KBatch_},
MPadded{CalculateMPadded(M_)}, MPadded{CalculateMPadded(M_)},
...@@ -509,6 +553,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -509,6 +553,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t K; index_t K;
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideC; index_t StrideC;
index_t KBatch; index_t KBatch;
index_t MPadded; index_t MPadded;
...@@ -526,24 +571,46 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -526,24 +571,46 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
__host__ Argument(const ADataType* p_a_grid_, __host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_, const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_, CDataType* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_, index_t StrideC_,
index_t k_batch_) index_t k_batch_,
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_} p_ds_grid{},
p_c_grid{p_c_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_}
{ {
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
});
} }
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
DsGridPointer p_ds_grid;
CDataType* p_c_grid; CDataType* p_c_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
}; };
struct SplitKBatchOffset struct SplitKBatchOffset
...@@ -1133,6 +1200,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1133,6 +1200,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
TailNumber TailNum = TailNumber::Odd> TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid, __device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared, void* p_shared,
const Problem& problem) const Problem& problem)
...@@ -1407,6 +1475,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1407,6 +1475,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
#if 0
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
...@@ -1433,6 +1502,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1433,6 +1502,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index(block_m_id, 0, block_n_id, 0), make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op}; c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -1489,6 +1559,156 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1489,6 +1559,156 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
#else
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock =
CShuffleBlockTransferScalarPerVector_NPerBlock;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
#endif
} }
} }
...@@ -1497,6 +1717,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1497,6 +1717,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
TailNumber TailNum = TailNumber::Odd> TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid, __device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared_0, void* p_shared_0,
void* p_shared_1, void* p_shared_1,
...@@ -1782,6 +2003,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1782,6 +2003,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
#if 0
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
...@@ -1808,6 +2030,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1808,6 +2030,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index(block_m_id, 0, block_n_id, 0), make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op}; c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -1864,6 +2087,156 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1864,6 +2087,156 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
#else
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock =
CShuffleBlockTransferScalarPerVector_NPerBlock;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
#endif
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment