Commit 5c89193d authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Use element-wise ops in the reference gemm

parent 0a0cfc4e
......@@ -67,14 +67,13 @@ struct ReferenceMXGemm : public device::BaseOperator
float Run(const Argument& arg)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using GemmInstance = ck::tensor_operation::host::ReferenceGemm<ComputeTypeA,
ComputeTypeB,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputeTypeA,
ComputeTypeB>;
......@@ -111,9 +110,9 @@ struct ReferenceMXGemm : public device::BaseOperator
auto ref_argument = ref_gemm.MakeArgument(a_m_k_scaled,
b_k_n_scaled,
arg.c_m_n_,
PassThrough{},
PassThrough{},
PassThrough{});
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ref_invoker.Run(ref_argument);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment