Commit aa027054 authored by rocking's avatar rocking
Browse files

Use reduceAccDataType instead of explicitly float

parent b3812da1
...@@ -261,14 +261,15 @@ int main(int argc, char* argv[]) ...@@ -261,14 +261,15 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetIdentityValue(); ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue(); ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n)); ReduceAccDataType c_val =
float d0_val = 0; ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
float d1_val = 0; ReduceAccDataType d0_val = 0;
ReduceAccDataType d1_val = 0;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
......
...@@ -162,7 +162,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -162,7 +162,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ReduceAccDataType c_val = ck::type_convert<float>(c_m_n(m, n)); ReduceAccDataType c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
ReduceAccDataType square_c_val = 0; ReduceAccDataType square_c_val = 0;
UnarySquareElementOp{}(square_c_val, c_val); UnarySquareElementOp{}(square_c_val, c_val);
......
...@@ -173,6 +173,8 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -173,6 +173,8 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
using ReduceAccDataType = DDataType;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -184,26 +186,27 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -184,26 +186,27 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float acc = ReduceAccDataType acc = static_cast<ReduceAccDataType>(c_m_n_host_result(m, n)) +
static_cast<float>(c_m_n_host_result(m, n)) + static_cast<float>(bias_n(n)); static_cast<ReduceAccDataType>(bias_n(n));
float c1 = c1_m_n(m, n); ReduceAccDataType c1 = c1_m_n(m, n);
c_element_op(acc, acc); c_element_op(acc, acc);
c1_element_op(c1, c1); c1_element_op(c1, c1);
acc += static_cast<float>(c1); acc += static_cast<ReduceAccDataType>(c1);
c_m_n_host_result(m, n) = static_cast<CDataType>(acc); c_m_n_host_result(m, n) = static_cast<CDataType>(acc);
} }
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetIdentityValue(); ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue(); ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n)); ReduceAccDataType c_val =
float d0_val = 0; ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
float d1_val = 0; ReduceAccDataType d0_val = 0;
ReduceAccDataType d1_val = 0;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
......
...@@ -155,6 +155,8 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -155,6 +155,8 @@ bool profile_gemm_reduce_impl(int do_verification,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
using ReduceAccDataType = DDataType;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -165,14 +167,15 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -165,14 +167,15 @@ bool profile_gemm_reduce_impl(int do_verification,
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetIdentityValue(); ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue(); ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n)); ReduceAccDataType c_val =
float d0_val = 0; ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
float d1_val = 0; ReduceAccDataType d0_val = 0;
ReduceAccDataType d1_val = 0;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment