Commit 3421b74c authored by Chao Liu's avatar Chao Liu
Browse files

use type_convert

parent 1deb01b2
...@@ -166,15 +166,15 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -166,15 +166,15 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType acc = AccDataType acc = ck::type_convert<AccDataType>(c_m_n(m, n)) +
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n)); ck::type_convert<AccDataType>(bias_n(n));
AccDataType c1 = static_cast<AccDataType>(c1_m_n(m, n)); AccDataType c1 = ck::type_convert<AccDataType>(c1_m_n(m, n));
c_element_op(acc, acc); c_element_op(acc, acc);
c1_element_op(c1, c1); c1_element_op(c1, c1);
acc += c1; acc += c1;
c_m_n(m, n) = static_cast<CDataType>(acc); c_m_n(m, n) = ck::type_convert<CDataType>(acc);
} }
// reduce_mean and reduce_square_mean // reduce_mean and reduce_square_mean
...@@ -208,12 +208,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -208,12 +208,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{ {
AccDataType out_acc = 0; AccDataType out_acc = 0;
layerNormInst(out_acc, layerNormInst(out_acc,
static_cast<AccDataType>(c_m_n(m, n)), ck::type_convert<AccDataType>(c_m_n(m, n)),
static_cast<AccDataType>(mean_m(m)), ck::type_convert<AccDataType>(mean_m(m)),
static_cast<AccDataType>(meanSquare_m(m)), ck::type_convert<AccDataType>(meanSquare_m(m)),
static_cast<AccDataType>(gamma_n(n)), ck::type_convert<AccDataType>(gamma_n(n)),
static_cast<AccDataType>(beta_n(n))); ck::type_convert<AccDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<ReduceDataType>(out_acc); out_m_n(m, n) = ck::type_convert<ReduceDataType>(out_acc);
} }
} }
} }
......
...@@ -135,9 +135,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -135,9 +135,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType v_b; AccDataType v_b;
arg.a_element_op_( arg.a_element_op_(
v_a, static_cast<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1))); v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_( arg.b_element_op_(
v_b, static_cast<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1))); v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
......
...@@ -134,9 +134,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -134,9 +134,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType v_b; AccDataType v_b;
arg.a_element_op_( arg.a_element_op_(
v_a, static_cast<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1))); v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_( arg.b_element_op_(
v_b, static_cast<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1))); v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment