"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "91db978bb14ab3ef818c0b75036465fdb3a2fa0b"
Commit bcdacc1f authored by rocking's avatar rocking
Browse files

1. Declare e inside the host_gemm_layernorm()

2. Prevent implicit cast in reference code
parent 29ad7a36
...@@ -87,8 +87,7 @@ auto f_host_tensor_descriptor2d = ...@@ -87,8 +87,7 @@ auto f_host_tensor_descriptor2d =
} }
}; };
void host_gemm_layernorm(Tensor<EMeanVarDataType>& e_m_n, void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
Tensor<HDataType>& h_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const Tensor<D0DataType>& bias_n, const Tensor<D0DataType>& bias_n,
...@@ -119,6 +118,7 @@ void host_gemm_layernorm(Tensor<EMeanVarDataType>& e_m_n, ...@@ -119,6 +118,7 @@ void host_gemm_layernorm(Tensor<EMeanVarDataType>& e_m_n,
2, 2,
1>; 1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N}); Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemm{}; auto ref_gemm = ReferenceGemm{};
...@@ -129,9 +129,17 @@ void host_gemm_layernorm(Tensor<EMeanVarDataType>& e_m_n, ...@@ -129,9 +129,17 @@ void host_gemm_layernorm(Tensor<EMeanVarDataType>& e_m_n,
ref_gemm_invoker.Run(ref_gemm_argument); ref_gemm_invoker.Run(ref_gemm_argument);
for(int m = 0; m < M; ++m) for(int n = 0; n < N; ++n)
for(int n = 0; n < N; ++n) {
cde_element_op(e_m_n(m, n), c_m_n(m, n), bias_n(n), d1_m_n(m, n)); AccDataType bias = static_cast<AccDataType>(bias_n(n));
for(int m = 0; m < M; ++m)
{
AccDataType e = static_cast<AccDataType>(e_m_n(m, n));
AccDataType d1 = static_cast<AccDataType>(d1_m_n(m, n));
cde_element_op(e, c_m_n(m, n), bias, d1);
e_m_n(m, n) = static_cast<EMeanVarDataType>(e);
}
}
ReferenceLayernorm ref_layernorm; ReferenceLayernorm ref_layernorm;
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
...@@ -230,11 +238,8 @@ int main() ...@@ -230,11 +238,8 @@ int main()
if(do_verification) if(do_verification)
{ {
Tensor<EMeanVarDataType> e_m_n_host(HostTensorDescriptor{M, N});
Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N}); Tensor<HDataType> h_m_n_host(HostTensorDescriptor{M, N});
host_gemm_layernorm(h_m_n_host,
host_gemm_layernorm(e_m_n_host,
h_m_n_host,
a_m_k, a_m_k,
b_k_n, b_k_n,
d0_n, d0_n,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment