Commit 644df335 authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into gemm_layernorm_instance

parents d99640ab 7494c1c6
......@@ -13,10 +13,16 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank,
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation>
typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
struct DeviceReduce : public BaseOperator
{
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
......@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
......@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank,
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation>
using DeviceReducePtr = std::unique_ptr<
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>;
typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
using DeviceReducePtr = std::unique_ptr<DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] alpha double type value
// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
......@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
double alpha,
double beta,
const void* in_dev,
void* out_dev,
InElementwiseOp in_elementwise_op,
......
......@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
......@@ -26,10 +26,10 @@ template <typename InDataTypeTuple,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n;
......
......@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
......@@ -25,8 +25,8 @@ template <typename InDataTypeTuple,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise
: public DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
struct DeviceElementwiseImpl
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
......
......@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> reduceDims,
XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op,
AccDataType epsilon,
double epsilon,
const std::array<const void*, NumInput> in_dev_buffers,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: epsilon_(epsilon),
p_gamma_(p_gamma),
: p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
x_elementwise_op_(x_elementwise_op),
y_elementwise_op_(y_elementwise_op)
{
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
for(int i = 0; i < NumInput; i++)
......@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
double epsilon,
const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma,
const void* p_beta,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename DsGridDesc_M0_M10_M11_N0_N10_N11,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx1030__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
__shared__ ABDataType p_shared[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
ds_grid_desc_m0_m10_m11_n0_n10_n11,
e_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_k0_m0_m1_k1;
ignore = b_grid_desc_k0_n0_n1_k1;
ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = block_2_ctile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleD_Dl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
ADataType,
AccDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
EGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using DsGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{}));
using EGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
e_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
a_grid_desc_k0_m_k1_ =
DeviceGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[i]);
});
e_grid_desc_m_n_ =
DeviceGemmMultipleD_Dl::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
ds_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_);
e_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_;
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmMultipleD_Dl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(
arg.e_grid_desc_m_n_.GetLength(I0), arg.e_grid_desc_m_n_.GetLength(I1));
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
const auto kernel =
kernel_gemm_dl_multiple_d<GridwiseGemm,
ADataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_K0_M0_M1_K1,
DeviceOp::BGridDesc_K0_N0_N1_K1,
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
DeviceOp::EGridDesc_M0_M10_M11_N0_N10_N11,
DefaultBlock2CTileMap,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.e_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
}
else
{
return false;
}
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmMultipleD_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
......@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
......
......@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......
......@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for<0, NumReduction, 1>{}([&](auto I) {
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
flag =
flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value;
flag && ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value;
});
return flag;
......@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas,
const std::array<double, NumReduction>& alphas,
const std::array<double, NumReduction>& betas,
const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++)
{
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = static_cast<AccDataType>(betas[i]);
};
in_dev_ = static_cast<const InDataType*>(in_dev);
......@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas,
const std::array<const void*, NumReduction> betas,
const std::array<double, NumReduction> alphas,
const std::array<double, NumReduction> betas,
const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......
......@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas,
const std::array<double, NumReduction>& alphas,
const std::array<double, NumReduction>& betas,
const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++)
{
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = static_cast<AccDataType>(betas[i]);
};
in_dev_ = static_cast<const InDataType*>(in_dev);
......@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas,
const std::array<const void*, NumReduction> betas,
const std::array<double, NumReduction> alphas,
const std::array<double, NumReduction> betas,
const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......
......@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon,
double epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: epsilon_(epsilon),
p_x_(p_x),
: p_x_(p_x),
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op)
{
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
......@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
double epsilon,
const void* p_x,
const void* p_gamma,
const void* p_beta,
......
......@@ -40,8 +40,16 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlock
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value,
static_assert(ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value,
"The OutDataType must support the specified OutMemoryDataOperation!");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
......@@ -209,8 +217,8 @@ struct DeviceReduceMultiBlock
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const InDataType* in_dev,
const IndexDataType* in_index_dev,
OutDataType* out_dev,
......@@ -494,8 +502,8 @@ struct DeviceReduceMultiBlock
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
......
......@@ -35,8 +35,17 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceThreadWise
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
......@@ -156,8 +165,8 @@ struct DeviceReduceThreadWise
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_index_dev,
......@@ -332,8 +341,8 @@ struct DeviceReduceThreadWise
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
......
......@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
: alpha_{alpha},
beta_{beta},
in_dev_{in_dev},
: in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size())
{
......@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
static auto MakeArgument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const AccDataType alpha,
const AccDataType beta,
double alpha,
double beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
......@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
double alpha,
double beta,
const void* in_dev,
void* out_dev,
InElementwiseOp in_elementwise_op,
......@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
*static_cast<const AccDataType*>(alpha),
*static_cast<const AccDataType*>(beta),
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
in_elementwise_op,
......
......@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp"
namespace ck {
namespace tensor_operation {
......@@ -24,16 +24,17 @@ template <typename EmbType,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename EmbElementwiseOperation,
ck::index_t BlockSize,
ck::index_t DimClusterSize,
ck::index_t RowClusterSize,
ck::index_t DimPerBlock,
ck::index_t RowPerBlock,
ck::index_t DimThreadSize,
ck::index_t RowVectorSize>
struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
ck::index_t RowVectorSize,
ck::index_t NumEmbeddings>
struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
{
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
{
return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
......@@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct Argument : public BaseArgument
{
Argument(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const ck::index_t NumRows,
const ck::index_t EmbeddingDim,
const ck::index_t IndexLength,
const AccDataType epsilon)
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
: p_out_(p_out),
p_emb_a_(p_emb_a),
p_emb_b_(p_emb_b),
p_emb_c_(p_emb_c),
p_index_a_(p_index_a),
p_index_b_(p_index_b),
p_index_c_(p_index_c),
p_embs_(p_embs),
p_indexs_(p_indexs),
p_gamma_(p_gamma),
p_beta_(p_beta),
NumRows_(NumRows),
EmbeddingDim_(EmbeddingDim),
IndexLength_(IndexLength),
epsilon_(epsilon)
epsilon_(epsilon),
emb_elementwise_op_(emb_elementwise_op)
{
grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
}
OutType* p_out_;
const EmbType* p_emb_a_;
const EmbType* p_emb_b_;
const EmbType* p_emb_c_;
const IndexType* p_index_a_;
const IndexType* p_index_b_;
const IndexType* p_index_c_;
ck::Array<EmbType*, NumEmbeddings> p_embs_;
ck::Array<IndexType*, NumEmbeddings> p_indexs_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
ck::index_t NumRows_;
ck::index_t EmbeddingDim_;
ck::index_t IndexLength_;
AccDataType epsilon_;
EmbElementwiseOperation emb_elementwise_op_;
size_t grid_size_;
};
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(void* p_out,
const void* p_emb_a,
const void* p_emb_b,
const void* p_emb_c,
const void* p_index_a,
const void* p_index_b,
const void* p_index_c,
const void* p_gamma,
const void* p_beta,
ck::index_t NumRows,
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon)
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_out,
const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const void* p_gamma,
const void* p_beta,
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
reinterpret_cast<const EmbType*>(p_emb_a),
reinterpret_cast<const EmbType*>(p_emb_b),
reinterpret_cast<const EmbType*>(p_emb_c),
reinterpret_cast<const IndexType*>(p_index_a),
reinterpret_cast<const IndexType*>(p_index_b),
reinterpret_cast<const IndexType*>(p_index_c),
p_embs,
p_indexs,
reinterpret_cast<const GammaDataType*>(p_gamma),
reinterpret_cast<const BetaDataType*>(p_beta),
NumRows,
EmbeddingDim,
IndexLength,
epsilon);
epsilon,
emb_elementwise_op);
}
using GridwiseSparseEmbedding =
GridwiseSparseEmbedding3ForwardLayernorm<EmbType,
GridwiseSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(MakeOutputDescriptor(1, 1)),
EmbElementwiseOperation,
BlockSize,
DimClusterSize,
RowClusterSize,
DimPerBlock,
RowPerBlock,
DimThreadSize,
RowVectorSize>;
RowVectorSize,
NumEmbeddings>;
struct Invoker : public BaseInvoker
{
......@@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
const auto kernel_main =
kernel_sparse_embedding3_forward_layernorm<GridwiseSparseEmbedding,
kernel_sparse_embeddings_forward_layernorm<GridwiseSparseEmbedding,
EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(out_desc)>;
decltype(out_desc),
EmbElementwiseOperation,
NumEmbeddings>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
......@@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3(BlockSize),
0,
arg.p_out_,
arg.p_emb_a_,
arg.p_emb_b_,
arg.p_emb_c_,
arg.p_index_a_,
arg.p_index_b_,
arg.p_index_c_,
arg.p_embs_,
arg.p_indexs_,
arg.p_gamma_,
arg.p_beta_,
out_desc,
arg.epsilon_);
arg.epsilon_,
arg.emb_elementwise_op_);
return (avg_time);
}
......@@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static bool IsSupportedArgument(const Argument* p_arg)
{
return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0);
return (RowPerBlock == p_arg->EmbeddingDim_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto str = std::stringstream();
// clang-format off
str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" <<
str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" <<
DimClusterSize << "x" << RowClusterSize << "_" <<
DimPerBlock << "x" << RowPerBlock << "_" <<
DimThreadSize << "x" << RowVectorSize;
......
......@@ -172,6 +172,42 @@ struct AddAdd
}
};
// C = A * B
// E = (C + D0) x D1
struct AddMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (type_convert<half_t>(c) + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = (c + d0) * d1;
e = y;
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
......@@ -278,6 +314,40 @@ struct Normalize
double epsilon_;
};
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct NormalizeInInfer
{
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2, typename T3, typename T4>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& variance,
const T3& gamma,
const T4& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
type_convert<T2>(gamma) +
type_convert<T2>(beta);
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
template <typename Y, typename X>
struct UnaryTypeConvert;
......
......@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave;
// 1-stage prefetch
template <typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{
// TODO: improve applicability
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep>
static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
// sync for Load threads()
block_sync_lds();
// global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// sync with math threads()
block_sync_lds();
// LDS write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
}
}
};
template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmMathWave;
// 1- stage prefetch
template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename ABDataType,
typename FloatGemmAcc,
typename EDataTypeShuffle,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename EElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t TileLoadThreadGroupSize,
index_t TileMathThreadGroupSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
struct TileLoadThreadGroup
{
__device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; }
__device__ static constexpr bool IsBelong()
{
return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
}
__device__ static index_t GetThreadId()
{
return get_thread_local_1d_id() - TileMathThreadGroupSize;
}
};
struct TileMathThreadGroup
{
__device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; }
__device__ static constexpr bool IsBelong()
{
return get_thread_local_1d_id() < TileMathThreadGroupSize;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
};
using CShuffleBlockTransferThreadGroup = ThisThreadBlock<TileMathThreadGroupSize>;
// load and math+store Wave pipelines.
// TODO: build pipelines blocks scheduling parallel tasks
using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup, NumGemmKPrefetchStage>;
using GridwiseGemmMath = GridwiseGemmMathWave<TileMathThreadGroup, NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType),
c_block_size * sizeof(EDataTypeShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& /*block_2_etile_map*/)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K == b_grid_desc_n_k.GetLength(I1)))
{
return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmMath::IsSupported(num_k_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M01 = I1;
constexpr auto N01 = I1;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M01)),
make_unmerge_transform(make_tuple(N0, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const EElementwiseOperation& e_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
// build loadWave and MathWave pipelines
// loadWave and MathWave synchronized through LDS
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
// divide block work by [M, N]
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
if(TileLoadThreadGroup::IsBelong())
{
// LoadWave
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
num_k_block_main_loop);
block_sync_lds();
block_sync_lds();
}
else if(TileMathThreadGroup::IsBelong())
{
// branch early for math wave
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize,
ABDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// TODO re-architect LDS+math stages
// Writing data to GMEM: only math wave is doing the work in cshuffle
GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<EDataTypeShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 5, 6>{},
Sequence<>{},
Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
EDataTypeShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
CShuffleBlockTransferThreadGroup, // ThreadGroup
EElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
EDataTypeShuffle, // typename SrcData,
EDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
e_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
// Different way of getting coalesced writes:
// We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner
// do it interleaved, then mfma can have nice c-mat layout as below:
//
// TODO
// We do not need to do LDS swizzle to align global writes writing cache lines:
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// elments (N is vertical or strided
// dimension)
// v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1
// elments (M is coalescing
// dimension) by enumerating M index in
// amat, bmat you can align cmat
// register(s) to contiguous M elements
// for example
// 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 ....
// 4th mfma instruction output space : 3 7 11 15 19 ....
// you can pack 4 registers output space into 2WORD and do global write
// (no LDS swizzling required)
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
}
}
};
} // namespace ck
......@@ -17,33 +17,24 @@ template <typename GridwiseSparseEmbedding,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename OutGridDesc>
typename OutGridDesc,
typename EmbElementwiseOperation,
ck::index_t NumEmbeddings>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
__global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon)
__global__ void kernel_sparse_embeddings_forward_layernorm(
OutType* p_out,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
GridwiseSparseEmbedding::Run(p_out,
p_emb_a,
p_emb_b,
p_emb_c,
p_index_a,
p_index_b,
p_index_c,
p_gamma,
p_beta,
out_grid_desc,
epsilon);
GridwiseSparseEmbedding::Run(
p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op);
}
template <typename EmbType,
......@@ -53,14 +44,16 @@ template <typename EmbType,
typename AccDataType,
typename OutType,
typename OutGridDesc,
typename EmbElementwiseOperation,
ck::index_t BlockSize,
ck::index_t DimClusterSize,
ck::index_t RowClusterSize,
ck::index_t DimPerBlock, // Row x Dim, along Dim
ck::index_t RowPerBlock, // Row x Dim, along Row
ck::index_t DimThreadSize, // this is actually not vector, but number of registers
ck::index_t RowVectorSize>
struct GridwiseSparseEmbedding3ForwardLayernorm
ck::index_t RowVectorSize,
ck::index_t NumEmbeddings>
struct GridwiseSparseEmbeddingsForwardLayernorm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -97,23 +90,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
BlockwiseWelford<AccDataType, BlockSize, ThreadClusterLength, Sequence<0, 1>>;
__device__ static void Run(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc,
const AccDataType epsilon)
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr auto thread_cluster_desc =
make_cluster_descriptor(Sequence<DimClusterSize, RowClusterSize>{}, Sequence<0, 1>{});
......@@ -141,13 +128,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr auto gamma_beta_buf_desc =
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_a;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_b;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_c;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_a;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_b;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_c;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true>,
NumEmbeddings>
in_thread_bufs;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, IndexType, DimPerBlock, true>, NumEmbeddings>
index_bufs;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true> acc_thread_buf;
......@@ -160,42 +145,31 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> var_thread_buf;
auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_a;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_b;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_c;
using src_vector_t = typename decltype(emb_vector_a)::type;
ck::Array<vector_type_maker_t<EmbType, RowVectorSize>, NumEmbeddings> emb_vectors;
auto emb_a = emb_vectors[0];
using src_vector_t = typename decltype(emb_a)::type;
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
IndexType index_a = index_buf_a[Number<current_dim>{}];
IndexType index_b = index_buf_b[Number<current_dim>{}];
IndexType index_c = index_buf_c[Number<current_dim>{}];
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(EmbType) * RowVectorSize;
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
IndexType index = index_bufs[i_embedding_][Number<current_dim>{}];
int32x4_t emb_res_a =
make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock);
int32x4_t emb_res_b =
make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock);
int32x4_t emb_res_c =
make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock);
emb_vector_a.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_a, thread_offset, 0);
emb_vector_b.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_b, thread_offset, 0);
emb_vector_c.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_c, thread_offset, 0);
int32x4_t emb_res = make_wave_buffer_resource_with_default_range(
p_embs[i_embedding_] + index * RowPerBlock);
emb_vectors(i_embedding_).template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
});
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
in_thread_buf_a(Number<register_offset>{}) =
emb_vector_a.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_b(Number<register_offset>{}) =
emb_vector_b.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_c(Number<register_offset>{}) =
emb_vector_c.template AsType<EmbType>()[i_row_vec_];
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
ck::type_convert<AccDataType>(
emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]);
});
});
});
};
......@@ -205,14 +179,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
AccDataType va =
ck::type_convert<AccDataType>(in_thread_buf_a(Number<register_offset>{}));
AccDataType vb =
ck::type_convert<AccDataType>(in_thread_buf_b(Number<register_offset>{}));
AccDataType vc =
ck::type_convert<AccDataType>(in_thread_buf_c(Number<register_offset>{}));
acc_thread_buf(Number<register_offset>{}) += va + vb + vc;
auto in_data_refs = generate_tie(
[&](auto i_embedding_) -> const auto& {
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
},
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto output_index_) -> auto& {
return acc_thread_buf(Number<register_offset>{});
},
Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
});
});
};
......@@ -242,7 +219,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
auto divisor =
1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
......@@ -250,9 +228,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
auto acc_val = acc_thread_buf[Number<register_offset>{}];
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) /
sqrt(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) * divisor;
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
beta_thread_buf[Number<gamma_beta_offset>{}];
out_vector.template AsType<OutType>()(Number<i_row_vec_>{}) =
......@@ -273,9 +250,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
// first load index
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
// prefer use s_load
index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value];
index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value];
index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value];
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
index_bufs(i_embedding_)(i_idx_) =
p_indexes[i_embedding_][index_start + i_idx_.value];
});
});
// load gamma/beta
......@@ -329,7 +307,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment