Commit 3421b74c authored by Chao Liu's avatar Chao Liu
Browse files

use type_convert

parent 1deb01b2
......@@ -166,15 +166,15 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n)
{
AccDataType acc =
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n));
AccDataType acc = ck::type_convert<AccDataType>(c_m_n(m, n)) +
ck::type_convert<AccDataType>(bias_n(n));
AccDataType c1 = static_cast<AccDataType>(c1_m_n(m, n));
AccDataType c1 = ck::type_convert<AccDataType>(c1_m_n(m, n));
c_element_op(acc, acc);
c1_element_op(c1, c1);
acc += c1;
c_m_n(m, n) = static_cast<CDataType>(acc);
c_m_n(m, n) = ck::type_convert<CDataType>(acc);
}
// reduce_mean and reduce_square_mean
......@@ -208,12 +208,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{
AccDataType out_acc = 0;
layerNormInst(out_acc,
static_cast<AccDataType>(c_m_n(m, n)),
static_cast<AccDataType>(mean_m(m)),
static_cast<AccDataType>(meanSquare_m(m)),
static_cast<AccDataType>(gamma_n(n)),
static_cast<AccDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<ReduceDataType>(out_acc);
ck::type_convert<AccDataType>(c_m_n(m, n)),
ck::type_convert<AccDataType>(mean_m(m)),
ck::type_convert<AccDataType>(meanSquare_m(m)),
ck::type_convert<AccDataType>(gamma_n(n)),
ck::type_convert<AccDataType>(beta_n(n)));
out_m_n(m, n) = ck::type_convert<ReduceDataType>(out_acc);
}
}
}
......
......@@ -135,9 +135,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType v_b;
arg.a_element_op_(
v_a, static_cast<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, static_cast<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
......
......@@ -134,9 +134,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType v_b;
arg.a_element_op_(
v_a, static_cast<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, static_cast<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment