Commit f278b2a5 authored by rocking's avatar rocking
Browse files

Add EMeanVarDataType parameter.

parent d3f2dbbd
......@@ -39,6 +39,7 @@ using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EMeanVarDataType = F16;
using GammaDataType = F16;
using BetaDataType = F16;
using HDataType = F16;
......@@ -60,11 +61,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>;
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>;
// clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
......@@ -86,7 +87,7 @@ auto f_host_tensor_descriptor2d =
}
};
void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
void host_gemm_layernorm(Tensor<EMeanVarDataType>& e_m_n,
Tensor<HDataType>& h_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
......@@ -109,7 +110,7 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
BElementOp,
PassThrough>;
using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<HDataType,
using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<EMeanVarDataType,
GammaDataType,
BetaDataType,
HDataType,
......@@ -229,7 +230,7 @@ int main()
if(do_verification)
{
Tensor<HDataType> e_m_n_host(HostTensorDescriptor{M, N});
Tensor<EMeanVarDataType> e_m_n_host(HostTensorDescriptor{M, N});
Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N});
host_gemm_layernorm(e_m_n_host,
......
......@@ -23,9 +23,7 @@ namespace ck {
template <typename GridwiseGemmWelford,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename MeanDataType,
typename VarDataType,
typename EMeanVarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -45,9 +43,9 @@ __global__ void
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -111,10 +109,8 @@ __global__ void
}
template <typename GridwiseWelfordLayernorm,
typename EDataType,
typename EMeanVarDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
......@@ -128,9 +124,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(
const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_in_welford_var_grid,
const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
......@@ -192,6 +188,7 @@ template <typename ALayout,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EMeanVarDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
......@@ -249,16 +246,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
CDEElementwiseOperation,
HElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
// EDataType, MeanDataType and VarDataType must be the same.
// eg. M, N, K = [1, 1, 1],
// in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
// if (x - mean) != 0, (x - mean) * divisor * gamma might be too large
// However, (x - mean) * divisor * gamma should be 0 in this case
using EDataType = HDataType;
using MeanDataType = HDataType;
using VarDataType = HDataType;
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector;
......@@ -392,9 +387,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
MeanDataType,
VarDataType,
EMeanVarDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
......@@ -442,10 +435,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType,
GridwiseWelfordSecondHalfLayernorm2d<EMeanVarDataType,
HDataType,
MeanDataType,
VarDataType,
GammaDataType,
BetaDataType,
AccDataType,
......@@ -488,7 +479,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_h_grid)},
p_workspace_e_grid_{nullptr},
p_workspace_mean_{nullptr},
p_workspace_var_{nullptr},
p_workspace_count_{nullptr},
......@@ -611,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
void* p_workspace_e_grid_;
void* p_workspace_mean_;
void* p_workspace_var_;
void* p_workspace_count_;
......@@ -694,9 +685,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
GridwiseGemmWelford,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemmWelford::DsGridPointer,
EDataType,
MeanDataType,
VarDataType,
EMeanVarDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
......@@ -713,10 +702,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType,
EMeanVarDataType,
HDataType,
MeanDataType,
VarDataType,
GammaDataType,
BetaDataType,
AccDataType,
......@@ -735,9 +722,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
static_cast<MeanDataType*>(arg.p_workspace_mean_),
static_cast<VarDataType*>(arg.p_workspace_var_),
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
static_cast<EMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<EMeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_),
arg.a_element_op_,
arg.b_element_op_,
......@@ -760,29 +747,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
avg_time +=
launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
static_cast<const MeanDataType*>(arg.p_workspace_mean_),
static_cast<const VarDataType*>(arg.p_workspace_var_),
static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.layernorm_e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
NBlockClusterLength,
arg.epsilon_,
arg.h_element_op_);
avg_time += launch_and_time_kernel(
stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
static_cast<const EMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<const EMeanVarDataType*>(arg.p_workspace_var_),
static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.layernorm_e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
NBlockClusterLength,
arg.epsilon_,
arg.h_element_op_);
return avg_time;
};
......@@ -814,14 +801,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(VarDataType) + 64;
workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64;
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType);
return (workspace_size);
};
......@@ -836,20 +826,27 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = gemm_welford_size * sizeof(MeanDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = gemm_welford_size * sizeof(VarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count
pArg_->p_workspace_count_ =
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
index_t count_space_sz = gemm_welford_size * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
pArg_->p_workspace_e_grid_ =
reinterpret_cast<char*>(pArg_->p_workspace_count_) + count_space_sz;
else
pArg_->p_workspace_e_grid_ = static_cast<void*>(pArg_->p_h_grid_);
};
static bool IsSupportedArgument(const Argument& arg)
......
......@@ -36,9 +36,7 @@ template <typename ABDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename MeanDataType,
typename VarDataType,
typename EMeanVarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -329,7 +327,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB))
{
return false;
}
......@@ -370,9 +368,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid,
EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
......@@ -825,7 +823,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EDataType,
EMeanVarDataType,
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
......@@ -1042,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
MeanDataType,
EMeanVarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
......@@ -1062,7 +1060,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
VarDataType,
EMeanVarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
......
......@@ -19,10 +19,8 @@
namespace ck {
template <typename EDataType,
template <typename EMeanVarDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
......@@ -87,9 +85,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_in_welford_var_grid,
__device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid,
......@@ -176,7 +174,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
// IO
auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
......@@ -192,7 +190,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_n_cluster_id));
auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<VarDataType,
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
......@@ -224,7 +222,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_n_cluster_id));
auto threadwise_e_load_m_n =
ThreadwiseTensorSliceTransfer_v2<EDataType,
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType,
decltype(e_grid_desc_m_n),
decltype(thread_buffer_desc_m_n),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment