Commit f278b2a5 authored by rocking's avatar rocking
Browse files

Add EMeanVarDataType parameter.

parent d3f2dbbd
...@@ -39,6 +39,7 @@ using CShuffleDataType = F32; ...@@ -39,6 +39,7 @@ using CShuffleDataType = F32;
using D0DataType = F16; using D0DataType = F16;
using D1DataType = F16; using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EMeanVarDataType = F16;
using GammaDataType = F16; using GammaDataType = F16;
using BetaDataType = F16; using BetaDataType = F16;
using HDataType = F16; using HDataType = F16;
...@@ -60,11 +61,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -60,11 +61,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| //######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M| //######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>; < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>;
// clang-format on // clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
...@@ -86,7 +87,7 @@ auto f_host_tensor_descriptor2d = ...@@ -86,7 +87,7 @@ auto f_host_tensor_descriptor2d =
} }
}; };
void host_gemm_layernorm(Tensor<HDataType>& e_m_n, void host_gemm_layernorm(Tensor<EMeanVarDataType>& e_m_n,
Tensor<HDataType>& h_m_n, Tensor<HDataType>& h_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
...@@ -109,7 +110,7 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n, ...@@ -109,7 +110,7 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
BElementOp, BElementOp,
PassThrough>; PassThrough>;
using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<HDataType, using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<EMeanVarDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
HDataType, HDataType,
...@@ -229,7 +230,7 @@ int main() ...@@ -229,7 +230,7 @@ int main()
if(do_verification) if(do_verification)
{ {
Tensor<HDataType> e_m_n_host(HostTensorDescriptor{M, N}); Tensor<EMeanVarDataType> e_m_n_host(HostTensorDescriptor{M, N});
Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N}); Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N});
host_gemm_layernorm(e_m_n_host, host_gemm_layernorm(e_m_n_host,
......
...@@ -23,9 +23,7 @@ namespace ck { ...@@ -23,9 +23,7 @@ namespace ck {
template <typename GridwiseGemmWelford, template <typename GridwiseGemmWelford,
typename ABDataType, typename ABDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EMeanVarDataType,
typename MeanDataType,
typename VarDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -45,9 +43,9 @@ __global__ void ...@@ -45,9 +43,9 @@ __global__ void
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EMeanVarDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid, EMeanVarDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid, EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid, int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -111,10 +109,8 @@ __global__ void ...@@ -111,10 +109,8 @@ __global__ void
} }
template <typename GridwiseWelfordLayernorm, template <typename GridwiseWelfordLayernorm,
typename EDataType, typename EMeanVarDataType,
typename HDataType, typename HDataType,
typename MeanDataType,
typename VarDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
...@@ -128,9 +124,9 @@ __global__ void ...@@ -128,9 +124,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_welford_layernorm2d_second_half( kernel_welford_layernorm2d_second_half(
const EDataType* __restrict__ p_e_grid, const EMeanVarDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid, const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_in_welford_var_grid, const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid, const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid, const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid, const BetaDataType* __restrict__ p_beta_grid,
...@@ -192,6 +188,7 @@ template <typename ALayout, ...@@ -192,6 +188,7 @@ template <typename ALayout,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename EMeanVarDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename HDataType, typename HDataType,
...@@ -249,16 +246,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -249,16 +246,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
CDEElementwiseOperation, CDEElementwiseOperation,
HElementwiseOperation> HElementwiseOperation>
{ {
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
// EDataType, MeanDataType and VarDataType must be the same. // EDataType, MeanDataType and VarDataType must be the same.
// eg. M, N, K = [1, 1, 1], // eg. M, N, K = [1, 1, 1],
// in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783 // in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
// if (x - mean) != 0, (x - mean) * divisor * gamma might be too large // if (x - mean) != 0, (x - mean) * divisor * gamma might be too large
// However, (x - mean) * divisor * gamma should be 0 in this case // However, (x - mean) * divisor * gamma should be 0 in this case
using EDataType = HDataType;
using MeanDataType = HDataType; using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using VarDataType = HDataType; using ELayout = HLayout;
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector; static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector;
...@@ -392,9 +387,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -392,9 +387,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
EDataType, EMeanVarDataType,
MeanDataType,
VarDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
...@@ -442,10 +435,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -442,10 +435,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm = using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, GridwiseWelfordSecondHalfLayernorm2d<EMeanVarDataType,
HDataType, HDataType,
MeanDataType,
VarDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
...@@ -488,7 +479,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -488,7 +479,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_h_grid)}, p_workspace_e_grid_{nullptr},
p_workspace_mean_{nullptr}, p_workspace_mean_{nullptr},
p_workspace_var_{nullptr}, p_workspace_var_{nullptr},
p_workspace_count_{nullptr}, p_workspace_count_{nullptr},
...@@ -611,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -611,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_; typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; void* p_workspace_e_grid_;
void* p_workspace_mean_; void* p_workspace_mean_;
void* p_workspace_var_; void* p_workspace_var_;
void* p_workspace_count_; void* p_workspace_count_;
...@@ -694,9 +685,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -694,9 +685,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
GridwiseGemmWelford, GridwiseGemmWelford,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemmWelford::DsGridPointer, typename GridwiseGemmWelford::DsGridPointer,
EDataType, EMeanVarDataType,
MeanDataType,
VarDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
...@@ -713,10 +702,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -713,10 +702,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
const auto kernel_welford_layernorm = const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm, kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType, EMeanVarDataType,
HDataType, HDataType,
MeanDataType,
VarDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
...@@ -735,9 +722,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -735,9 +722,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
static_cast<MeanDataType*>(arg.p_workspace_mean_), static_cast<EMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<VarDataType*>(arg.p_workspace_var_), static_cast<EMeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_), static_cast<int32_t*>(arg.p_workspace_count_),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -760,29 +747,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -760,29 +747,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil( index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1)); arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
avg_time += avg_time += launch_and_time_kernel(
launch_and_time_kernel(stream_config, stream_config,
kernel_welford_layernorm, kernel_welford_layernorm,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_e_grid_, static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
static_cast<const MeanDataType*>(arg.p_workspace_mean_), static_cast<const EMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<const VarDataType*>(arg.p_workspace_var_), static_cast<const EMeanVarDataType*>(arg.p_workspace_var_),
static_cast<const int32_t*>(arg.p_workspace_count_), static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_gamma_grid_, arg.p_gamma_grid_,
arg.p_beta_grid_, arg.p_beta_grid_,
arg.p_h_grid_, arg.p_h_grid_,
arg.layernorm_e_grid_desc_m_n_, arg.layernorm_e_grid_desc_m_n_,
arg.h_grid_desc_m_n_, arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_grid_desc_m_nblock_, arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_, arg.layernorm_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_, arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_, arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
NBlockClusterLength, NBlockClusterLength,
arg.epsilon_, arg.epsilon_,
arg.h_element_op_); arg.h_element_op_);
return avg_time; return avg_time;
}; };
...@@ -814,14 +801,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -814,14 +801,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_; int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
// workspace for welford intermediate mean // workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64; workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
// workspace for welford intermediate mean // workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(VarDataType) + 64; workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64; workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64;
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType);
return (workspace_size); return (workspace_size);
}; };
...@@ -836,20 +826,27 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -836,20 +826,27 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = gemm_welford_size * sizeof(MeanDataType); index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance // setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz; pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = gemm_welford_size * sizeof(VarDataType); index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count // setup buffer used for intermediate welford count
pArg_->p_workspace_count_ = pArg_->p_workspace_count_ =
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz; reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
index_t count_space_sz = gemm_welford_size * sizeof(int32_t);
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
if constexpr(!is_same_v<EMeanVarDataType, HDataType>)
pArg_->p_workspace_e_grid_ =
reinterpret_cast<char*>(pArg_->p_workspace_count_) + count_space_sz;
else
pArg_->p_workspace_e_grid_ = static_cast<void*>(pArg_->p_h_grid_);
}; };
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
......
...@@ -36,9 +36,7 @@ template <typename ABDataType, ...@@ -36,9 +36,7 @@ template <typename ABDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename EDataType, typename EMeanVarDataType,
typename MeanDataType,
typename VarDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -329,7 +327,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -329,7 +327,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB))
{ {
return false; return false;
} }
...@@ -370,9 +368,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -370,9 +368,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Run(const ABDataType* __restrict__ p_a_grid, Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EMeanVarDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid, EMeanVarDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid, EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count, int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -825,7 +823,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -825,7 +823,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
EDataType, EMeanVarDataType,
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock), decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1042,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1042,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
MeanDataType, EMeanVarDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock), decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1062,7 +1060,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1062,7 +1060,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
VarDataType, EMeanVarDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_var_grid_desc_mblock_mperblock_nblock), decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
......
...@@ -19,10 +19,8 @@ ...@@ -19,10 +19,8 @@
namespace ck { namespace ck {
template <typename EDataType, template <typename EMeanVarDataType,
typename HDataType, typename HDataType,
typename MeanDataType,
typename VarDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
...@@ -87,9 +85,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -87,9 +85,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize; static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
__device__ static void Run(const EDataType* __restrict__ p_e_grid, __device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid, const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const VarDataType* __restrict__ p_in_welford_var_grid, const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid, const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid, const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid, const BetaDataType* __restrict__ p_beta_grid,
...@@ -176,7 +174,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -176,7 +174,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
// IO // IO
auto threadwise_mean_load_m_nblock = auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<MeanDataType, ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType, ComputeDataType,
MeanVarGridDesc_M_NBlock, MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
...@@ -192,7 +190,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -192,7 +190,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_n_cluster_id)); thread_n_cluster_id));
auto threadwise_var_load_m_nblock = auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<VarDataType, ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType, ComputeDataType,
MeanVarGridDesc_M_NBlock, MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
...@@ -224,7 +222,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -224,7 +222,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_n_cluster_id)); thread_n_cluster_id));
auto threadwise_e_load_m_n = auto threadwise_e_load_m_n =
ThreadwiseTensorSliceTransfer_v2<EDataType, ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
ComputeDataType, ComputeDataType,
decltype(e_grid_desc_m_n), decltype(e_grid_desc_m_n),
decltype(thread_buffer_desc_m_n), decltype(thread_buffer_desc_m_n),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment