Commit 1d7290fb authored by rocking's avatar rocking
Browse files

Update interface

parent 003ec407
...@@ -63,7 +63,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern ...@@ -63,7 +63,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| //######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | LayernormThreadClusterSize_M_N, LayernormThreadSliceSize_M_N
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>; < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
// clang-format on // clang-format on
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
// #include #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
...@@ -103,23 +102,56 @@ __global__ void ...@@ -103,23 +102,56 @@ __global__ void
#endif #endif
} }
// template <typename GridwiseWelfordLayernorm, template <typename GridwiseWelfordLayernorm,
// typename EDataType, typename EDataType,
// typename HDataType, typename HDataType,
// typename MeanDataType, typename MeanDataType,
// typename VarDataType> typename VarDataType,
// __global__ void typename GammaDataType,
// #if CK_USE_LAUNCH_BOUNDS typename BetaDataType,
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) typename ComputeDataType,
// #endif typename EHGridDesc_M_N,
// kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid, typename MeanVarCountGridDesc_M_N,
// const MeanDataType* __restrict__ p_mean_grid, typename GammaBetaGridDesc_N>
// const VarDataType* __restrict__ p_var_grid, __global__ void
// HDataType* __restrict__ p_y_grid, #if CK_USE_LAUNCH_BOUNDS
// index_t blkgroup_size) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// { #endif
// // GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size); kernel_welford_layernorm2d_second_half(
// } const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t blkgroup_size,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
ComputeDataType epsilon)
{
GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid,
p_in_welford_var_grid,
p_in_welford_count_grid,
p_gamma_grid,
p_beta_grid,
p_h_grid,
e_grid_desc_m_n,
h_grid_desc_m_n,
mean_var_count_grid_desc_m_n,
gamma_grid_desc_n,
beta_grid_desc_n,
blkgroup_size,
num_mean_var_count_k_block_tile_iteration,
num_xy_k_block_tile_iteration,
epsilon);
}
} // namespace ck } // namespace ck
...@@ -204,6 +236,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -204,6 +236,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
using LayernormBlockTileSize_M_N =
Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M_N::At(0),
LayernormThreadClusterSize_M_N::At(1) * LayernormThreadSliceSize_M_N::At(1)>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -330,11 +366,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -330,11 +366,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarCountGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using MeanVarCountGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1)); using EHGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -351,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -351,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
AGridDesc_M_K, AGridDesc_M_K,
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_N, MeanVarCountGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -388,27 +422,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -388,27 +422,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
// using GridwiseWelfordLayernorm = using GridwiseWelfordLayernorm =
// GridwiseWelfordSecondHalfLayernorm2d<EDataType, GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType, HDataType,
// MeanDataType, MeanDataType,
// VarDataType, VarDataType,
// AccDataType, GammaDataType,
// HGridDesc_M_N, BetaDataType,
// MeanVarGridDesc_M_N, AccDataType,
// GammaBetaGridDesc_N, EHGridDesc_M_N,
// MeanVarGridDesc_M, MeanVarCountGridDesc_M_N,
// BlockSize, GammaBetaGridDesc_N,
// LayernormThreadClusterSize_M_N::At(I0), BlockSize,
// LayernormThreadClusterSize_M_N::At(I1), LayernormThreadClusterSize_M_N::At(I0),
// LayernormThreadSliceSize_M_N::At(I0), LayernormThreadClusterSize_M_N::At(I1),
// LayernormThreadSliceSize_M_N::At(I1), LayernormThreadSliceSize_M_N::At(I0),
// LayernormESrcHDstVectorDim, LayernormThreadSliceSize_M_N::At(I1),
// LayernormESrcVectorSize, LayernormESrcHDstVectorDim,
// LayernormHDstVectorSize, LayernormESrcVectorSize,
// LayernormGammaSrcVectorSize, LayernormHDstVectorSize,
// LayernormBetaSrcVectorSize, LayernormGammaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>; LayernormBetaSrcVectorSize,
LayernormMeanVarSrcDstVectorSize>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -449,20 +484,24 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -449,20 +484,24 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
block_2_etile_map_{GridwiseGemmWelford::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemmWelford::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
h_element_op_{h_element_op}, h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
mean_var_count_grid_desc_m_n_ = mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_); DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, gemm_nblock_, gemm_nblock_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw)); hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
int gemm_welford_size = MRaw * blkGroupSize_; int gemm_welford_size = MRaw * gemm_nblock_;
hip_check_error( hip_check_error(
hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size)); hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error( hip_check_error(
...@@ -502,8 +541,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -502,8 +541,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_n_); mean_var_count_grid_desc_m_n_);
} }
// TODO - H
} }
void Print() const void Print() const
...@@ -533,11 +570,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -533,11 +570,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_; MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_;
GammaBetaGridDesc_N gamma_grid_desc_n_; GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_; GammaBetaGridDesc_N beta_grid_desc_n_;
HGridDesc_M_N h_grid_desc_m_n_; EHGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -558,7 +595,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -558,7 +595,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
HElementwiseOperation h_element_op_; HElementwiseOperation h_element_op_;
int blkGroupSize_; int gemm_nblock_;
AccDataType epsilon_; AccDataType epsilon_;
}; };
...@@ -581,9 +618,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -581,9 +618,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
} }
const index_t grid_size = index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -612,12 +650,18 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -612,12 +650,18 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemmWelford::DefaultBlock2ETileMap, typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
// const auto kernel_welford_layernorm = const auto kernel_welford_layernorm =
// kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm, kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
// EDataType, EDataType,
// HDataType, HDataType,
// MeanDataType, MeanDataType,
// VarDataType>; VarDataType,
GammaDataType,
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
GammaBetaGridDesc_N>;
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -642,16 +686,39 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -642,16 +686,39 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_, arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
// avg_time += launch_and_time_kernel(stream_config, grid_size = math::integer_least_multiple(M, LayernormBlockTileSize_M_N::At(0)) /
// kernel_welford_layernorm, LayernormBlockTileSize_M_N::At(0);
// dim3(grid_size),
// dim3(BlockSize), index_t numMeanVarCountBlockTileIteration_N =
// 0, math::integer_least_multiple(arg.gemm_nblock_,
// arg.p_e_grid_, LayernormThreadClusterSize_M_N::At(I1)) /
// arg.p_welford_mean_grid_, LayernormThreadClusterSize_M_N::At(I1);
// arg.p_welford_var_grid_,
// arg.p_h_grid_, index_t numXBlockTileIteration_N =
// arg.blkGroupSize_); math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) /
LayernormBlockTileSize_M_N::At(I1);
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_welford_mean_grid_,
arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N,
numXBlockTileIteration_N,
arg.epsilon_);
return avg_time; return avg_time;
}; };
...@@ -681,6 +748,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -681,6 +748,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return false; return false;
} }
// TODO
return true; return true;
} }
......
...@@ -23,25 +23,26 @@ template <typename EDataType, ...@@ -23,25 +23,26 @@ template <typename EDataType,
typename HDataType, typename HDataType,
typename MeanDataType, typename MeanDataType,
typename VarDataType, typename VarDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename XYGridDesc_M_N, typename EHGridDesc_M_N,
typename MeanVarGridDesc_M_N, typename MeanVarCountGridDesc_M_N,
typename GammaBetaGridDesc_N, typename GammaBetaGridDesc_N,
typename MeanVarGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t NThreadClusterSize, index_t NThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t NThreadSliceSize, index_t NThreadSliceSize,
index_t XSrcYDstVectorDim, index_t ESrcYDstVectorDim,
index_t XSrcVectorSize, index_t ESrcVectorSize,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t MeanVarSrcDstVectorSize> index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d struct GridwiseWelfordSecondHalfLayernorm2d
{ {
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); static constexpr bool reorder_thread_cluster = (ESrcYDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>; using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
...@@ -76,57 +77,38 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -76,57 +77,38 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize; static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EDataType* __restrict__ p_e_grid, __device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_mean_grid, const MeanDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_var_grid, const VarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
/*const MeanVarGridDesc_M_N& mean_grid_desc_m_k, const EHGridDesc_M_N& e_grid_desc_m_n,
const MeanVarGridDesc_M_N& var_grid_desc_m_k, const EHGridDesc_M_N& h_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_m, const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& beta_grid_desc_m, const GammaBetaGridDesc_N& gamma_grid_desc_n,
const MeanVarGridDesc_M& mean_var_grid_desc_m,*/ const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t blkgroup_size) index_t gemm_nblock_,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
ComputeDataType epsilon)
{ {
ignore = p_e_grid; ignore = p_e_grid;
ignore = p_mean_grid; ignore = p_in_welford_mean_grid;
ignore = p_var_grid; ignore = p_in_welford_var_grid;
ignore = p_in_welford_count_grid;
ignore = p_gamma_grid;
ignore = p_beta_grid;
ignore = p_h_grid; ignore = p_h_grid;
ignore = e_grid_desc_m_n;
const index_t thread_local_id = get_thread_local_1d_id(); ignore = h_grid_desc_m_n;
const index_t block_global_id = get_block_1d_id(); ignore = mean_var_count_grid_desc_m_n;
const index_t blkgroup_id = block_global_id / blkgroup_size; ignore = gamma_grid_desc_n;
const index_t block_local_id = block_global_id % blkgroup_size; ignore = beta_grid_desc_n;
ignore = gemm_nblock_;
const auto thread_cluster_idx = ignore = num_mean_var_count_k_block_tile_iteration;
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); ignore = num_xy_k_block_tile_iteration;
ignore = epsilon;
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
/*
auto threadwise_mean_load_m_n =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_grid_desc_m_n,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * 1));*/
} // run } // run
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment