"...composable_kernel_rocm.git" did not exist on "79a4b17f97a7ee29573780354727fa402f344e26"
Commit dd0255ba authored by rocking's avatar rocking
Browse files

Add gridwise gemm + welford

parent e9f656fa
...@@ -12,10 +12,96 @@ ...@@ -12,10 +12,96 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename FDataType,
typename GDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename FGridDescriptor_MBlock_MPerBlock_NBlock,
typename GGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_welford_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
FDataType* __restrict__ p_f_grid,
GDataType* __restrict__ p_g_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const FGridDescriptor_MBlock_MPerBlock_NBlock f_grid_desc_mblock_mperblock_nblock,
const GGridDescriptor_MBlock_MPerBlock_NBlock g_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_f_grid,
p_g_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
f_grid_desc_mblock_mperblock_nblock,
g_grid_desc_mblock_mperblock_nblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = p_f_grid;
ignore = p_g_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = f_grid_desc_mblock_mperblock_nblock;
ignore = g_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -43,7 +129,7 @@ template <typename ALayout, ...@@ -43,7 +129,7 @@ template <typename ALayout,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
typename FDataType, typename HDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -82,6 +168,8 @@ template <typename ALayout, ...@@ -82,6 +168,8 @@ template <typename ALayout,
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{ {
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle; using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using FDataType = CShuffleDataType;
using GDataType = CShuffleDataType;
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -162,8 +250,64 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -162,8 +250,64 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using FGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1)); using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDWelford_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
FDataType,
GDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
FGridDesc_M_N,
GGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock,
ReduceThreadTransferScalarPerVector_NPerBlock,
LoopSched>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -171,7 +315,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -171,7 +315,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
void* p_f_grid, void* p_h_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -188,18 +332,70 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -188,18 +332,70 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_f_grid_{static_cast<FDataType*>(p_f_grid)}, p_f_grid_{nullptr},
p_g_grid_{nullptr},
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
f_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(
MRaw,
math::integer_divide_ceil(NRaw, NPerBlock),
math::integer_divide_ceil(NRaw, NPerBlock))},
g_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(
MRaw,
math::integer_divide_ceil(NRaw, NPerBlock),
math::integer_divide_ceil(NRaw, NPerBlock))},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
h_element_op_{h_element_op} h_element_op_{h_element_op}
{ {
// TODO int welford_size = MRaw * math::integer_divide_ceil(NRaw, NPerBlock);
hip_check_error(hipMalloc(&p_f_grid_, sizeof(FDataType) * welford_size));
hip_check_error(hipMalloc(&p_g_grid_, sizeof(GDataType) * welford_size));
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E/F/G
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
f_grid_desc_m_n_,
g_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
f_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(f_grid_desc_m_n_);
g_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(g_grid_desc_m_n_);
}
// TODO - H
} }
void Print() const void Print() const
...@@ -216,20 +412,35 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -216,20 +412,35 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// pointers // pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
// FIXME - typename GridwiseGemm::DsGridPointer typename GridwiseGemm::DsGridPointer p_ds_grid_;
std::array<const void*, NumDTensor> p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
FDataType* p_f_grid_; FDataType* p_f_grid_; // mean
GDataType* p_g_grid_; // variance * count
HDataType* p_h_grid_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
FGridDesc_M_N f_grid_desc_m_n_;
GGridDesc_M_N g_grid_desc_m_n_;
HGridDesc_M_N h_grid_desc_m_n_; HGridDesc_M_N h_grid_desc_m_n_;
// TODO - tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
// TODO - block-to-e-tile map typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock
f_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock
g_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -243,10 +454,79 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -243,10 +454,79 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument&, const StreamConfig&) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
// TODO // TODO
return 0; if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.f_grid_desc_m_n_,
arg.g_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_welford_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
FDataType,
GDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
typename GridwiseGemm::DefaultAGridDesc_AK0_M_AK1,
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.p_f_grid_,
arg.p_g_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.f_grid_desc_mblock_mperblock_nblock_,
arg.g_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
} }
// polymorphic // polymorphic
...@@ -264,7 +544,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -264,7 +544,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return false; return false;
} }
return false; return true;
} }
// polymorphic // polymorphic
...@@ -277,7 +557,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -277,7 +557,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
void* p_f, void* p_h,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -295,7 +575,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -295,7 +575,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
p_f, p_h,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
...@@ -317,7 +597,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -317,7 +597,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
void* p_f, void* p_h,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -335,7 +615,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -335,7 +615,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
p_f, p_h,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment