Commit 694057a7 authored by rocking's avatar rocking
Browse files

Rewrite the device interface and rename some var

parent 982c85a3
......@@ -38,7 +38,8 @@ using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using GammaDataType = F16;
using BetaDataType = F16;
using HDataType = F16;
// Layout
......@@ -59,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| ELayout| HLayout| AData| BData| AccData| CShuffle| DsData| EData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| ReduceThreadTransfer| |
//######| | | | | | Type| Type| Type| DataType| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock|
< ALayout, BLayout, DsLayout, ELayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4>;
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
// clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
......@@ -96,33 +97,39 @@ int main()
ck::index_t StrideB = 1024;
ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 1024;
ck::index_t StrideE = 1024;
ck::index_t StrideH = 1024;
// TODO - gamma and beta
float epsilon = 1e-5;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
d0_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -135,7 +142,8 @@ int main()
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(),
M,
N,
......@@ -143,8 +151,8 @@ int main()
StrideA,
StrideB,
{StrideD0, StrideD1},
StrideE,
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
......
......@@ -13,7 +13,8 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
// #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
......@@ -101,22 +102,23 @@ __global__ void
#endif
}
template <typename GridwiseWelfordLayernorm,
typename XDataType,
typename YDataType,
typename MeanDataType,
typename VarDataType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(const XDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
YDataType* __restrict__ p_y_grid)
{
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid);
}
// template <typename GridwiseWelfordLayernorm,
// typename EDataType,
// typename HDataType,
// typename MeanDataType,
// typename VarDataType>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// #endif
// kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
// const MeanDataType* __restrict__ p_mean_grid,
// const VarDataType* __restrict__ p_var_grid,
// HDataType* __restrict__ p_y_grid,
// index_t blkgroup_size)
// {
// GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
// }
} // namespace ck
......@@ -139,14 +141,14 @@ namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename HLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -180,12 +182,22 @@ template <typename ALayout,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock,
index_t ReduceThreadTransferScalarPerVector_NPerBlock,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
typename LayernormThreadClusterSize_M_N,
typename LayernormThreadSliceSize_M_N,
index_t LayernormESrcHDstVectorDim,
index_t LayernormESrcVectorSize,
index_t LayernormHDstVectorSize,
index_t LayernormGammaSrcVectorSize,
index_t LayernormBetaSrcVectorSize,
index_t LayernormMeanVarSrcDstVectorSize,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout;
using EDataType = HDataType;
using MeanDataType = CShuffleDataType;
using VarDataType = CShuffleDataType;
......@@ -264,13 +276,64 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{});
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using VarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
static auto MakeDescriptor_M(index_t MRaw)
{
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_mraw;
}
};
static auto MakeDescriptor_N(index_t NRaw)
{
const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N
return transform_tensor_descriptor(grid_desc_nraw,
make_tuple(make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_nraw;
}
};
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
......@@ -289,8 +352,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
MeanGridDesc_M_N,
VarGridDesc_M_N,
MeanVarGridDesc_M_N,
MeanVarGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -320,15 +383,34 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock,
ReduceThreadTransferScalarPerVector_NPerBlock,
PostShuffleThreadClusterSize_M_N,
PostShuffleScalarPerVector,
1,
LoopSched>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, HDataType, MeanDataType, VarDataType>;
// using GridwiseWelfordLayernorm =
// GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType,
// MeanDataType,
// VarDataType,
// AccDataType,
// HGridDesc_M_N,
// MeanGridDesc_M_N,
// GammaBetaGridDesc_N,
// MeanVarGridDesc_M,
// BlockSize,
// LayernormMThreadClusterSize,
// LayernormNThreadClusterSize,
// LayernormMThreadSliceSize,
// LayernormNThreadSliceSize,
// LayernormESrcHDstVectorDim,
// LayernormESrcVectorSize,
// LayernormHDstVectorSize,
// LayernormGammaSrcVectorSize,
// LayernormBetaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>;
// Argument
struct Argument : public BaseArgument
......@@ -336,7 +418,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
const void* p_gamma_grid,
const void* p_beta_grid,
void* p_h_grid,
index_t MRaw,
index_t NRaw,
......@@ -344,41 +427,48 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t StrideH,
AccDataType epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
BElementwiseOperation h_element_op)
HElementwiseOperation h_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_e_grid_{nullptr},
p_mean_grid_{nullptr},
p_var_grid_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_grid_desc_m_n_{},
var_grid_desc_m_n_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
h_element_op_{h_element_op},
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)}
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
mean_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
var_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
int welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * welford_size));
hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * welford_size));
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
int gemm_welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * gemm_welford_size));
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -440,6 +530,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
EDataType* p_e_grid_;
MeanDataType* p_mean_grid_; // mean
VarDataType* p_var_grid_; // variance * count
const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_;
// tensor descriptors for problem definiton
......@@ -447,8 +539,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
MeanGridDesc_M_N mean_grid_desc_m_n_;
VarGridDesc_M_N var_grid_desc_m_n_;
MeanVarGridDesc_M_N mean_grid_desc_m_n_;
MeanVarGridDesc_M_N var_grid_desc_m_n_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
HGridDesc_M_N h_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
......@@ -473,6 +567,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
HElementwiseOperation h_element_op_;
int blkGroupSize_;
AccDataType epsilon_;
};
// Invoker
......@@ -524,12 +619,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType,
HDataType,
MeanDataType,
VarDataType>;
// const auto kernel_welford_layernorm =
// kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
// EDataType,
// HDataType,
// MeanDataType,
// VarDataType>;
avg_time +=
launch_and_time_kernel(stream_config,
......@@ -554,15 +649,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.var_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_mean_grid_,
arg.p_var_grid_,
arg.p_h_grid_);
// avg_time += launch_and_time_kernel(stream_config,
// kernel_welford_layernorm,
// dim3(grid_size),
// dim3(BlockSize),
// 0,
// arg.p_e_grid_,
// arg.p_mean_grid_,
// arg.p_var_grid_,
// arg.p_h_grid_,
// arg.blkGroupSize_);
return avg_time;
};
......@@ -604,7 +700,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const void* p_gamma,
const void* p_beta,
void* p_h,
index_t MRaw,
index_t NRaw,
......@@ -612,8 +709,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t StrideH,
AccDataType epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
......@@ -622,7 +719,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return Argument{p_a,
p_b,
p_ds,
p_e,
p_gamma,
p_beta,
p_h,
MRaw,
NRaw,
......@@ -630,8 +728,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
StrideA,
StrideB,
StrideDs,
StrideE,
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
......@@ -644,7 +742,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const void* p_gamma,
const void* p_beta,
void* p_h,
index_t MRaw,
index_t NRaw,
......@@ -652,8 +751,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t StrideH,
AccDataType epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
......@@ -662,7 +761,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
p_gamma,
p_beta,
p_h,
MRaw,
NRaw,
......@@ -670,8 +770,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
StrideA,
StrideB,
StrideDs,
StrideE,
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
......
......@@ -78,9 +78,9 @@ template <typename ABDataType,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
index_t CDEReduceThreadTransferScalarPerVector_NPerBlock,
index_t FGTransferScalarPerVector,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
index_t MeanVarTransferScalarPerVector,
LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
......@@ -604,10 +604,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// static_cast<CShuffleDataType*>(p_shared),
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
......@@ -711,8 +707,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
false>{};
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
// LDS c_shuffle_block_desc_mperblock_nperblock
constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
......@@ -724,75 +720,79 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) *
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) ==
static_assert(PostShuffleThreadClusterSize_M_N::At(I0) *
PostShuffleThreadClusterSize_M_N::At(I1) ==
BlockSize,
"wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) ==
PostShuffleThreadClusterSize_M_N::At(I0) ==
0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) ==
PostShuffleThreadClusterSize_M_N::At(I1) ==
0,
"wrong!");
constexpr index_t mreduce_per_thread =
constexpr index_t PostShuffleThreadSliceSize_M =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0);
PostShuffleThreadClusterSize_M_N::At(I0);
constexpr index_t nreduce_per_thread =
constexpr index_t PostShuffleThreadSliceSize_N =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1);
PostShuffleThreadClusterSize_M_N::At(I1);
constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
Sequence<mreduce_per_thread, nreduce_per_thread>{};
constexpr auto PostShuffleThreadSliceSize_M_N =
Sequence<PostShuffleThreadSliceSize_M, PostShuffleThreadSliceSize_N>{};
// VGPR cde_reduce_thread_desc_mperblock_nperblock
constexpr auto cde_reduce_thread_desc_mperblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR post_shuffle_thread_desc_m_n
constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
auto e_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
cde_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
post_shuffle_thread_desc_m_n.GetElementSpaceSize());
// To apply D0, D1, ... and Welford.
// threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
constexpr auto post_shuffle_thread_cluster_desc =
make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<1, 0>{});
const auto c_reduce_thread_cluster_idx =
c_reduce_thread_cluster_desc.CalculateBottomIndex(
const auto post_shuffle_thread_cluster_idx =
post_shuffle_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto c_reduce_thread_data_idx_begin =
c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
const auto post_shuffle_thread_data_idx_begin =
post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
// To apply D0, D1, ... and Welford.
// Copy c shuffle from LDS back to VGPR
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
CShuffleDataType,
AccDataType,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(cde_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
Sequence<0, 1>,
1,
CDEReduceThreadTransferScalarPerVector_NPerBlock,
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto post_shuffle_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<CShuffleDataType,
AccDataType,
decltype(c_shuffle_block_desc_mperblock_nperblock),
decltype(post_shuffle_thread_desc_m_n),
decltype(PostShuffleThreadSliceSize_M_N),
Sequence<0, 1>,
1,
PostShuffleScalarPerVector,
1,
true>{c_shuffle_block_desc_mperblock_nperblock,
post_shuffle_thread_data_idx_begin};
// D0, D1, ..., Dn
constexpr auto cde_reduce_thread_desc_I1_mperblock_I1_nperblock =
constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
make_tuple(I1,
Number<PostShuffleThreadSliceSize_M>{},
I1,
Number<PostShuffleThreadSliceSize_N>{}));
// FIXME: Decrease usage of VGPR
// Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
auto ds_thread_buf = generate_tuple(
[&](auto) {
return make_static_buffer<AddressSpaceEnum::Vgpr, CShuffleDataType>(
cde_reduce_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
},
Number<NumDTensor>{});
......@@ -804,58 +804,65 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
DDataType,
AccDataType,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>,
Sequence<0, 1, 2, 3>,
3,
CDEReduceThreadTransferScalarPerVector_NPerBlock,
PostShuffleScalarPerVector,
1,
true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
true>(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]));
},
Number<NumDTensor>{});
auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
EDataType,
decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock),
decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
CDEReduceThreadTransferScalarPerVector_NPerBlock,
Sequence<I1,
PostShuffleThreadSliceSize_M,
I1,
PostShuffleThreadSliceSize_N>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
PostShuffleScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
true>{
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]),
tensor_operation::element_wise::PassThrough{}};
// Welford
constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
make_tuple(Number<PostShuffleThreadSliceSize_M>{},
Number<PostShuffleThreadSliceSize_N>{}));
constexpr auto thread_welford_dst_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed(
make_tuple(Number<PostShuffleThreadSliceSize_M>{}));
using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
decltype(thread_welford_src_desc_m_k),
decltype(thread_welford_dst_desc_m)>;
using BlockwiseWelford =
BlockwiseWelford<AccDataType,
BlockSize,
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
Sequence<0, 1>,
false>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
PostShuffleThreadClusterSize_M_N,
Sequence<0, 1>,
false>;
constexpr int num_shuffleM =
MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
......@@ -870,14 +877,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding
threadwise_welfords(i).max_count_ = nreduce_per_thread;
threadwise_welfords(i).max_count_ = PostShuffleThreadSliceSize_N;
mean_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, mreduce_per_thread, 1>{}([&](auto j) {
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
});
......@@ -905,11 +912,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
block_sync_lds();
// Get shuffle data from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
cde_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
e_thread_buf);
post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
post_shuffle_thread_desc_m_n,
make_tuple(I0, I0),
e_thread_buf);
// Global read D0, D1, ...
static_for<0, NumDTensor, 1>{}([&](auto Id) {
......@@ -917,7 +924,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
d_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[Id],
ds_grid_buf[Id],
cde_reduce_thread_desc_I1_mperblock_I1_nperblock,
post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
ds_thread_buf(Id));
......@@ -931,19 +938,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
});
// cde_element_op(e, c, d0, d1, ...);
static_for<0, cde_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
const auto c_ds_src_data_refs = concat_tuple_of_reference(
tie(e_thread_buf[i]),
generate_tie(
[&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
Number<NumDTensor>{}));
auto e_dst_data_refs = tie(e_thread_buf(i));
unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
});
static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) {
const auto c_ds_src_data_refs = concat_tuple_of_reference(
tie(e_thread_buf[i]),
generate_tie(
[&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
Number<NumDTensor>{}));
auto e_dst_data_refs = tie(e_thread_buf(i));
unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
});
// Global write E
e_thread_copy_vgpr_to_global.Run(cde_reduce_thread_desc_I1_mperblock_I1_nperblock,
e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
make_tuple(I0, I0, I0, I0),
e_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -980,35 +986,35 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto& var_thread_buf = var_thread_bufs(i);
int count = threadwise_welfords(i).cur_count_;
static_for<0, mreduce_per_thread, 1>{}([&](auto j) {
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count);
});
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1));
make_tuple(I1, Number<PostShuffleThreadSliceSize_M>{}, I1));
constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(mreduce_per_thread % FGTransferScalarPerVector == 0);
static_assert(PostShuffleThreadSliceSize_M % MeanVarTransferScalarPerVector == 0);
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
MeanDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
FGTransferScalarPerVector,
MeanVarTransferScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
false>{mean_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
auto var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
......@@ -1017,17 +1023,17 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype(thread_welford_desc_I_m_I),
decltype(var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread, 1>,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
FGTransferScalarPerVector,
MeanVarTransferScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
false>{var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
c_reduce_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment