Commit 8749678a authored by rocking's avatar rocking
Browse files

Rename F and G to mean and var

parent 9a25afe4
...@@ -24,8 +24,8 @@ template <typename GridwiseGemm, ...@@ -24,8 +24,8 @@ template <typename GridwiseGemm,
typename ABDataType, typename ABDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename FDataType, typename MeanDataType,
typename GDataType, typename VarDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -33,8 +33,8 @@ template <typename GridwiseGemm, ...@@ -33,8 +33,8 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename FGridDescriptor_MBlock_MPerBlock_NBlock, typename MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename GGridDescriptor_MBlock_MPerBlock_NBlock, typename VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap, typename Block2ETileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -46,8 +46,8 @@ __global__ void ...@@ -46,8 +46,8 @@ __global__ void
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
FDataType* __restrict__ p_f_grid, MeanDataType* __restrict__ p_mean_grid,
GDataType* __restrict__ p_g_grid, VarDataType* __restrict__ p_var_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
...@@ -57,8 +57,8 @@ __global__ void ...@@ -57,8 +57,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const FGridDescriptor_MBlock_MPerBlock_NBlock f_grid_desc_mblock_mperblock_nblock, const MeanGridDescriptor_MBlock_MPerBlock_NBlock mean_grid_desc_mblock_mperblock_nblock,
const GGridDescriptor_MBlock_MPerBlock_NBlock g_grid_desc_mblock_mperblock_nblock, const VarGridDescriptor_MBlock_MPerBlock_NBlock var_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -68,8 +68,8 @@ __global__ void ...@@ -68,8 +68,8 @@ __global__ void
p_b_grid, p_b_grid,
p_ds_grid, p_ds_grid,
p_e_grid, p_e_grid,
p_f_grid, p_mean_grid,
p_g_grid, p_var_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -78,16 +78,16 @@ __global__ void ...@@ -78,16 +78,16 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
f_grid_desc_mblock_mperblock_nblock, mean_grid_desc_mblock_mperblock_nblock,
g_grid_desc_mblock_mperblock_nblock, var_grid_desc_mblock_mperblock_nblock,
block_2_etile_map); block_2_etile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = p_f_grid; ignore = p_mean_grid;
ignore = p_g_grid; ignore = p_var_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
...@@ -95,8 +95,8 @@ __global__ void ...@@ -95,8 +95,8 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = f_grid_desc_mblock_mperblock_nblock; ignore = mean_grid_desc_mblock_mperblock_nblock;
ignore = g_grid_desc_mblock_mperblock_nblock; ignore = var_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map; ignore = block_2_etile_map;
#endif #endif
} }
...@@ -185,9 +185,9 @@ template <typename ALayout, ...@@ -185,9 +185,9 @@ template <typename ALayout,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{ {
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle; using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using FDataType = CShuffleDataType; using MeanDataType = CShuffleDataType;
using GDataType = CShuffleDataType; using VarDataType = CShuffleDataType;
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -264,13 +264,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -264,13 +264,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using FGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using MeanGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using VarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1)); using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
...@@ -279,8 +279,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -279,8 +279,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
EDataType, EDataType,
FDataType, MeanDataType,
GDataType, VarDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
...@@ -289,8 +289,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -289,8 +289,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
FGridDesc_M_N, MeanGridDesc_M_N,
GGridDesc_M_N, VarGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -328,7 +328,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -328,7 +328,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm = using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, FDataType, GDataType>; GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, MeanDataType, VarDataType>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -354,15 +354,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -354,15 +354,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_f_grid_{nullptr}, p_mean_grid_{nullptr},
p_g_grid_{nullptr}, p_var_grid_{nullptr},
p_h_grid_{static_cast<HDataType*>(p_h_grid)}, p_h_grid_{static_cast<HDataType*>(p_h_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
f_grid_desc_m_n_{}, mean_grid_desc_m_n_{},
g_grid_desc_m_n_{}, var_grid_desc_m_n_{},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -371,14 +371,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -371,14 +371,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
h_element_op_{h_element_op}, h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)} blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)}
{ {
f_grid_desc_m_n_ = mean_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_); DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
g_grid_desc_m_n_ = var_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_); DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
int welford_size = MRaw * blkGroupSize_; int welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_f_grid_, sizeof(FDataType) * welford_size)); hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * welford_size));
hip_check_error(hipMalloc(&p_g_grid_, sizeof(GDataType) * welford_size)); hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * welford_size));
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -398,8 +398,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -398,8 +398,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
f_grid_desc_m_n_, mean_grid_desc_m_n_,
g_grid_desc_m_n_, var_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
...@@ -410,11 +410,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -410,11 +410,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
f_grid_desc_mblock_mperblock_nblock_ = mean_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(f_grid_desc_m_n_); GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
mean_grid_desc_m_n_);
g_grid_desc_mblock_mperblock_nblock_ = var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(g_grid_desc_m_n_); GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
var_grid_desc_m_n_);
} }
// TODO - H // TODO - H
...@@ -436,8 +438,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -436,8 +438,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
FDataType* p_f_grid_; // mean MeanDataType* p_mean_grid_; // mean
GDataType* p_g_grid_; // variance * count VarDataType* p_var_grid_; // variance * count
HDataType* p_h_grid_; HDataType* p_h_grid_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
...@@ -445,8 +447,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -445,8 +447,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
FGridDesc_M_N f_grid_desc_m_n_; MeanGridDesc_M_N mean_grid_desc_m_n_;
GGridDesc_M_N g_grid_desc_m_n_; VarGridDesc_M_N var_grid_desc_m_n_;
HGridDesc_M_N h_grid_desc_m_n_; HGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -456,10 +458,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -456,10 +458,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock
f_grid_desc_mblock_mperblock_nblock_; mean_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock
g_grid_desc_mblock_mperblock_nblock_; var_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
...@@ -486,8 +488,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -486,8 +488,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.f_grid_desc_m_n_, arg.mean_grid_desc_m_n_,
arg.g_grid_desc_m_n_, arg.var_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
...@@ -508,8 +510,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -508,8 +510,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
FDataType, MeanDataType,
GDataType, VarDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
...@@ -517,8 +519,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -517,8 +519,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1, typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock, typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock, typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
...@@ -526,8 +528,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -526,8 +528,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm, kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType, EDataType,
HDataType, HDataType,
FDataType, MeanDataType,
GDataType>; VarDataType>;
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -539,8 +541,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -539,8 +541,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_b_grid_, arg.p_b_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
arg.p_f_grid_, arg.p_mean_grid_,
arg.p_g_grid_, arg.p_var_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
...@@ -548,8 +550,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -548,8 +550,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.f_grid_desc_mblock_mperblock_nblock_, arg.mean_grid_desc_mblock_mperblock_nblock_,
arg.g_grid_desc_mblock_mperblock_nblock_, arg.var_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -558,8 +560,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -558,8 +560,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_e_grid_, arg.p_e_grid_,
arg.p_f_grid_, arg.p_mean_grid_,
arg.p_g_grid_, arg.p_var_grid_,
arg.p_h_grid_); arg.p_h_grid_);
return avg_time; return avg_time;
......
...@@ -37,8 +37,8 @@ template <typename ABDataType, ...@@ -37,8 +37,8 @@ template <typename ABDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
typename FDataType, typename MeanDataType,
typename GDataType, typename VarDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -47,8 +47,8 @@ template <typename ABDataType, ...@@ -47,8 +47,8 @@ template <typename ABDataType,
typename BGridDesc_N_K, typename BGridDesc_N_K,
typename DsGridDesc_M_N, typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
typename FGridDesc_M_N, typename MeanGridDesc_M_N,
typename GGridDesc_M_N, typename VarGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -242,10 +242,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -242,10 +242,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// TODO - MakeFGGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock // TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N> template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n) MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{ {
const auto M = grid_desc_m_n.GetLength(I0); const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1); const auto NBlock = grid_desc_m_n.GetLength(I1);
...@@ -271,13 +271,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -271,13 +271,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap> template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, __host__ __device__ static constexpr bool
const BGridDesc_N_K& b_grid_desc_n_k, CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const DsGridDesc_M_N& ds_grid_desc_m_n, const BGridDesc_N_K& b_grid_desc_n_k,
const EGridDesc_M_N& e_grid_desc_m_n, const DsGridDesc_M_N& ds_grid_desc_m_n,
const FGridDesc_M_N& f_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const GGridDesc_M_N& g_grid_desc_m_n, const MeanGridDesc_M_N& mean_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map) const VarGridDesc_M_N& var_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -289,9 +290,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -289,9 +290,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
M == f_grid_desc_m_n.GetLength(I0) && M == g_grid_desc_m_n.GetLength(I0) && M == mean_grid_desc_m_n.GetLength(I0) && M == var_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == f_grid_desc_m_n.GetLength(I1) && N / NPerBlock == mean_grid_desc_m_n.GetLength(I1) &&
N / NPerBlock == g_grid_desc_m_n.GetLength(I1))) N / NPerBlock == var_grid_desc_m_n.GetLength(I1)))
{ {
return false; return false;
} }
...@@ -353,12 +354,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -353,12 +354,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 = using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using FGridDescriptor_MBlock_MPerBlock_NBlock = using MeanGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
remove_cvref_t<decltype(MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(FGridDesc_M_N{}))>; MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(MeanGridDesc_M_N{}))>;
using GGridDescriptor_MBlock_MPerBlock_NBlock = using VarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
remove_cvref_t<decltype(MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(GGridDesc_M_N{}))>; MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(VarGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
...@@ -376,8 +377,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -376,8 +377,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
FDataType* __restrict__ p_f_grid, MeanDataType* __restrict__ p_mean_grid,
GDataType* __restrict__ p_g_grid, VarDataType* __restrict__ p_var_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -388,8 +389,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -388,8 +389,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const FGridDescriptor_MBlock_MPerBlock_NBlock& f_grid_desc_mblock_mperblock_nblock, const MeanGridDescriptor_MBlock_MPerBlock_NBlock& mean_grid_desc_mblock_mperblock_nblock,
const GGridDescriptor_MBlock_MPerBlock_NBlock& g_grid_desc_mblock_mperblock_nblock, const VarGridDescriptor_MBlock_MPerBlock_NBlock& var_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -409,11 +410,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -409,11 +410,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto f_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_f_grid, f_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); p_mean_grid, mean_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto g_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_g_grid, g_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); p_var_grid, var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -989,11 +990,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -989,11 +990,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(mreduce_per_thread % FGTransferScalarPerVector == 0); static_assert(mreduce_per_thread % FGTransferScalarPerVector == 0);
auto f_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
FDataType, MeanDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(f_grid_desc_mblock_mperblock_nblock), decltype(mean_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>, Sequence<1, mreduce_per_thread, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
...@@ -1001,18 +1002,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1001,18 +1002,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
FGTransferScalarPerVector, FGTransferScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{f_grid_desc_mblock_mperblock_nblock, false>{mean_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock c_reduce_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
auto g_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
GDataType, VarDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(g_grid_desc_mblock_mperblock_nblock), decltype(var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>, Sequence<1, mreduce_per_thread, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
...@@ -1020,24 +1021,24 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1020,24 +1021,24 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
FGTransferScalarPerVector, FGTransferScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{g_grid_desc_mblock_mperblock_nblock, false>{var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock c_reduce_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
f_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
mean_thread_buf, mean_thread_buf,
f_grid_desc_mblock_mperblock_nblock, mean_grid_desc_mblock_mperblock_nblock,
f_grid_buf); mean_grid_buf);
g_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
var_thread_buf, var_thread_buf,
g_grid_desc_mblock_mperblock_nblock, var_grid_desc_mblock_mperblock_nblock,
g_grid_buf); var_grid_buf);
}); });
} // shuffle C + Ds + welford + write out } // shuffle C + Ds + welford + write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment