"docker/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "06a40afedc3e7ba338e8d55dfb22a913db6fb0dd"
Commit d62cd096 authored by rocking's avatar rocking
Browse files

Add base class for gemm layernorm

parent 8572bc7c
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : H[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// H = layernorm(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension in layernorm(E)
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename HLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename HElementwiseOperation>
struct DeviceGemmMultipleDLayernorm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_h,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideH,
double epsilon,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
HElementwiseOperation h_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -10,13 +10,13 @@ ...@@ -10,13 +10,13 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
namespace ck { namespace ck {
...@@ -236,7 +236,21 @@ template <typename ALayout, ...@@ -236,7 +236,21 @@ template <typename ALayout,
index_t LayernormBetaSrcVectorSize, index_t LayernormBetaSrcVectorSize,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
: public DeviceGemmMultipleDLayernorm<ALayout,
BLayout,
DsLayout,
HLayout,
ADataType,
BDataType,
DsDataType,
GammaDataType,
BetaDataType,
HDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
HElementwiseOperation>
{ {
using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle; using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle;
using ELayout = HLayout; using ELayout = HLayout;
...@@ -464,7 +478,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -464,7 +478,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideH, index_t StrideH,
AccDataType epsilon, double epsilon,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
...@@ -505,7 +519,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -505,7 +519,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
NRaw_{NRaw}, NRaw_{NRaw},
KRaw_{KRaw}, KRaw_{KRaw},
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{static_cast<AccDataType>(epsilon)}
{ {
gemm_mean_var_grid_desc_m_nblock_ = gemm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>( DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(
...@@ -931,7 +945,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -931,7 +945,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideH, index_t StrideH,
AccDataType epsilon, double epsilon,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
...@@ -973,11 +987,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -973,11 +987,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideH, index_t StrideH,
AccDataType epsilon, double epsilon,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
HElementwiseOperation h_element_op) HElementwiseOperation h_element_op) override
{ {
return std::make_unique<Argument>(p_a, return std::make_unique<Argument>(p_a,
p_b, p_b,
...@@ -1000,7 +1014,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -1000,7 +1014,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
} }
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment