Commit 1d7290fb authored by rocking's avatar rocking
Browse files

Update interface

parent 003ec407
......@@ -63,7 +63,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | LayernormThreadClusterSize_M_N, LayernormThreadSliceSize_M_N
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
// clang-format on
......
......@@ -13,8 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
// #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
......@@ -103,23 +102,56 @@ __global__ void
#endif
}
// template <typename GridwiseWelfordLayernorm,
// typename EDataType,
// typename HDataType,
// typename MeanDataType,
// typename VarDataType>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// #endif
// kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
// const MeanDataType* __restrict__ p_mean_grid,
// const VarDataType* __restrict__ p_var_grid,
// HDataType* __restrict__ p_y_grid,
// index_t blkgroup_size)
// {
// // GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
// }
template <typename GridwiseWelfordLayernorm,
typename EDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
typename GammaBetaGridDesc_N>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(
const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t blkgroup_size,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
ComputeDataType epsilon)
{
GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid,
p_in_welford_var_grid,
p_in_welford_count_grid,
p_gamma_grid,
p_beta_grid,
p_h_grid,
e_grid_desc_m_n,
h_grid_desc_m_n,
mean_var_count_grid_desc_m_n,
gamma_grid_desc_n,
beta_grid_desc_n,
blkgroup_size,
num_mean_var_count_k_block_tile_iteration,
num_xy_k_block_tile_iteration,
epsilon);
}
} // namespace ck
......@@ -204,6 +236,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
static constexpr index_t NumDTensor = DsDataType::Size();
using LayernormBlockTileSize_M_N =
Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M_N::At(0),
LayernormThreadClusterSize_M_N::At(1) * LayernormThreadSliceSize_M_N::At(1)>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -330,11 +366,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarCountGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using EHGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......@@ -351,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
......@@ -388,27 +422,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
// using GridwiseWelfordLayernorm =
// GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType,
// MeanDataType,
// VarDataType,
// AccDataType,
// HGridDesc_M_N,
// MeanVarGridDesc_M_N,
// GammaBetaGridDesc_N,
// MeanVarGridDesc_M,
// BlockSize,
// LayernormThreadClusterSize_M_N::At(I0),
// LayernormThreadClusterSize_M_N::At(I1),
// LayernormThreadSliceSize_M_N::At(I0),
// LayernormThreadSliceSize_M_N::At(I1),
// LayernormESrcHDstVectorDim,
// LayernormESrcVectorSize,
// LayernormHDstVectorSize,
// LayernormGammaSrcVectorSize,
// LayernormBetaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType,
HDataType,
MeanDataType,
VarDataType,
GammaDataType,
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
GammaBetaGridDesc_N,
BlockSize,
LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1),
LayernormThreadSliceSize_M_N::At(I0),
LayernormThreadSliceSize_M_N::At(I1),
LayernormESrcHDstVectorDim,
LayernormESrcVectorSize,
LayernormHDstVectorSize,
LayernormGammaSrcVectorSize,
LayernormBetaSrcVectorSize,
LayernormMeanVarSrcDstVectorSize>;
// Argument
struct Argument : public BaseArgument
......@@ -449,20 +484,24 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
block_2_etile_map_{GridwiseGemmWelford::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)},
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, gemm_nblock_, gemm_nblock_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
int gemm_welford_size = MRaw * blkGroupSize_;
int gemm_welford_size = MRaw * gemm_nblock_;
hip_check_error(
hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(
......@@ -502,8 +541,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_n_);
}
// TODO - H
}
void Print() const
......@@ -533,11 +570,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
HGridDesc_M_N h_grid_desc_m_n_;
EHGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -558,7 +595,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CDEElementwiseOperation cde_element_op_;
HElementwiseOperation h_element_op_;
int blkGroupSize_;
int gemm_nblock_;
AccDataType epsilon_;
};
......@@ -581,9 +618,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -612,12 +650,18 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>;
// const auto kernel_welford_layernorm =
// kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
// EDataType,
// HDataType,
// MeanDataType,
// VarDataType>;
const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType,
HDataType,
MeanDataType,
VarDataType,
GammaDataType,
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_N,
GammaBetaGridDesc_N>;
avg_time +=
launch_and_time_kernel(stream_config,
......@@ -642,16 +686,39 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
// avg_time += launch_and_time_kernel(stream_config,
// kernel_welford_layernorm,
// dim3(grid_size),
// dim3(BlockSize),
// 0,
// arg.p_e_grid_,
// arg.p_welford_mean_grid_,
// arg.p_welford_var_grid_,
// arg.p_h_grid_,
// arg.blkGroupSize_);
grid_size = math::integer_least_multiple(M, LayernormBlockTileSize_M_N::At(0)) /
LayernormBlockTileSize_M_N::At(0);
index_t numMeanVarCountBlockTileIteration_N =
math::integer_least_multiple(arg.gemm_nblock_,
LayernormThreadClusterSize_M_N::At(I1)) /
LayernormThreadClusterSize_M_N::At(I1);
index_t numXBlockTileIteration_N =
math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) /
LayernormBlockTileSize_M_N::At(I1);
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_welford_mean_grid_,
arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N,
numXBlockTileIteration_N,
arg.epsilon_);
return avg_time;
};
......@@ -681,6 +748,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return false;
}
// TODO
return true;
}
......
......@@ -23,25 +23,26 @@ template <typename EDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename XYGridDesc_M_N,
typename MeanVarGridDesc_M_N,
typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
typename GammaBetaGridDesc_N,
typename MeanVarGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t NThreadClusterSize,
index_t MThreadSliceSize,
index_t NThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t ESrcYDstVectorDim,
index_t ESrcVectorSize,
index_t YDstVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d
{
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
static constexpr bool reorder_thread_cluster = (ESrcYDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
......@@ -76,57 +77,38 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid,
/*const MeanVarGridDesc_M_N& mean_grid_desc_m_k,
const MeanVarGridDesc_M_N& var_grid_desc_m_k,
const GammaBetaGridDesc_N& gamma_grid_desc_m,
const GammaBetaGridDesc_N& beta_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,*/
index_t blkgroup_size)
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t gemm_nblock_,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
ComputeDataType epsilon)
{
ignore = p_e_grid;
ignore = p_mean_grid;
ignore = p_var_grid;
ignore = p_in_welford_mean_grid;
ignore = p_in_welford_var_grid;
ignore = p_in_welford_count_grid;
ignore = p_gamma_grid;
ignore = p_beta_grid;
ignore = p_h_grid;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
/*
auto threadwise_mean_load_m_n =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_grid_desc_m_n,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * 1));*/
ignore = e_grid_desc_m_n;
ignore = h_grid_desc_m_n;
ignore = mean_var_count_grid_desc_m_n;
ignore = gamma_grid_desc_n;
ignore = beta_grid_desc_n;
ignore = gemm_nblock_;
ignore = num_mean_var_count_k_block_tile_iteration;
ignore = num_xy_k_block_tile_iteration;
ignore = epsilon;
} // run
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment