"tests/pipelines/vscode:/vscode.git/clone" did not exist on "4974b84564d25bd4b5c594db4e04cb885cc0a9ed"
Commit adc64a23 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed reference gemm

parent 9212f569
...@@ -21,7 +21,7 @@ template <typename ADataType, ...@@ -21,7 +21,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputeType = ADataType> typename ComputType = ADataType>
struct ReferenceGemm : public device::BaseOperator struct ReferenceGemm : public device::BaseOperator
{ {
// Argument // Argument
...@@ -65,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -65,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ComputeType v_a; ComputType v_a;
ComputeType v_b; ComputType v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
...@@ -89,7 +89,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -89,7 +89,8 @@ struct ReferenceGemm : public device::BaseOperator
arg.b_element_op_(v_b, arg.b_k_n_(k, n)); arg.b_element_op_(v_b, arg.b_k_n_(k, n));
} }
v_acc += type_convert<AccDataType>(v_a * v_b); v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
CDataType v_c; CDataType v_c;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment