Commit 2d91fd12 authored by Anthony Chang's avatar Anthony Chang
Browse files

initial layernorm implementation

parent 57b7dca7
......@@ -7,7 +7,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
......@@ -25,6 +25,7 @@ template <typename ALayout,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -58,11 +59,14 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemm_Xdl_CShuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
struct DeviceGemmLayerNorm_Xdl_CShuffle
: public BaseOperator
{
using DeviceOp = DeviceGemm_Xdl_CShuffle;
using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -332,16 +336,69 @@ struct DeviceGemm_Xdl_CShuffle
}
}
// assuming packed tensor
static auto MakeGridDescriptor_M(index_t MRaw)
{
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return grid_desc_mraw;
}
}
static auto MakeGridDescriptor_N(index_t NRaw)
{
const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw));
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad N
return transform_tensor_descriptor(grid_desc_nraw,
make_tuple(make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad N
return grid_desc_nraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_N = decltype(MakeGridDescriptor_N(1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
ReduceAccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -349,6 +406,7 @@ struct DeviceGemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
C0GridDesc_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -380,6 +438,9 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
// Argument
......@@ -388,6 +449,9 @@ struct DeviceGemm_Xdl_CShuffle
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const CShuffleDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -400,10 +464,15 @@ struct DeviceGemm_Xdl_CShuffle
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_c0_bias_{p_c0_bias},
p_c0_gamma_{p_c0_gamma},
p_c0_beta_{p_c0_beta},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_nblock_nperblock_{},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......@@ -416,6 +485,9 @@ struct DeviceGemm_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
c0_grid_desc_nblock_nperblock_ = GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_);
// TODO ANT: adopt tensile style workgroup mapping
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
......@@ -424,11 +496,17 @@ struct DeviceGemm_Xdl_CShuffle
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const CShuffleDataType* p_c0_bias_;
const CShuffleDataType* p_c0_gamma_;
const CShuffleDataType* p_c0_beta_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_N c0_grid_desc_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock
c0_grid_desc_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -474,16 +552,18 @@ struct DeviceGemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
......@@ -496,28 +576,35 @@ struct DeviceGemm_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_bias_,
arg.p_c0_gamma_,
arg.p_c0_beta_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_nblock_nperblock_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
......@@ -527,12 +614,16 @@ struct DeviceGemm_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_bias_,
arg.p_c0_gamma_,
arg.p_c0_beta_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_nblock_nperblock_,
arg.block_2_ctile_map_);
}
......@@ -568,6 +659,9 @@ struct DeviceGemm_Xdl_CShuffle
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const CShuffleDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -581,6 +675,9 @@ struct DeviceGemm_Xdl_CShuffle
return Argument{p_a,
p_b,
p_c,
p_c0_bias,
p_c0_gamma,
p_c0_beta,
MRaw,
NRaw,
KRaw,
......@@ -594,10 +691,12 @@ struct DeviceGemm_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0_bias,
const void* p_c0_gamma,
const void* p_c0_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -607,11 +706,14 @@ struct DeviceGemm_Xdl_CShuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
index_t /* KBatch */ = 1)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const CShuffleDataType*>(p_c0_bias),
static_cast<const CShuffleDataType*>(p_c0_gamma),
static_cast<const CShuffleDataType*>(p_c0_beta),
MRaw,
NRaw,
KRaw,
......@@ -623,8 +725,7 @@ struct DeviceGemm_Xdl_CShuffle
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
std::unique_ptr<BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
......@@ -635,7 +736,7 @@ struct DeviceGemm_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGemm_Xdl_CShuffle"
str << "DeviceGemmLayerNorm_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment