Commit e89422a8 authored by rocking's avatar rocking
Browse files

use reference layernorm

parent 180290ba
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -98,63 +99,45 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n, ...@@ -98,63 +99,45 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
CDEElementOp cde_element_op, CDEElementOp cde_element_op,
int M, int M,
int N, int N,
float epsilon = 1e-5) AccDataType epsilon = 1e-5)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemm = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
PassThrough>; PassThrough>;
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<HDataType,
GammaDataType,
BetaDataType,
HDataType,
AccDataType,
HElementOp,
2,
1>;
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N}); Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemm{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_gemm_invoker = ref_gemm.MakeInvoker();
auto ref_argument = auto ref_gemm_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument); ref_gemm_invoker.Run(ref_gemm_argument);
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n(m, n), c_m_n(m, n), bias_n(n), d1_m_n(m, n)); cde_element_op(e_m_n(m, n), c_m_n(m, n), bias_n(n), d1_m_n(m, n));
}
// LayerNorm
Tensor<AccDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<AccDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto layerNormInst = NormalizeFunctor{epsilon};
for(int m = 0; m < M; ++m)
{
AccDataType mean = 0;
AccDataType meanSquare = 0;
for(int n = 0; n < N; ++n)
{
auto e_val = ck::type_convert<AccDataType>(e_m_n(m, n));
mean += e_val;
meanSquare += e_val * e_val;
}
mean /= N; ReferenceLayernorm ref_layernorm;
meanSquare /= N; auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
for(int n = 0; n < N; ++n) auto ref_layernorm_argument = ref_layernorm.MakeArgument(
{ e_m_n, gamma_n, beta_n, h_m_n, HElementOp{}, {M, N}, {1}, epsilon);
AccDataType h_val = 0; ref_layernorm_invoker.Run(ref_layernorm_argument);
AccDataType e_val = ck::type_convert<AccDataType>(e_m_n(m, n));
AccDataType gamma_val = ck::type_convert<AccDataType>(gamma_n(n));
AccDataType beta_val = ck::type_convert<AccDataType>(beta_n(n));
layerNormInst(h_val, e_val, mean, meanSquare, gamma_val, beta_val);
h_m_n(m, n) = ck::type_convert<HDataType>(h_val);
}
}
} }
int main() int main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment