Commit 12db5b6d authored by Anthony Chang's avatar Anthony Chang
Browse files

c0 bias/beta/gamma now have its own precision type

parent d08aa99e
...@@ -29,6 +29,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -29,6 +29,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using C0DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
...@@ -38,21 +39,23 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -38,21 +39,23 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// using AccElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType, using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
C0DataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
...@@ -125,9 +128,9 @@ int main(int argc, char* argv[]) ...@@ -125,9 +128,9 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<C0DataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<CDataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<C0DataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<CDataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<C0DataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
...@@ -152,19 +155,18 @@ int main(int argc, char* argv[]) ...@@ -152,19 +155,18 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
// TODO ANT: test other init c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5}); c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 2});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<CDataType>{0, 2}); c0_n_beta.GenerateTensorValue(GeneratorTensor_2<C0DataType>{0, 5});
c0_n_beta.GenerateTensorValue(GeneratorTensor_2<CDataType>{0, 5});
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0}); c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(CDataType) * c0_n_bias.mDesc.GetElementSpace()); DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(CDataType) * c0_n_gamma.mDesc.GetElementSpace()); DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(CDataType) * c0_n_beta.mDesc.GetElementSpace()); DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
...@@ -182,9 +184,9 @@ int main(int argc, char* argv[]) ...@@ -182,9 +184,9 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_bias_buf.GetDeviceBuffer()), static_cast<C0DataType*>(c0_bias_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_gamma_buf.GetDeviceBuffer()), static_cast<C0DataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_beta_buf.GetDeviceBuffer()), static_cast<C0DataType*>(c0_beta_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
......
...@@ -27,6 +27,7 @@ template <typename ALayout, ...@@ -27,6 +27,7 @@ template <typename ALayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename C0DataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
...@@ -375,6 +376,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -375,6 +376,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
C0DataType,
ReduceAccDataType, ReduceAccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -426,9 +428,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -426,9 +428,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const CDataType* p_c0_bias, const C0DataType* p_c0_bias,
const CDataType* p_c0_gamma, const C0DataType* p_c0_gamma,
const CDataType* p_c0_beta, const C0DataType* p_c0_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -474,9 +476,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -474,9 +476,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const CDataType* p_c0_bias_; const C0DataType* p_c0_bias_;
const CDataType* p_c0_gamma_; const C0DataType* p_c0_gamma_;
const CDataType* p_c0_beta_; const C0DataType* p_c0_beta_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -533,6 +535,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -533,6 +535,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -570,6 +573,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -570,6 +573,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -685,9 +689,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -685,9 +689,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const CDataType*>(p_c0_bias), static_cast<const C0DataType*>(p_c0_bias),
static_cast<const CDataType*>(p_c0_gamma), static_cast<const C0DataType*>(p_c0_gamma),
static_cast<const CDataType*>(p_c0_beta), static_cast<const C0DataType*>(p_c0_beta),
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
......
...@@ -18,6 +18,7 @@ namespace ck { ...@@ -18,6 +18,7 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename FloatC0,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -35,9 +36,9 @@ __global__ void ...@@ -35,9 +36,9 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN FloatC* __restrict__ p_c_grid, // MxN
const FloatC* __restrict__ p_c0_bias_grid, // 1xN const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -94,6 +95,7 @@ template <typename FloatAB, ...@@ -94,6 +95,7 @@ template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
typename FloatC0,
typename FloatReduceAcc, // Data type after shuffle typename FloatReduceAcc, // Data type after shuffle
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -369,9 +371,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -369,9 +371,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_bias_grid, // 1xN const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -751,7 +753,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -751,7 +753,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC>( auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC0>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// TODO ANT: incorporate in singly defined p_shared. calculate proper total size in // TODO ANT: incorporate in singly defined p_shared. calculate proper total size in
...@@ -808,8 +810,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -808,8 +810,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC, FloatC0,
FloatC, FloatC0,
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
......
...@@ -12,6 +12,7 @@ namespace host { ...@@ -12,6 +12,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename C0DataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -88,9 +89,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -88,9 +89,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
{ {
Argument(const Tensor<ADataType>& a_m_k, Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const Tensor<CDataType>& c0_n_bias, // 1xN const Tensor<C0DataType>& c0_n_bias, // 1xN
const Tensor<CDataType>& c0_n_gamma, // 1xN const Tensor<C0DataType>& c0_n_gamma, // 1xN
const Tensor<CDataType>& c0_n_beta, // 1xN const Tensor<C0DataType>& c0_n_beta, // 1xN
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -111,9 +112,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -111,9 +112,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
const Tensor<ADataType>& a_m_k_; const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_; const Tensor<BDataType>& b_k_n_;
const Tensor<CDataType>& c0_n_bias_; const Tensor<C0DataType>& c0_n_bias_;
const Tensor<CDataType>& c0_n_gamma_; const Tensor<C0DataType>& c0_n_gamma_;
const Tensor<CDataType>& c0_n_beta_; const Tensor<C0DataType>& c0_n_beta_;
Tensor<CDataType>& c_m_n_; Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -164,9 +165,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -164,9 +165,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
static auto MakeArgument(const Tensor<ADataType>& a_m_k, static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const Tensor<CDataType>& c0_n_bias, // 1xN const Tensor<C0DataType>& c0_n_bias, // 1xN
const Tensor<CDataType>& c0_n_gamma, // 1xN const Tensor<C0DataType>& c0_n_gamma, // 1xN
const Tensor<CDataType>& c0_n_beta, // 1xN const Tensor<C0DataType>& c0_n_beta, // 1xN
Tensor<CDataType>& c_m_n, Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment