Commit 25e35b59 authored by Chao Liu's avatar Chao Liu
Browse files

rename, clean

parent 8a60a329
add_example_executable(example_gemm_bias_add_fastgelu_xdl_fp16 gemm_bias_add_fastgelu_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_fastgelu_xdl_fp16 gemm_bias_fastgelu_xdl_fp16.cpp)
......@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta)
add_subdirectory(03_gemm_bias_relu)
add_subdirectory(03_gemm_bias_fastgelu)
add_subdirectory(03_gemm_bias_add_fastgelu)
add_subdirectory(04_gemm_bias_relu_add)
add_subdirectory(06_conv2d_fwd_bias_relu)
add_subdirectory(07_conv2d_fwd_bias_relu_add)
......
......@@ -33,7 +33,7 @@ struct DeviceGemmMultipleD : public BaseOperator
ck::index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op);
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -489,7 +489,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
......@@ -500,7 +500,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -512,7 +512,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
}
......@@ -538,13 +538,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_etile_map_;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
......@@ -625,10 +626,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
......@@ -782,7 +783,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
return str.str();
}
}; // namespace device
};
} // namespace device
} // namespace tensor_operation
......
......@@ -14,18 +14,24 @@
namespace ck {
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename DsDataType,
typename FloatC,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -55,8 +61,8 @@ template <typename FloatAB,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
......@@ -153,12 +159,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -168,7 +174,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
......@@ -182,7 +188,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
......@@ -199,58 +205,59 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop, typename Block2CTileMap>
template <bool HasMainKBlockLoop, typename Block2ETileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid,
FloatE* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const StaticallyIndexedArray<CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
const StaticallyIndexedArray<EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>&
ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: use tuple
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different
// type from E
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -266,17 +273,17 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
},
Number<NumDTensor>{});
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
......@@ -537,27 +544,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// shuffle: blockwise copy C from LDS to global
#if 1
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v6r3<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
CDEElementwiseOperation, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data,
remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data,
FloatC, // typename DstData,
FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
......@@ -568,37 +575,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
cde_element_op};
#else
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
CDEElementwiseOperation, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data,
FloatC, // typename DstData,
FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
cde_element_op};
#endif
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
......@@ -611,8 +618,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
......@@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
......@@ -640,37 +647,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global
#if 1
c_shuffle_block_copy_lds_to_global.Run(
cde_block_copy_lds_and_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0],
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
#else
c_shuffle_block_copy_lds_to_global.Run(
cde_block_copy_lds_and_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
#endif
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
constexpr auto c_global_step = sfc_cde_block.GetForwardStep(access_id);
// move on Ds
c_shuffle_block_copy_lds_to_global.MoveSrc1SliceWindow(
cde_block_copy_lds_and_global.MoveSrc1SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], c_global_step);
c_shuffle_block_copy_lds_to_global.MoveSrc2SliceWindow(
cde_block_copy_lds_and_global.MoveSrc2SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], c_global_step);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment