Commit 8749678a authored by rocking's avatar rocking
Browse files

Rename F and G to mean and var

parent 9a25afe4
......@@ -24,8 +24,8 @@ template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename FDataType,
typename GDataType,
typename MeanDataType,
typename VarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -33,8 +33,8 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename FGridDescriptor_MBlock_MPerBlock_NBlock,
typename GGridDescriptor_MBlock_MPerBlock_NBlock,
typename MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
......@@ -46,8 +46,8 @@ __global__ void
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
FDataType* __restrict__ p_f_grid,
GDataType* __restrict__ p_g_grid,
MeanDataType* __restrict__ p_mean_grid,
VarDataType* __restrict__ p_var_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
......@@ -57,8 +57,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const FGridDescriptor_MBlock_MPerBlock_NBlock f_grid_desc_mblock_mperblock_nblock,
const GGridDescriptor_MBlock_MPerBlock_NBlock g_grid_desc_mblock_mperblock_nblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock mean_grid_desc_mblock_mperblock_nblock,
const VarGridDescriptor_MBlock_MPerBlock_NBlock var_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -68,8 +68,8 @@ __global__ void
p_b_grid,
p_ds_grid,
p_e_grid,
p_f_grid,
p_g_grid,
p_mean_grid,
p_var_grid,
p_shared,
a_element_op,
b_element_op,
......@@ -78,16 +78,16 @@ __global__ void
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
f_grid_desc_mblock_mperblock_nblock,
g_grid_desc_mblock_mperblock_nblock,
mean_grid_desc_mblock_mperblock_nblock,
var_grid_desc_mblock_mperblock_nblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = p_f_grid;
ignore = p_g_grid;
ignore = p_mean_grid;
ignore = p_var_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
......@@ -95,8 +95,8 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = f_grid_desc_mblock_mperblock_nblock;
ignore = g_grid_desc_mblock_mperblock_nblock;
ignore = mean_grid_desc_mblock_mperblock_nblock;
ignore = var_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map;
#endif
}
......@@ -186,8 +186,8 @@ template <typename ALayout,
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using FDataType = CShuffleDataType;
using GDataType = CShuffleDataType;
using MeanDataType = CShuffleDataType;
using VarDataType = CShuffleDataType;
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -268,8 +268,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using FGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using VarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm
......@@ -279,8 +279,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CShuffleDataType,
DsDataType,
EDataType,
FDataType,
GDataType,
MeanDataType,
VarDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
......@@ -289,8 +289,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
FGridDesc_M_N,
GGridDesc_M_N,
MeanGridDesc_M_N,
VarGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -328,7 +328,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, FDataType, GDataType>;
GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, MeanDataType, VarDataType>;
// Argument
struct Argument : public BaseArgument
......@@ -354,15 +354,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_f_grid_{nullptr},
p_g_grid_{nullptr},
p_mean_grid_{nullptr},
p_var_grid_{nullptr},
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
f_grid_desc_m_n_{},
g_grid_desc_m_n_{},
mean_grid_desc_m_n_{},
var_grid_desc_m_n_{},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
......@@ -371,14 +371,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)}
{
f_grid_desc_m_n_ =
mean_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
g_grid_desc_m_n_ =
var_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
int welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_f_grid_, sizeof(FDataType) * welford_size));
hip_check_error(hipMalloc(&p_g_grid_, sizeof(GDataType) * welford_size));
hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * welford_size));
hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * welford_size));
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -398,8 +398,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
f_grid_desc_m_n_,
g_grid_desc_m_n_,
mean_grid_desc_m_n_,
var_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
......@@ -410,11 +410,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
f_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(f_grid_desc_m_n_);
mean_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
mean_grid_desc_m_n_);
g_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(g_grid_desc_m_n_);
var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemm::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
var_grid_desc_m_n_);
}
// TODO - H
......@@ -436,8 +438,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
FDataType* p_f_grid_; // mean
GDataType* p_g_grid_; // variance * count
MeanDataType* p_mean_grid_; // mean
VarDataType* p_var_grid_; // variance * count
HDataType* p_h_grid_;
// tensor descriptors for problem definiton
......@@ -445,8 +447,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
FGridDesc_M_N f_grid_desc_m_n_;
GGridDesc_M_N g_grid_desc_m_n_;
MeanGridDesc_M_N mean_grid_desc_m_n_;
VarGridDesc_M_N var_grid_desc_m_n_;
HGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
......@@ -456,10 +458,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock
f_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock
g_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock
var_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
......@@ -486,8 +488,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.f_grid_desc_m_n_,
arg.g_grid_desc_m_n_,
arg.mean_grid_desc_m_n_,
arg.var_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
......@@ -508,8 +510,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
FDataType,
GDataType,
MeanDataType,
VarDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
......@@ -517,8 +519,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemm::DefaultBGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::FGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::GGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
......@@ -526,8 +528,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType,
HDataType,
FDataType,
GDataType>;
MeanDataType,
VarDataType>;
avg_time +=
launch_and_time_kernel(stream_config,
......@@ -539,8 +541,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.p_f_grid_,
arg.p_g_grid_,
arg.p_mean_grid_,
arg.p_var_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
......@@ -548,8 +550,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.f_grid_desc_mblock_mperblock_nblock_,
arg.g_grid_desc_mblock_mperblock_nblock_,
arg.mean_grid_desc_mblock_mperblock_nblock_,
arg.var_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config,
......@@ -558,8 +560,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_f_grid_,
arg.p_g_grid_,
arg.p_mean_grid_,
arg.p_var_grid_,
arg.p_h_grid_);
return avg_time;
......
......@@ -37,8 +37,8 @@ template <typename ABDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename FDataType,
typename GDataType,
typename MeanDataType,
typename VarDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -47,8 +47,8 @@ template <typename ABDataType,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename FGridDesc_M_N,
typename GGridDesc_M_N,
typename MeanGridDesc_M_N,
typename VarGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -242,10 +242,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{});
}
// TODO - MakeFGGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1);
......@@ -271,12 +271,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const FGridDesc_M_N& f_grid_desc_m_n,
const GGridDesc_M_N& g_grid_desc_m_n,
const MeanGridDesc_M_N& mean_grid_desc_m_n,
const VarGridDesc_M_N& var_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
......@@ -289,9 +290,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
M == f_grid_desc_m_n.GetLength(I0) && M == g_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == f_grid_desc_m_n.GetLength(I1) &&
N / NPerBlock == g_grid_desc_m_n.GetLength(I1)))
M == mean_grid_desc_m_n.GetLength(I0) && M == var_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_grid_desc_m_n.GetLength(I1) &&
N / NPerBlock == var_grid_desc_m_n.GetLength(I1)))
{
return false;
}
......@@ -355,10 +356,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using FGridDescriptor_MBlock_MPerBlock_NBlock =
remove_cvref_t<decltype(MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(FGridDesc_M_N{}))>;
using GGridDescriptor_MBlock_MPerBlock_NBlock =
remove_cvref_t<decltype(MakeFGGridDescriptor_MBlock_MPerBlock_NBlock(GGridDesc_M_N{}))>;
using MeanGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(MeanGridDesc_M_N{}))>;
using VarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(VarGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......@@ -376,8 +377,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
FDataType* __restrict__ p_f_grid,
GDataType* __restrict__ p_g_grid,
MeanDataType* __restrict__ p_mean_grid,
VarDataType* __restrict__ p_var_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -388,8 +389,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const FGridDescriptor_MBlock_MPerBlock_NBlock& f_grid_desc_mblock_mperblock_nblock,
const GGridDescriptor_MBlock_MPerBlock_NBlock& g_grid_desc_mblock_mperblock_nblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock& mean_grid_desc_mblock_mperblock_nblock,
const VarGridDescriptor_MBlock_MPerBlock_NBlock& var_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -409,11 +410,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto f_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_f_grid, f_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_grid, mean_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto g_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_g_grid, g_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_var_grid, var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
......@@ -989,11 +990,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(mreduce_per_thread % FGTransferScalarPerVector == 0);
auto f_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
FDataType,
MeanDataType,
decltype(thread_welford_desc_I_m_I),
decltype(f_grid_desc_mblock_mperblock_nblock),
decltype(mean_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>,
Sequence<0, 1, 2>,
......@@ -1001,18 +1002,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
FGTransferScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
false>{f_grid_desc_mblock_mperblock_nblock,
false>{mean_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
auto g_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
GDataType,
VarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(g_grid_desc_mblock_mperblock_nblock),
decltype(var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>,
Sequence<0, 1, 2>,
......@@ -1020,24 +1021,24 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
FGTransferScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
false>{g_grid_desc_mblock_mperblock_nblock,
false>{var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
f_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
f_grid_desc_mblock_mperblock_nblock,
f_grid_buf);
mean_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
g_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
g_grid_desc_mblock_mperblock_nblock,
g_grid_buf);
var_grid_desc_mblock_mperblock_nblock,
var_grid_buf);
});
} // shuffle C + Ds + welford + write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment