Commit 5c89193d authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use element-wise ops in the reference gemm

parent 0a0cfc4e
...@@ -67,14 +67,13 @@ struct ReferenceMXGemm : public device::BaseOperator ...@@ -67,14 +67,13 @@ struct ReferenceMXGemm : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using GemmInstance = ck::tensor_operation::host::ReferenceGemm<ComputeTypeA, using GemmInstance = ck::tensor_operation::host::ReferenceGemm<ComputeTypeA,
ComputeTypeB, ComputeTypeB,
CDataType, CDataType,
AccDataType, AccDataType,
PassThrough, AElementwiseOperation,
PassThrough, BElementwiseOperation,
PassThrough, CElementwiseOperation,
ComputeTypeA, ComputeTypeA,
ComputeTypeB>; ComputeTypeB>;
...@@ -111,9 +110,9 @@ struct ReferenceMXGemm : public device::BaseOperator ...@@ -111,9 +110,9 @@ struct ReferenceMXGemm : public device::BaseOperator
auto ref_argument = ref_gemm.MakeArgument(a_m_k_scaled, auto ref_argument = ref_gemm.MakeArgument(a_m_k_scaled,
b_k_n_scaled, b_k_n_scaled,
arg.c_m_n_, arg.c_m_n_,
PassThrough{}, arg.a_element_op_,
PassThrough{}, arg.b_element_op_,
PassThrough{}); arg.c_element_op_);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment