Commit b58d5a7d authored by rocking's avatar rocking
Browse files

Fix bug of mean var padding for layernorm

parent db0a27ad
......@@ -254,7 +254,8 @@ int main()
h_device_buf.FromDevice(h_m_n.mData.data());
pass &= ck::utils::check_err(e_m_n, e_m_n_host);
pass &= ck::utils::check_err(h_m_n, h_m_n_host);
pass &=
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
}
return pass ? 0 : 1;
......
......@@ -111,7 +111,7 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock,
typename LayernormMeanVarCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation>
__global__ void
......@@ -128,7 +128,7 @@ __global__ void
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const LayernormMeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
......@@ -314,14 +314,25 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{});
}
static auto MakeMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
static auto MakeGemmMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
// TODO - padding according to MNperBlock of Gemm and Layernorm
// TODO - padding according to MNperBlock of Gemm
// CAUSION - GetWorkSpaceSize
return grid_desc_m_n;
}
static auto MakeLayernormMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
return PadTensorDescriptor(
grid_desc_m_n,
make_tuple(LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)),
Sequence<true, true>{});
}
static auto MakeDescriptor_M(index_t MRaw)
{
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
......@@ -375,7 +386,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_NBlock = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
using GemmMeanVarCountGridDesc_M_NBlock =
decltype(MakeGemmMeanVarCountGridDescriptor_M_NBlock(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock =
decltype(MakeLayernormMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
......@@ -395,7 +411,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K,
DsGridDesc_M_N,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
GemmMeanVarCountGridDesc_M_NBlock,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -440,7 +456,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
LayernormMeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation,
BlockSize,
......@@ -491,7 +507,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_nblock_{},
gemm_mean_var_count_grid_desc_m_nblock_{},
layernorm_mean_var_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
......@@ -507,8 +524,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
gemm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeGemmMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeLayernormMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -540,7 +560,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_nblock_);
gemm_mean_var_count_grid_desc_m_nblock_);
}
}
......@@ -572,7 +592,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_;
MeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock_;
GemmMeanVarCountGridDesc_M_NBlock gemm_mean_var_count_grid_desc_m_nblock_;
LayernormMeanVarCountGridDesc_M_NBlock layernorm_mean_var_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_;
......@@ -660,7 +681,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
MeanVarCountGridDesc_M_NBlock,
LayernormMeanVarCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation>;
......@@ -710,7 +731,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_nblock_,
arg.layernorm_mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
......@@ -745,7 +766,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
size_t workspace_size = 0;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// FIXME - padding
int gemm_welford_size =
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
......@@ -765,7 +788,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
pArg_->p_workspace_ = p_workspace;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
int gemm_welford_size =
pArg_->gemm_mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment