Commit 694057a7 authored by rocking's avatar rocking
Browse files

Rewrite the device interface and rename some var

parent 982c85a3
...@@ -38,7 +38,8 @@ using CShuffleDataType = F32; ...@@ -38,7 +38,8 @@ using CShuffleDataType = F32;
using D0DataType = F16; using D0DataType = F16;
using D1DataType = F16; using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16; using GammaDataType = F16;
using BetaDataType = F16;
using HDataType = F16; using HDataType = F16;
// Layout // Layout
...@@ -59,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -59,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| ELayout| HLayout| AData| BData| AccData| CShuffle| DsData| EData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| ReduceThreadTransfer| | //######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| | | | | | Type| Type| Type| DataType| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4>; < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
// clang-format on // clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
...@@ -96,33 +97,39 @@ int main() ...@@ -96,33 +97,39 @@ int main()
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideD0 = 0; ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 1024; ck::index_t StrideD1 = 1024;
ck::index_t StrideE = 1024;
ck::index_t StrideH = 1024; ck::index_t StrideH = 1024;
// TODO - gamma and beta float epsilon = 1e-5;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1)); Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{})); Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
d0_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1}); d0_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize()); DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_n.mData.data()); d0_device_buf.ToDevice(d0_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -135,7 +142,8 @@ int main() ...@@ -135,7 +142,8 @@ int main()
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(), h_device_buf.GetDeviceBuffer(),
M, M,
N, N,
...@@ -143,8 +151,8 @@ int main() ...@@ -143,8 +151,8 @@ int main()
StrideA, StrideA,
StrideB, StrideB,
{StrideD0, StrideD1}, {StrideD0, StrideD1},
StrideE,
StrideH, StrideH,
epsilon,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" // #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
...@@ -101,22 +102,23 @@ __global__ void ...@@ -101,22 +102,23 @@ __global__ void
#endif #endif
} }
template <typename GridwiseWelfordLayernorm, // template <typename GridwiseWelfordLayernorm,
typename XDataType, // typename EDataType,
typename YDataType, // typename HDataType,
typename MeanDataType, // typename MeanDataType,
typename VarDataType> // typename VarDataType>
__global__ void // __global__ void
#if CK_USE_LAUNCH_BOUNDS // #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif // #endif
kernel_welford_layernorm2d_second_half(const XDataType* __restrict__ p_x_grid, // kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid, // const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid, // const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid) // HDataType* __restrict__ p_y_grid,
{ // index_t blkgroup_size)
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid); // {
} // GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
// }
} // namespace ck } // namespace ck
...@@ -139,14 +141,14 @@ namespace device { ...@@ -139,14 +141,14 @@ namespace device {
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
typename ELayout,
typename HLayout, typename HLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename EDataType, typename GammaDataType,
typename BetaDataType,
typename HDataType, typename HDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -180,12 +182,22 @@ template <typename ALayout, ...@@ -180,12 +182,22 @@ template <typename ALayout,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock, typename PostShuffleThreadClusterSize_M_N,
index_t ReduceThreadTransferScalarPerVector_NPerBlock, index_t PostShuffleScalarPerVector,
typename LayernormThreadClusterSize_M_N,
typename LayernormThreadSliceSize_M_N,
index_t LayernormESrcHDstVectorDim,
index_t LayernormESrcVectorSize,
index_t LayernormHDstVectorSize,
index_t LayernormGammaSrcVectorSize,
index_t LayernormBetaSrcVectorSize,
index_t LayernormMeanVarSrcDstVectorSize,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{ {
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle; using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
using EDataType = HDataType;
using MeanDataType = CShuffleDataType; using MeanDataType = CShuffleDataType;
using VarDataType = CShuffleDataType; using VarDataType = CShuffleDataType;
...@@ -264,13 +276,64 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -264,13 +276,64 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); static auto MakeDescriptor_M(index_t MRaw)
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); {
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
using VarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); const auto MPad = M - MRaw;
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_mraw;
}
};
static auto MakeDescriptor_N(index_t NRaw)
{
const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N
return transform_tensor_descriptor(grid_desc_nraw,
make_tuple(make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_nraw;
}
};
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
...@@ -289,8 +352,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -289,8 +352,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
MeanGridDesc_M_N, MeanVarGridDesc_M_N,
VarGridDesc_M_N, MeanVarGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -320,15 +383,34 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -320,15 +383,34 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BBlockLdsExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock, PostShuffleThreadClusterSize_M_N,
ReduceThreadTransferScalarPerVector_NPerBlock, PostShuffleScalarPerVector,
1, 1,
LoopSched>; LoopSched>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm = // using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, MeanDataType, VarDataType>; // GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType,
// MeanDataType,
// VarDataType,
// AccDataType,
// HGridDesc_M_N,
// MeanGridDesc_M_N,
// GammaBetaGridDesc_N,
// MeanVarGridDesc_M,
// BlockSize,
// LayernormMThreadClusterSize,
// LayernormNThreadClusterSize,
// LayernormMThreadSliceSize,
// LayernormNThreadSliceSize,
// LayernormESrcHDstVectorDim,
// LayernormESrcVectorSize,
// LayernormHDstVectorSize,
// LayernormGammaSrcVectorSize,
// LayernormBetaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -336,7 +418,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -336,7 +418,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Argument(const void* p_a_grid, Argument(const void* p_a_grid,
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, const void* p_gamma_grid,
const void* p_beta_grid,
void* p_h_grid, void* p_h_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -344,41 +427,48 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -344,41 +427,48 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t StrideH, index_t StrideH,
AccDataType epsilon,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
BElementwiseOperation h_element_op) HElementwiseOperation h_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{nullptr},
p_mean_grid_{nullptr}, p_mean_grid_{nullptr},
p_var_grid_{nullptr}, p_var_grid_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)}, p_h_grid_{static_cast<HDataType*>(p_h_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_grid_desc_m_n_{}, mean_grid_desc_m_n_{},
var_grid_desc_m_n_{}, var_grid_desc_m_n_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
h_element_op_{h_element_op}, h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)} blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{ {
mean_grid_desc_m_n_ = mean_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_); DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
var_grid_desc_m_n_ = var_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_); DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
int welford_size = MRaw * blkGroupSize_; hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * welford_size));
hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * welford_size)); int gemm_welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * gemm_welford_size));
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -440,6 +530,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -440,6 +530,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
EDataType* p_e_grid_; EDataType* p_e_grid_;
MeanDataType* p_mean_grid_; // mean MeanDataType* p_mean_grid_; // mean
VarDataType* p_var_grid_; // variance * count VarDataType* p_var_grid_; // variance * count
const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_; HDataType* p_h_grid_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
...@@ -447,8 +539,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -447,8 +539,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
MeanGridDesc_M_N mean_grid_desc_m_n_; MeanVarGridDesc_M_N mean_grid_desc_m_n_;
VarGridDesc_M_N var_grid_desc_m_n_; MeanVarGridDesc_M_N var_grid_desc_m_n_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
HGridDesc_M_N h_grid_desc_m_n_; HGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -473,6 +567,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -473,6 +567,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
HElementwiseOperation h_element_op_; HElementwiseOperation h_element_op_;
int blkGroupSize_; int blkGroupSize_;
AccDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -524,12 +619,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -524,12 +619,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
const auto kernel_welford_layernorm = // const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm, // kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType, // EDataType,
HDataType, // HDataType,
MeanDataType, // MeanDataType,
VarDataType>; // VarDataType>;
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -554,15 +649,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -554,15 +649,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.var_grid_desc_mblock_mperblock_nblock_, arg.var_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config, // avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm, // kernel_welford_layernorm,
dim3(grid_size), // dim3(grid_size),
dim3(BlockSize), // dim3(BlockSize),
0, // 0,
arg.p_e_grid_, // arg.p_e_grid_,
arg.p_mean_grid_, // arg.p_mean_grid_,
arg.p_var_grid_, // arg.p_var_grid_,
arg.p_h_grid_); // arg.p_h_grid_,
// arg.blkGroupSize_);
return avg_time; return avg_time;
}; };
...@@ -604,7 +700,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -604,7 +700,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
static auto MakeArgument(const void* p_a, static auto MakeArgument(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, const void* p_gamma,
const void* p_beta,
void* p_h, void* p_h,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -612,8 +709,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -612,8 +709,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t StrideH, index_t StrideH,
AccDataType epsilon,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
...@@ -622,7 +719,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -622,7 +719,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_ds, p_ds,
p_e, p_gamma,
p_beta,
p_h, p_h,
MRaw, MRaw,
NRaw, NRaw,
...@@ -630,8 +728,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -630,8 +728,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
StrideA, StrideA,
StrideB, StrideB,
StrideDs, StrideDs,
StrideE,
StrideH, StrideH,
epsilon,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
...@@ -644,7 +742,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -644,7 +742,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, const void* p_gamma,
const void* p_beta,
void* p_h, void* p_h,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -652,8 +751,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -652,8 +751,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t StrideH, index_t StrideH,
AccDataType epsilon,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
...@@ -662,7 +761,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -662,7 +761,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return std::make_unique<Argument>(p_a, return std::make_unique<Argument>(p_a,
p_b, p_b,
p_ds, p_ds,
p_e, p_gamma,
p_beta,
p_h, p_h,
MRaw, MRaw,
NRaw, NRaw,
...@@ -670,8 +770,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -670,8 +770,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
StrideA, StrideA,
StrideB, StrideB,
StrideDs, StrideDs,
StrideE,
StrideH, StrideH,
epsilon,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
......
...@@ -78,9 +78,9 @@ template <typename ABDataType, ...@@ -78,9 +78,9 @@ template <typename ABDataType,
index_t BBlockLdsExtraN, index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, typename PostShuffleThreadClusterSize_M_N,
index_t CDEReduceThreadTransferScalarPerVector_NPerBlock, index_t PostShuffleScalarPerVector,
index_t FGTransferScalarPerVector, index_t MeanVarTransferScalarPerVector,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{ {
...@@ -604,10 +604,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -604,10 +604,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_cast<CShuffleDataType*>(p_shared), static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// static_cast<CShuffleDataType*>(p_shared),
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple( make_tuple(
...@@ -711,8 +707,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -711,8 +707,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
false>{}; false>{};
// LDS c_reduce_block_desc_mperblock_nperblock // LDS c_shuffle_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple( make_tuple(
make_freeze_transform(I0), make_freeze_transform(I0),
...@@ -724,75 +720,79 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -724,75 +720,79 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) * static_assert(PostShuffleThreadClusterSize_M_N::At(I0) *
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) == PostShuffleThreadClusterSize_M_N::At(I1) ==
BlockSize, BlockSize,
"wrong!"); "wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) == PostShuffleThreadClusterSize_M_N::At(I0) ==
0 && 0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) == PostShuffleThreadClusterSize_M_N::At(I1) ==
0, 0,
"wrong!"); "wrong!");
constexpr index_t mreduce_per_thread = constexpr index_t PostShuffleThreadSliceSize_M =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0); PostShuffleThreadClusterSize_M_N::At(I0);
constexpr index_t nreduce_per_thread = constexpr index_t PostShuffleThreadSliceSize_N =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1); PostShuffleThreadClusterSize_M_N::At(I1);
constexpr auto c_reduce_thread_lengths_mperblock_nperblock = constexpr auto PostShuffleThreadSliceSize_M_N =
Sequence<mreduce_per_thread, nreduce_per_thread>{}; Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
// VGPR cde_reduce_thread_desc_mperblock_nperblock // VGPR post_shuffle_thread_desc_m_n
constexpr auto cde_reduce_thread_desc_mperblock_nperblock = constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed( make_tuple(Number<PostShuffleThreadSliceSize_M>{},
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{})); Number<PostShuffleThreadSliceSize_N>{}));
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
cde_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); post_shuffle_thread_desc_m_n.GetElementSpaceSize());
// To apply D0, D1, ... and Welford. // To apply D0, D1, ... and Welford.
// threadwise copy from LDS to VGPR // threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( constexpr auto post_shuffle_thread_cluster_desc =
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<1, 0>{});
const auto c_reduce_thread_cluster_idx = const auto post_shuffle_thread_cluster_idx =
c_reduce_thread_cluster_desc.CalculateBottomIndex( post_shuffle_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto c_reduce_thread_data_idx_begin = const auto post_shuffle_thread_data_idx_begin =
c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
// To apply D0, D1, ... and Welford. // To apply D0, D1, ... and Welford.
// Copy c shuffle from LDS back to VGPR // Copy c shuffle from LDS back to VGPR
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto post_shuffle_thread_copy_lds_to_vgpr =
CShuffleDataType, ThreadwiseTensorSliceTransfer_v2<CShuffleDataType,
AccDataType, AccDataType,
decltype(c_reduce_block_desc_mperblock_nperblock), decltype(c_shuffle_block_desc_mperblock_nperblock),
decltype(cde_reduce_thread_desc_mperblock_nperblock), decltype(post_shuffle_thread_desc_m_n),
decltype(c_reduce_thread_lengths_mperblock_nperblock), decltype(PostShuffleThreadSliceSize_M_N),
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
CDEReduceThreadTransferScalarPerVector_NPerBlock, PostShuffleScalarPerVector,
1, 1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; true>{c_shuffle_block_desc_mperblock_nperblock,
post_shuffle_thread_data_idx_begin};
// D0, D1, ..., Dn // D0, D1, ..., Dn
constexpr auto cde_reduce_thread_desc_I1_mperblock_I1_nperblock = constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{})); make_tuple(I1,
Number<PostShuffleThreadSliceSize_M>{},
I1,
Number<PostShuffleThreadSliceSize_N>{}));
// FIXME: Decrease usage of VGPR // FIXME: Decrease usage of VGPR
// Apply pointwise lambda function from multi-source (Global and LDS) into VGPR // Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
auto ds_thread_buf = generate_tuple( auto ds_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return make_static_buffer<AddressSpaceEnum::Vgpr, CShuffleDataType>( return make_static_buffer<AddressSpaceEnum::Vgpr, CShuffleDataType>(
cde_reduce_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize()); post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
...@@ -804,58 +804,65 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -804,58 +804,65 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
DDataType, DDataType,
AccDataType, AccDataType,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock), decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
CDEReduceThreadTransferScalarPerVector_NPerBlock, PostShuffleScalarPerVector,
1, 1,
true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I], true>(
make_multi_index( ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
I0, make_multi_index(
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], I0,
I0, m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]));
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
EDataType, EDataType,
decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock), decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths Sequence<I1,
Sequence<0, 1, 2, 3>, // DimAccessOrder PostShuffleThreadSliceSize_M,
3, // DstVectorDim I1,
CDEReduceThreadTransferScalarPerVector_NPerBlock, PostShuffleThreadSliceSize_N>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
PostShuffleScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0, make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0, I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// Welford // Welford
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed( constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{})); make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
constexpr auto thread_welford_dst_desc_m = constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{})); make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
using ThreadwiseWelford = ThreadwiseWelford<AccDataType, using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
decltype(thread_welford_src_desc_m_k), decltype(thread_welford_src_desc_m_k),
decltype(thread_welford_dst_desc_m)>; decltype(thread_welford_dst_desc_m)>;
using BlockwiseWelford = using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockwiseWelford<AccDataType, BlockSize,
BlockSize, PostShuffleThreadClusterSize_M_N,
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, Sequence<0, 1>,
Sequence<0, 1>, false>;
false>;
constexpr int num_shuffleM = constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl); MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
...@@ -870,14 +877,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -870,14 +877,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_for<0, num_shuffleM, 1>{}([&](auto i) { static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding // TODO - padding
threadwise_welfords(i).max_count_ = nreduce_per_thread; threadwise_welfords(i).max_count_ = PostShuffleThreadSliceSize_N;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()); thread_welford_dst_desc_m.GetElementSpaceSize());
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()); thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, mreduce_per_thread, 1>{}([&](auto j) { static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f); mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f); var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
}); });
...@@ -905,11 +912,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -905,11 +912,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
block_sync_lds(); block_sync_lds();
// Get shuffle data from LDS to VGPR // Get shuffle data from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
cde_reduce_thread_desc_mperblock_nperblock, post_shuffle_thread_desc_m_n,
make_tuple(I0, I0), make_tuple(I0, I0),
e_thread_buf); e_thread_buf);
// Global read D0, D1, ... // Global read D0, D1, ...
static_for<0, NumDTensor, 1>{}([&](auto Id) { static_for<0, NumDTensor, 1>{}([&](auto Id) {
...@@ -917,7 +924,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -917,7 +924,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
d_thread_copy_global_to_vgpr.Run( d_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], ds_grid_desc_mblock_mperblock_nblock_nperblock[Id],
ds_grid_buf[Id], ds_grid_buf[Id],
cde_reduce_thread_desc_I1_mperblock_I1_nperblock, post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
ds_thread_buf(Id)); ds_thread_buf(Id));
...@@ -931,19 +938,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -931,19 +938,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
}); });
// cde_element_op(e, c, d0, d1, ...); // cde_element_op(e, c, d0, d1, ...);
static_for<0, cde_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) {
[&](auto i) { const auto c_ds_src_data_refs = concat_tuple_of_reference(
const auto c_ds_src_data_refs = concat_tuple_of_reference( tie(e_thread_buf[i]),
tie(e_thread_buf[i]), generate_tie(
generate_tie( [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
[&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, Number<NumDTensor>{}));
Number<NumDTensor>{})); auto e_dst_data_refs = tie(e_thread_buf(i));
auto e_dst_data_refs = tie(e_thread_buf(i)); unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); });
});
// Global write E // Global write E
e_thread_copy_vgpr_to_global.Run(cde_reduce_thread_desc_I1_mperblock_I1_nperblock, e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
e_thread_buf, e_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -980,35 +986,35 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -980,35 +986,35 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto& var_thread_buf = var_thread_bufs(i); auto& var_thread_buf = var_thread_bufs(i);
int count = threadwise_welfords(i).cur_count_; int count = threadwise_welfords(i).cur_count_;
static_for<0, mreduce_per_thread, 1>{}([&](auto j) { static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds(); block_sync_lds();
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count); BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count);
}); });
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1)); make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
constexpr int shuffleMPerBlock = constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(mreduce_per_thread % FGTransferScalarPerVector == 0); static_assert(PostShuffleThreadSliceSize_M % MeanVarTransferScalarPerVector == 0);
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
MeanDataType, MeanDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_grid_desc_mblock_mperblock_nblock), decltype(mean_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
FGTransferScalarPerVector, MeanVarTransferScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{mean_grid_desc_mblock_mperblock_nblock, false>{mean_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
...@@ -1017,17 +1023,17 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1017,17 +1023,17 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(var_grid_desc_mblock_mperblock_nblock), decltype(var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
FGTransferScalarPerVector, MeanVarTransferScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{var_grid_desc_mblock_mperblock_nblock, false>{var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment