Commit b58d5a7d authored by rocking's avatar rocking
Browse files

Fix bug of mean var padding for layernorm

parent db0a27ad
...@@ -254,7 +254,8 @@ int main() ...@@ -254,7 +254,8 @@ int main()
h_device_buf.FromDevice(h_m_n.mData.data()); h_device_buf.FromDevice(h_m_n.mData.data());
pass &= ck::utils::check_err(e_m_n, e_m_n_host); pass &= ck::utils::check_err(e_m_n, e_m_n_host);
pass &= ck::utils::check_err(h_m_n, h_m_n_host); pass &=
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -111,7 +111,7 @@ template <typename GridwiseWelfordLayernorm, ...@@ -111,7 +111,7 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename EHGridDesc_M_N, typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock, typename LayernormMeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N, typename GammaBetaGridDesc_N,
typename HElementwiseOperation> typename HElementwiseOperation>
__global__ void __global__ void
...@@ -128,7 +128,7 @@ __global__ void ...@@ -128,7 +128,7 @@ __global__ void
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock, const LayernormMeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
...@@ -314,14 +314,25 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -314,14 +314,25 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
static auto MakeMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock) static auto MakeGemmMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{ {
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock)); const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
// TODO - padding according to MNperBlock of Gemm and Layernorm // TODO - padding according to MNperBlock of Gemm
// CAUSION - GetWorkSpaceSize
return grid_desc_m_n; return grid_desc_m_n;
} }
static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
return PadTensorDescriptor(
grid_desc_m_n,
make_tuple(LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)),
Sequence<true, true>{});
}
static auto MakeDescriptor_M(index_t MRaw) static auto MakeDescriptor_M(index_t MRaw)
{ {
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -375,9 +386,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -375,9 +386,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_NBlock = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1)); // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1)); // layout(different padding)
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1)); using GemmMeanVarCountGridDesc_M_NBlock =
decltype(MakeGemmMeanVarCountGridDescriptor_M_NBlock(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock =
decltype(MakeLayernormMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -395,7 +411,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -395,7 +411,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock, GemmMeanVarCountGridDesc_M_NBlock,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -440,7 +456,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -440,7 +456,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock, LayernormMeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N, GammaBetaGridDesc_N,
HElementwiseOperation, HElementwiseOperation,
BlockSize, BlockSize,
...@@ -491,7 +507,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -491,7 +507,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_nblock_{}, gemm_mean_var_count_grid_desc_m_nblock_{},
layernorm_mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
...@@ -507,8 +524,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -507,8 +524,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
mean_var_count_grid_desc_m_nblock_ = gemm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_); DeviceOp::MakeGemmMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeLayernormMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -540,7 +560,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -540,7 +560,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_ = mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_nblock_); gemm_mean_var_count_grid_desc_m_nblock_);
} }
} }
...@@ -572,7 +592,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -572,7 +592,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_; EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock_; GemmMeanVarCountGridDesc_M_NBlock gemm_mean_var_count_grid_desc_m_nblock_;
LayernormMeanVarCountGridDesc_M_NBlock layernorm_mean_var_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_; GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_; GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_; EHGridDesc_M_N h_grid_desc_m_n_;
...@@ -660,7 +681,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -660,7 +681,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock, LayernormMeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N, GammaBetaGridDesc_N,
HElementwiseOperation>; HElementwiseOperation>;
...@@ -710,7 +731,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -710,7 +731,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_, arg.p_h_grid_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_, arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_nblock_, arg.layernorm_mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_, arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_, arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
...@@ -745,7 +766,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -745,7 +766,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
size_t workspace_size = 0; size_t workspace_size = 0;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize(); // FIXME - padding
int gemm_welford_size =
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// workspace for welford intermediate mean // workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64; workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
...@@ -765,7 +788,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -765,7 +788,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
pArg_->p_workspace_ = p_workspace; pArg_->p_workspace_ = p_workspace;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize(); int gemm_welford_size =
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_; // int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment