Commit 5dce9c9d authored by wangshaojie6's avatar wangshaojie6
Browse files

make device/grid level code

parent f0d63f25
...@@ -56,28 +56,14 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -56,28 +56,14 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
auto f_tensor_length_stride_descriptor = std::vector<ck::index_t> a_ms_ks_lengths = {M, K};
[](std::size_t row, std::size_t col, std::size_t stride, auto layout){ std::vector<ck::index_t> a_ms_ks_strides = {StrideA, 1};
if (std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) std::vector<ck::index_t> b_ns_ks_lengths = {N, K};
{ std::vector<ck::index_t> b_ns_ks_strides = {StrideB, 1};
return {std::vector<std::size_t>({row, col}), std::vector<ck::index_t> d_ms_ns_lengths = {M, N};
std::vector<std::size_t>({stride, 1})}; std::vector<ck::index_t> d_ms_ns_strides = {0, 1};
} std::vector<ck::index_t> e_ms_ns_lengths = {M, N}
else std::vector<ck::index_t> e_ms_ns_strides = {StrideE, 1};
{
return {std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride})};
}
};
std::vector<ck::index_t> a_ms_ks_lengths = f_tensor_length_stride_descriptor(M, K, StrideA, ALayout{})[0];
std::vector<ck::index_t> a_ms_ks_strides = f_tensor_length_stride_descriptor(M, K, StrideA, ALayout{})[1];
std::vector<ck::index_t> b_ns_ks_lengths = f_tensor_length_stride_descriptor(N, K, StrideB, Row{})[0];
std::vector<ck::index_t> b_ns_ks_strides = f_tensor_length_stride_descriptor(N, K, StrideB, Row{})[1];
std::vector<ck::index_t> d_ms_ns_lengths = f_tensor_length_stride_descriptor(M, N, 0, Row{})[0];
std::vector<ck::index_t> d_ms_ns_strides = f_tensor_length_stride_descriptor(M, N, 0, Row{})[1];
std::vector<ck::index_t> e_ms_ns_lengths = f_tensor_length_stride_descriptor(M, N, StrideE, ELayout{})[0];
std::vector<ck::index_t> e_ms_ns_strides = f_tensor_length_stride_descriptor(M, N, StrideE, ELayout{})[1];
switch(config.init_method) switch(config.init_method)
{ {
...@@ -176,7 +162,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig ...@@ -176,7 +162,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
Tensor<CDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host_result, d_m_n, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, e_m_n_host_result, d_m_n, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_splitK_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -51,7 +51,7 @@ static constexpr ck::index_t NumDimM = 1; ...@@ -51,7 +51,7 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1; static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1; static constexpr ck::index_t NumDimK = 1;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceOpInstanceKKN = ck::tensor_operation::device:: using DeviceOpInstanceKKN = ck::tensor_operation::device::
......
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_splitk_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemmAtomicAdd,
typename FloatAB, typename FloatAB,
typename FloatDsPointer, typename FloatDsPointer,
typename FloatE, typename FloatE,
...@@ -57,7 +58,7 @@ __global__ void ...@@ -57,7 +58,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemmAtomicAdd::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -80,7 +81,7 @@ __global__ void ...@@ -80,7 +81,7 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemmAtomicAdd::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_batch_offset,
...@@ -538,56 +539,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -538,56 +539,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
}; };
// GridwiseGemm // GridwiseGemmAtomicAdd atomicadd
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemmAtomicAdd = GridwiseGemmSplitKMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -635,11 +588,11 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -635,11 +588,11 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemmAtomicAdd::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemmAtomicAdd::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -676,12 +629,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -676,12 +629,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
e_grid_desc_g_m_n_{ e_grid_desc_g_m_n_{
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemmAtomicAdd::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, GridwiseGemmAtomicAdd::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, KBatch)}, block_2_etile_map_{GridwiseGemmAtomicAdd::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, 1, 1, KBatch)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -711,18 +664,18 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -711,18 +664,18 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
}); });
// populate desc for Ds/E // populate desc for Ds/E
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, if(GridwiseGemmAtomicAdd::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemmAtomicAdd::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemmAtomicAdd::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
} }
...@@ -753,7 +706,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -753,7 +706,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemmAtomicAdd::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
...@@ -768,9 +721,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -768,9 +721,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemmAtomicAdd::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemmAtomicAdd::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
...@@ -804,7 +757,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -804,7 +757,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemmAtomicAdd::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
...@@ -826,19 +779,19 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -826,19 +779,19 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle<
GridwiseGemm, GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemmAtomicAdd::DsGridPointer,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemmAtomicAdd::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemmAtomicAdd::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -862,7 +815,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -862,7 +815,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemmAtomicAdd::CalculateHasMainKBlockLoop(K))
{ {
return launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
...@@ -887,7 +840,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -887,7 +840,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
return false; return false;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemmAtomicAdd::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -231,17 +232,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -231,17 +232,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// return block_id to E matrix tile idx (m0, n0) mapping
// __host__ __device__ static constexpr auto
// MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
// {
// return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
// e_grid_desc_m_n);
// }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap( __host__ __device__ static constexpr auto
const EGridDesc_M_N& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch = 1) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& c_m_n_grid_desc,
index_t /* M01 */,
index_t /* N01 */,
index_t KBatch = 1)
{ {
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>( return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
c_m_n_grid_desc, 8, KBatch); c_m_n_grid_desc, 8, KBatch);
...@@ -263,6 +259,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -263,6 +259,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
if(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
return false;
}
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{ {
...@@ -332,7 +333,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -332,7 +333,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap = using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, KBatch))>; remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, 1, 1, 1))>;
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
...@@ -378,19 +379,22 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -378,19 +379,22 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex( if(!block_2_etile_map.ValidCTileIndex(
block_work_idx, make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ {
return; return;
} }
// k batch id
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1, BK1);
...@@ -426,7 +430,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -426,7 +430,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -457,7 +461,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -457,7 +461,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(k_batch_id, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -640,6 +644,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -640,6 +644,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// add multiple d at the fisrt atomic option position
// if(k_batch_id == 0)
// tuple of reference to C/Ds tensor descriptors // tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference( const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
...@@ -665,12 +672,6 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -665,12 +672,6 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
}, },
Number<NumDTensor>{})); Number<NumDTensor>{}));
// only do bias at the 1st atomic add position
if constexpr(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd)
{
}
// blockwise copy C/D/E between LDS and global // blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, ThisThreadBlock,
...@@ -679,8 +680,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -679,8 +680,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
decltype(c_ds_desc_refs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// support arbitray type // Sequence support
// arbitray type
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
...@@ -698,9 +700,35 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -698,9 +700,35 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
{c_ds_desc_refs, {c_ds_desc_refs,
idx_c_ds_block_begin, idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
cde_element_op}; cde_element_op};
// block wise copy E between lds and global
auto e_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
EDataType, // typename SrcData,
EDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
ck::tensor_operation::element_wise::PassThrough{}};
// space filling curve for threadwise C in VGPR before shuffle // space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -741,12 +769,23 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle ...@@ -741,12 +769,23 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global if(k_batch_id == 0)
cde_block_copy_lds_and_global.Run( {
c_ds_desc_refs, // each block copy its data from LDS to global
c_ds_buf_refs, cde_block_copy_lds_and_global.Run(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), c_ds_desc_refs,
tie(e_grid_buf)); c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
}
else
{
e_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
}
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment