#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP #define REFERENCE_GEMM_BIAS_BIAS_2D_HPP #include #include #include "device_base.hpp" #include "host_tensor.hpp" namespace ck { namespace tensor_operation { namespace host { template struct ReferenceGemmBias2D : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { Argument(const Tensor& a_m_k, const Tensor& b_k_n, const Tensor& c0_m_n, Tensor& c_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) : a_m_k_{a_m_k}, b_k_n_{b_k_n}, c0_m_n_{c0_m_n}, c_m_n_{c_m_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} { } const Tensor& a_m_k_; const Tensor& b_k_n_; const Tensor& c0_m_n_; Tensor& c_m_n_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; }; // Invoker struct Invoker : public device::BaseInvoker { using Argument = ReferenceGemmBias2D::Argument; float Run(const Argument& arg) { auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_m_k_.mDesc.GetLengths()[1]; AccDataType a = 0; AccDataType b = 0; AccDataType acc = 0; for(int k = 0; k < K; ++k) { arg.a_element_op_(a, arg.a_m_k_(m, k)); arg.b_element_op_(b, arg.b_k_n_(k, n)); acc += a * b; } CDataType cast_acc = static_cast(acc); arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n)); }; make_ParallelTensorFunctor( f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); return 0; } float Run(const device::BaseArgument* p_arg, int) override { return Run(*dynamic_cast(p_arg)); } }; static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check return true; } bool IsSupportedArgument(const device::BaseArgument*) override { return true; } static auto MakeArgument(const Tensor& a_m_k, const Tensor& b_k_n, const Tensor& c0_m_n, Tensor& c_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op}; } static auto MakeInvoker() { return Invoker{}; } virtual std::unique_ptr MakeInvokerPointer() { return std::make_unique(Invoker{}); } std::string GetTypeString() const override { auto str = std::stringstream(); // clang-format off str << "ReferenceGemmBias2D" << std::endl; // clang-format on return str.str(); } }; } // namespace host } // namespace tensor_operation } // namespace ck #endif