Commit ddc76a8b authored by Anthony Chang's avatar Anthony Chang
Browse files

AccElemOp for gemm outputs prior to feeding to layernorm

parent 12db5b6d
...@@ -36,20 +36,29 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; ...@@ -36,20 +36,29 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
struct Relu
{
template<typename OutT, typename InT>
__host__ __device__ void operator()(OutT& y, const InT& x) const
{
y = x > 0 ? x : 0;
}
};
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using AccElementOp = Relu;
// using AccElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = Relu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType, using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType,
...@@ -59,6 +68,7 @@ using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADa ...@@ -59,6 +68,7 @@ using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADa
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
AccElementOp,
CElementOp>; CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -176,6 +186,7 @@ int main(int argc, char* argv[]) ...@@ -176,6 +186,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto acc_element_op = AccElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
...@@ -195,6 +206,7 @@ int main(int argc, char* argv[]) ...@@ -195,6 +206,7 @@ int main(int argc, char* argv[])
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
c_element_op); c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
...@@ -235,6 +247,7 @@ int main(int argc, char* argv[]) ...@@ -235,6 +247,7 @@ int main(int argc, char* argv[])
c_m_n_host_result, c_m_n_host_result,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -33,6 +33,7 @@ template <typename ALayout, ...@@ -33,6 +33,7 @@ template <typename ALayout,
typename ReduceAccDataType, typename ReduceAccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
...@@ -380,6 +381,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -380,6 +381,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
ReduceAccDataType, ReduceAccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -439,6 +441,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -439,6 +441,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
...@@ -455,6 +458,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -455,6 +458,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
block_2_ctile_map_{}, block_2_ctile_map_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
...@@ -489,6 +493,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -489,6 +493,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
}; };
...@@ -538,6 +543,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -538,6 +543,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
C0DataType, C0DataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
...@@ -560,6 +566,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -560,6 +566,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg.p_c0_beta_, arg.p_c0_beta_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
...@@ -576,6 +583,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -576,6 +583,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
C0DataType, C0DataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
...@@ -597,6 +605,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -597,6 +605,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg.p_c0_beta_, arg.p_c0_beta_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
...@@ -648,6 +657,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -648,6 +657,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a,
...@@ -664,6 +674,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -664,6 +674,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
c_element_op}; c_element_op};
} }
...@@ -683,6 +694,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -683,6 +694,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) index_t /* KBatch */ = 1)
{ {
...@@ -700,6 +712,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -700,6 +712,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
c_element_op); c_element_op);
} }
......
...@@ -21,6 +21,7 @@ template <typename GridwiseGemm, ...@@ -21,6 +21,7 @@ template <typename GridwiseGemm,
typename FloatC0, typename FloatC0,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
...@@ -41,6 +42,7 @@ __global__ void ...@@ -41,6 +42,7 @@ __global__ void
const FloatC0* __restrict__ p_c0_beta_grid, // 1xN const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
...@@ -62,6 +64,7 @@ __global__ void ...@@ -62,6 +64,7 @@ __global__ void
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -79,6 +82,7 @@ __global__ void ...@@ -79,6 +82,7 @@ __global__ void
ignore = p_c0_beta_grid; ignore = p_c0_beta_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = acc_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
...@@ -99,6 +103,7 @@ template <typename FloatAB, ...@@ -99,6 +103,7 @@ template <typename FloatAB,
typename FloatReduceAcc, // Data type after shuffle typename FloatReduceAcc, // Data type after shuffle
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -377,6 +382,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -377,6 +382,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
...@@ -630,7 +636,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -630,7 +636,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatCShuffle, FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, AccElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle, Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
I1, I1,
...@@ -654,7 +660,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -654,7 +660,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
m_thread_data_on_block_idx[I3], m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4], m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; acc_element_op};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
...@@ -16,6 +16,7 @@ template <typename ADataType, ...@@ -16,6 +16,7 @@ template <typename ADataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct ReferenceGemmLayernorm : public device::BaseOperator struct ReferenceGemmLayernorm : public device::BaseOperator
{ {
...@@ -25,7 +26,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -25,7 +26,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
AccDataType, AccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; AccElementwiseOperation>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta) // D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType> template <typename InDataType, typename OutDataType, typename ComputeDataType>
...@@ -95,6 +96,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -95,6 +96,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5) const CDataType epsilon = 1e-5)
: a_m_k_{a_m_k}, : a_m_k_{a_m_k},
...@@ -105,6 +107,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -105,6 +107,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
c_m_n_{c_m_n}, c_m_n_{c_m_n},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
epsilon_{epsilon} epsilon_{epsilon}
{ {
...@@ -119,7 +122,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -119,7 +122,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
const CDataType epsilon_; const CDataType epsilon_;
}; };
...@@ -140,11 +145,16 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -140,11 +145,16 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
acc_m_n, acc_m_n,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.acc_element_op_);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_n_gamma_, arg.c0_n_beta_); RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_n_gamma_, arg.c0_n_beta_);
arg.c_m_n_.ForEach([&](auto& self, auto idx){
arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1]));
});
return 0; return 0;
} }
...@@ -171,6 +181,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -171,6 +181,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5) const CDataType epsilon = 1e-5)
{ {
...@@ -182,6 +193,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -182,6 +193,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
c_m_n, c_m_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
c_element_op, c_element_op,
epsilon}; epsilon};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment