#pragma once #include "host_tensor.hpp" template void host_gemm_mk_kn_mn(const Tensor& a_m_k, const Tensor& b_k_n, Tensor& c_m_n) { auto f_mk_kn_mn = [&](auto m, auto n) { const int K = a_m_k.mDesc.GetLengths()[1]; double v = 0; for(int k = 0; k < K; ++k) { v += static_cast(a_m_k(m, k)) * static_cast(b_k_n(k, n)); } c_m_n(m, n) = v; }; make_ParallelTensorFunctor(f_mk_kn_mn, c_m_n.mDesc.GetLengths()[0], c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); }