Commit adc64a23 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed reference gemm

parent 9212f569
......@@ -21,7 +21,7 @@ template <typename ADataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeType = ADataType>
typename ComputType = ADataType>
struct ReferenceGemm : public device::BaseOperator
{
// Argument
......@@ -65,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k)
{
ComputeType v_a;
ComputeType v_b;
ComputType v_a;
ComputType v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
......@@ -89,7 +89,8 @@ struct ReferenceGemm : public device::BaseOperator
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
v_acc += type_convert<AccDataType>(v_a * v_b);
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment