Commit 5dce9c9d authored by wangshaojie6's avatar wangshaojie6
Browse files

make device/grid level code

parent f0d63f25
......@@ -56,28 +56,14 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
auto f_tensor_length_stride_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout){
if (std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return {std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1})};
}
else
{
return {std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride})};
}
};
std::vector<ck::index_t> a_ms_ks_lengths = f_tensor_length_stride_descriptor(M, K, StrideA, ALayout{})[0];
std::vector<ck::index_t> a_ms_ks_strides = f_tensor_length_stride_descriptor(M, K, StrideA, ALayout{})[1];
std::vector<ck::index_t> b_ns_ks_lengths = f_tensor_length_stride_descriptor(N, K, StrideB, Row{})[0];
std::vector<ck::index_t> b_ns_ks_strides = f_tensor_length_stride_descriptor(N, K, StrideB, Row{})[1];
std::vector<ck::index_t> d_ms_ns_lengths = f_tensor_length_stride_descriptor(M, N, 0, Row{})[0];
std::vector<ck::index_t> d_ms_ns_strides = f_tensor_length_stride_descriptor(M, N, 0, Row{})[1];
std::vector<ck::index_t> e_ms_ns_lengths = f_tensor_length_stride_descriptor(M, N, StrideE, ELayout{})[0];
std::vector<ck::index_t> e_ms_ns_strides = f_tensor_length_stride_descriptor(M, N, StrideE, ELayout{})[1];
std::vector<ck::index_t> a_ms_ks_lengths = {M, K};
std::vector<ck::index_t> a_ms_ks_strides = {StrideA, 1};
std::vector<ck::index_t> b_ns_ks_lengths = {N, K};
std::vector<ck::index_t> b_ns_ks_strides = {StrideB, 1};
std::vector<ck::index_t> d_ms_ns_lengths = {M, N};
std::vector<ck::index_t> d_ms_ns_strides = {0, 1};
std::vector<ck::index_t> e_ms_ns_lengths = {M, N}
std::vector<ck::index_t> e_ms_ns_strides = {StrideE, 1};
switch(config.init_method)
{
......@@ -176,7 +162,7 @@ bool run_splitK_gemm_bias(const ProblemSize& problem_size, const ExecutionConfig
Tensor<CDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host_result, d_m_n, a_element_op, b_element_op, c_element_op);
a_m_k, b_k_n, e_m_n_host_result, d_m_n, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
......
......@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_splitK_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -51,7 +51,7 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceOpInstanceKKN = ck::tensor_operation::device::
......
......@@ -15,12 +15,13 @@
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_splitk_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
template <typename GridwiseGemmAtomicAdd,
typename FloatAB,
typename FloatDsPointer,
typename FloatE,
......@@ -57,7 +58,7 @@ __global__ void
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseGemmAtomicAdd::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -80,7 +81,7 @@ __global__ void
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
GridwiseGemmAtomicAdd::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
......@@ -538,56 +539,8 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
EGridDesc_G_M_N e_grid_desc_g_m_n_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// GridwiseGemm
using GridwiseGemmAtomicAdd = GridwiseGemmMultipleD_xdl_cshuffle<
// GridwiseGemmAtomicAdd atomicadd
using GridwiseGemmAtomicAdd = GridwiseGemmSplitKMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
......@@ -635,11 +588,11 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
GridwiseGemmAtomicAdd::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
GridwiseGemmAtomicAdd::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using Block2ETileMap = typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
......@@ -676,12 +629,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
e_grid_desc_g_m_n_{
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
GridwiseGemmAtomicAdd::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
GridwiseGemmAtomicAdd::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, KBatch)},
block_2_etile_map_{GridwiseGemmAtomicAdd::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, 1, 1, KBatch)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -711,18 +664,18 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
});
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
if(GridwiseGemmAtomicAdd::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
GridwiseGemmAtomicAdd::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
GridwiseGemmAtomicAdd::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
}
......@@ -753,7 +706,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
typename GridwiseGemmAtomicAdd::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
......@@ -768,9 +721,9 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemmAtomicAdd::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemmAtomicAdd::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
......@@ -804,7 +757,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
if(!GridwiseGemmAtomicAdd::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
......@@ -826,19 +779,19 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle<
GridwiseGemm,
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
typename GridwiseGemmAtomicAdd::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmAtomicAdd::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmAtomicAdd::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
typename GridwiseGemm::DefaultBlock2ETileMap,
typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
......@@ -862,7 +815,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
arg.block_2_etile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if(GridwiseGemmAtomicAdd::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
......@@ -887,7 +840,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
return false;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
if(!GridwiseGemmAtomicAdd::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
......
......@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -231,17 +232,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
// __host__ __device__ static constexpr auto
// MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
// {
// return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
// e_grid_desc_m_n);
// }
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap(
const EGridDesc_M_N& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch = 1)
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& c_m_n_grid_desc,
index_t /* M01 */,
index_t /* N01 */,
index_t KBatch = 1)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
c_m_n_grid_desc, 8, KBatch);
......@@ -263,6 +259,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
if(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
return false;
}
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
......@@ -332,7 +333,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, KBatch))>;
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, 1, 1, 1))>;
using DsGridPointer = decltype(MakeDsGridPointer());
......@@ -378,19 +379,22 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// k batch id
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
......@@ -426,7 +430,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(k_batch_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
......@@ -457,7 +461,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(k_batch_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
......@@ -640,6 +644,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// add multiple d at the fisrt atomic option position
// if(k_batch_id == 0)
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
......@@ -665,12 +672,6 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
},
Number<NumDTensor>{}));
// only do bias at the 1st atomic add position
if constexpr(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd)
{
}
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
......@@ -679,8 +680,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
......@@ -698,9 +700,35 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
cde_element_op};
// block wise copy E between lds and global
auto e_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
EDataType, // typename SrcData,
EDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
ck::tensor_operation::element_wise::PassThrough{}};
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -741,12 +769,23 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// make sure it's safe to read from LDS
block_sync_lds();
if(k_batch_id == 0)
{
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
}
else
{
e_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
}
if constexpr(access_id < num_access - 1)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment