Commit 2d91fd12 authored by Anthony Chang's avatar Anthony Chang
Browse files

initial layernorm implementation

parent 57b7dca7
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp" #include "tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck { namespace ck {
...@@ -25,6 +25,7 @@ template <typename ALayout, ...@@ -25,6 +25,7 @@ template <typename ALayout,
typename CDataType, typename CDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -58,11 +59,14 @@ template <typename ALayout, ...@@ -58,11 +59,14 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemm_Xdl_CShuffle struct DeviceGemmLayerNorm_Xdl_CShuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public BaseOperator
{ {
using DeviceOp = DeviceGemm_Xdl_CShuffle; using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -332,16 +336,69 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -332,16 +336,69 @@ struct DeviceGemm_Xdl_CShuffle
} }
} }
// assuming packed tensor
static auto MakeGridDescriptor_M(index_t MRaw)
{
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return grid_desc_mraw;
}
}
static auto MakeGridDescriptor_N(index_t NRaw)
{
const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N
return transform_tensor_descriptor(grid_desc_nraw,
make_tuple(make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_nraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_N = decltype(MakeGridDescriptor_N(1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
ReduceAccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -349,6 +406,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -349,6 +406,7 @@ struct DeviceGemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
C0GridDesc_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -380,6 +438,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -380,6 +438,9 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>; LoopSched>;
// Argument // Argument
...@@ -388,6 +449,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -388,6 +449,9 @@ struct DeviceGemm_Xdl_CShuffle
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const CShuffleDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -400,10 +464,15 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -400,10 +464,15 @@ struct DeviceGemm_Xdl_CShuffle
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_c0_bias_{p_c0_bias},
p_c0_gamma_{p_c0_gamma},
p_c0_beta_{p_c0_beta},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_nblock_nperblock_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
...@@ -416,6 +485,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -416,6 +485,9 @@ struct DeviceGemm_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
c0_grid_desc_nblock_nperblock_ = GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_);
// TODO ANT: adopt tensile style workgroup mapping
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
} }
} }
...@@ -424,11 +496,17 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -424,11 +496,17 @@ struct DeviceGemm_Xdl_CShuffle
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const CShuffleDataType* p_c0_bias_;
const CShuffleDataType* p_c0_gamma_;
const CShuffleDataType* p_c0_beta_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_N c0_grid_desc_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock
c0_grid_desc_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -474,16 +552,18 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -474,16 +552,18 @@ struct DeviceGemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -496,28 +576,35 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -496,28 +576,35 @@ struct DeviceGemm_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_bias_,
arg.p_c0_gamma_,
arg.p_c0_beta_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -527,12 +614,16 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -527,12 +614,16 @@ struct DeviceGemm_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_bias_,
arg.p_c0_gamma_,
arg.p_c0_beta_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -568,6 +659,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -568,6 +659,9 @@ struct DeviceGemm_Xdl_CShuffle
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
const CShuffleDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -581,6 +675,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -581,6 +675,9 @@ struct DeviceGemm_Xdl_CShuffle
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_c,
p_c0_bias,
p_c0_gamma,
p_c0_beta,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
...@@ -594,10 +691,12 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -594,10 +691,12 @@ struct DeviceGemm_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const void* p_c0_bias,
const void* p_c0_gamma,
const void* p_c0_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -607,11 +706,14 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -607,11 +706,14 @@ struct DeviceGemm_Xdl_CShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const CShuffleDataType*>(p_c0_bias),
static_cast<const CShuffleDataType*>(p_c0_gamma),
static_cast<const CShuffleDataType*>(p_c0_beta),
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
...@@ -623,8 +725,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -623,8 +725,7 @@ struct DeviceGemm_Xdl_CShuffle
c_element_op); c_element_op);
} }
// polymorphic std::unique_ptr<BaseInvoker> MakeInvokerPointer()
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
...@@ -635,7 +736,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -635,7 +736,7 @@ struct DeviceGemm_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGemm_Xdl_CShuffle" str << "DeviceGemmLayerNorm_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -8,27 +8,36 @@ ...@@ -8,27 +8,36 @@
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp"
#include "reduction_functions_blockwise.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
// D = Layernorm(A * B + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename FloatCShuffle,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_NBlock_NPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid, // MxN
const FloatCShuffle* __restrict__ p_c0_bias_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_gamma_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -36,14 +45,20 @@ __global__ void ...@@ -36,14 +45,20 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock
c0_grid_desc_nblock_nperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_bias_grid,
p_c0_gamma_grid,
p_c0_beta_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -51,17 +66,24 @@ __global__ void ...@@ -51,17 +66,24 @@ __global__ void
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
// TODO ANT: Run layernorm epilogue here
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_c0_bias_grid;
ignore = p_c0_gamma_grid;
ignore = p_c0_beta_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_nblock_nperblock;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -70,6 +92,7 @@ template <typename FloatAB, ...@@ -70,6 +92,7 @@ template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
typename FloatReduceAcc, // Data type after shuffle
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -77,6 +100,7 @@ template <typename FloatAB, ...@@ -77,6 +100,7 @@ template <typename FloatAB,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename C0GridDesc_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -108,8 +132,11 @@ template <typename FloatAB, ...@@ -108,8 +132,11 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -155,9 +182,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -155,9 +182,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{}, Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{}, // 1 * MWave * 32
I1, I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{})); Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{})); // 1 * NWave * 32
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
...@@ -209,6 +236,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -209,6 +236,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
return false; return false;
// in order to reduce N dim without elaborate sync across CUs in single kernel, one
// workgroup must span the entire N extent
if(math::integer_divide_ceil(N, NPerBlock) > 1)
{
return false;
}
// static check: all waves in the workgroups combined must cover whole N extent in order
// to have efficient N-dim reduction
static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave, "condition not met for efficient layernorm");
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = K / KPerBlock;
...@@ -258,6 +296,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -258,6 +296,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_grid_desc_mblock_mperblock_nblock_nperblock; return c_grid_desc_mblock_mperblock_nblock_nperblock;
} }
// for broadcasting bias, beta, gamma
// __host__ __device__ static constexpr auto
// MakeC0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock)
// {
// const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0);
// const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
// c0_grid_desc_nblock_nperblock,
// make_tuple(make_insert_transform(I1),
// make_insert_transform(I1),
// make_pass_through_transform(NBlock),
// make_pass_through_transform(NPerBlock)),
// make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// return c0_grid_desc_mblock_mperblock_nblock_nperblock;
// }
// for bias, beta, gamma
__host__ __device__ static constexpr auto
MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n)
{
const auto N = c0_grid_desc_n.GetLength(I0);
const auto NBlock = N / NPerBlock;
const auto c0_grid_desc_nblock_nperblock = transform_tensor_descriptor(
c0_grid_desc_n,
make_tuple(make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return c0_grid_desc_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
...@@ -301,6 +373,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -301,6 +373,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using C0GridDescriptor_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
...@@ -308,6 +383,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -308,6 +383,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatCShuffle* __restrict__ p_c0_bias_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_gamma_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -316,6 +394,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -316,6 +394,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock&
c0_grid_desc_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -324,6 +404,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -324,6 +404,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_bias_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
auto c0_gamma_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
auto c0_beta_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
// if (hipThreadIdx_x == 0 && hipBlockIdx_x == 0) c_grid_desc_mblock_mperblock_nblock_nperblock.Print();
/*
{TensorDescriptor,
transforms: {Embed, up_lengths_ {MultiIndex, size 2,256 128 }coefficients_ {MultiIndex, size 2,128 1 }}LowerDimensionIds:{size 1, 0 }UpperDimensionIds:{size 2, 1 2 }
transforms: {UnMerge, up_lengths_{MultiIndex, size 2,1 256 }up_lengths_scan_{MultiIndex, size 2,256 1 }}LowerDimensionIds:{size 1, 1 }UpperDimensionIds:{size 2, 3 4 }
transforms: {UnMerge, up_lengths_{MultiIndex, size 2,1 128 }up_lengths_scan_{MultiIndex, size 2,128 1 }}LowerDimensionIds:{size 1, 2 }UpperDimensionIds:{size 2, 5 6 }
}
{size 4, 3 4 5 6 }
*/
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -609,6 +705,164 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -609,6 +705,164 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op}; c_element_op};
// add bias: load bias to vgpr buffer, add to LDS
const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0);
// for broadcasting bias, beta, gamma
const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c0_grid_desc_nblock_nperblock,
make_tuple(make_insert_transform(I1),
make_insert_transform(I1),
make_pass_through_transform(NBlock),
make_pass_through_transform(NPerBlock)),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c0_grid_desc_mperblock_nperblock = transform_tensor_descriptor(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c0_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c0_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) *
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
BlockSize,
"wrong!");
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) ==
0 &&
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
0,
"wrong!");
constexpr index_t mreduce_per_thread =
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0);
constexpr index_t nreduce_per_thread =
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1);
constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
Sequence<mreduce_per_thread, nreduce_per_thread>{};
// pytorch default
// https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
static constexpr FloatReduceAcc epsilon = 1e-5;
// VGPR c_reduce_thread_desc_mperblock_nperblock
constexpr auto c_reduce_thread_desc_mperblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mblock_mperblock
constexpr auto d_reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// TODO ANT: incorporate in singly defined p_shared. calculate proper total size in
// GetSharedMemoryNumberOfByte() and shift pointer as approriate
__shared__ FloatReduceAcc p_d_reduce_work_buffer[BlockSize];
auto d_reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_d_reduce_work_buffer, BlockSize);
// Sum thread workspace
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// Squared sum thread workspace
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
const auto c_reduce_thread_cluster_idx =
c_reduce_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto c_reduce_thread_data_idx_begin =
c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatReduceAcc,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
Sequence<0, 1>,
1,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto c_reduce_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
FloatCShuffle,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_block_desc_mperblock_nperblock),
tensor_operation::element_wise::PassThrough,
decltype(c_reduce_thread_lengths_mperblock_nperblock),
Sequence<0, 1>,
1,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin, tensor_operation::element_wise::PassThrough{}};
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatReduceAcc,
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>(c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(
I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -637,6 +891,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -637,6 +891,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
// __syncthreads();
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
...@@ -647,8 +902,175 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -647,8 +902,175 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
// debug::print_shared(c_shuffle_block_buf.p_data_, c_shuffle_block_buf.element_space_size_);
// __syncthreads();
block_sync_lds();
// layernorm
{
// load from LDS and global, add bias
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_bias_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) {
// auto thread_slice_desc = make_cluster_descriptor(
// Sequence<mreduce_per_thread, nreduce_per_thread>{});
// auto thread_slice_idx = thread_slice_desc.CalculateBottomIndex(make_multi_index(i));
// printf("tid %zd, access_id %d, im, in %d, %d, c0 = %f, c = %f\n",
// hipThreadIdx_x,
// access_id.value,
// thread_slice_idx[I0],
// thread_slice_idx[I1],
// c0_thread_buf(i),
// c_reduce_thread_buf(i));
c_reduce_thread_buf(i) += c0_thread_buf(i);
});
// static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
// static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
// constexpr auto offset =
// Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
// make_tuple(im, in))>{};
// c_reduce_thread_buf(offset) += c0_thread_buf(offset);
// // printf("tid %zd, access_id %d, im, in %d, %d, c0 = %f, c+c0 = %f\n",
// // hipThreadIdx_x,
// // access_id.value,
// // im.value,
// // in.value,
// // c0_thread_buf(offset),
// // c_reduce_thread_buf(offset));
// });
// });
using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::Add<FloatReduceAcc>,
false>;
using ThreadwiseReduceD1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::SquaredAdd<FloatReduceAcc>,
false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetReductionZeroVal();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d1_thread_buf(i) = d1_zeroVal; });
// reduce sum in VGPR
ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf);
// reduce squared sum in VGPR
ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// reduce across workgorup
using BlockwiseReduce = PartitionedBlockwiseReduction<FloatReduceAcc,
BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
Sequence<1, 0>, // ThreadClusterArrangeOrder
reduce::Add<FloatReduceAcc>,
false>;
static_for<0, mreduce_per_thread, 1>{}([&](auto i) {
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf, d0_thread_buf(i)); // blockwise reduced sum
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf, d1_thread_buf(i)); // blockwise reduced squared sum
// printf("tid %zd, access_id %d, mreduce_idx %d, sum = %f, sq sum = %f\n",
// hipThreadIdx_x,
// access_id.value,
// i.value,
// d0_thread_buf(i),
// d1_thread_buf(i));
});
// normalize
const index_t NRaw = c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0].GetUpperLengths()[I1]; // TODO: proper handle
// if(hipThreadIdx_x == 0) printf("NRaw = %d\n", NRaw);
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto dst_offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
constexpr auto src_offset =
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
make_tuple(im))>{};
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
FloatReduceAcc denom = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt<FloatReduceAcc, FloatReduceAcc>{}(divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = denom / divisor_sqrt;
// printf("tid %zd, access_id %d, reduce_idx %d %d, avg_sum = %f, avg sq sum = %f, final = %f\n",
// hipThreadIdx_x,
// access_id.value,
// im.value,
// in.value,
// avg_sum,
// avg_squared_sum,
// c_reduce_thread_buf(dst_offset));
});
});
// scaling
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_gamma_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) {
c_reduce_thread_buf(i) *= c0_thread_buf(i); // * gamma
});
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_beta_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) {
c_reduce_thread_buf(i) += c0_thread_buf(i); // + beta
});
// __syncthreads();
block_sync_lds(); block_sync_lds();
c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf,
c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf);
} // end layernorm
// __syncthreads();
block_sync_lds();
// debug::print_shared<32>(c_shuffle_block_buf.p_data_, c_shuffle_block_buf.element_space_size_);
// each block copy its data from LDS to global // each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run( c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
...@@ -663,6 +1085,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -663,6 +1085,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C0 bias
c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment