#pragma once #include "host_tensor.hpp" #include "gemm_common.hpp" template void host_gemm(const Tensor& a, const Tensor& b, Tensor& c, const GemmMatrixLayout layout) { if(layout == GemmMatrixLayout::MK_KN_MN) { auto f_mk_kn_mn = [&](auto m, auto n) { const int K = a.mDesc.GetLengths()[1]; double v = 0; for(int k = 0; k < K; ++k) { v += static_cast(a(m, k)) * static_cast(b(k, n)); } c(m, n) = v; }; make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } else if(layout == GemmMatrixLayout::MK_NK_MN) { auto f_mk_nk_mn = [&](auto m, auto n) { const int K = a.mDesc.GetLengths()[1]; double v = 0; for(int k = 0; k < K; ++k) { v += static_cast(a(m, k)) * static_cast(b(n, k)); } c(m, n) = v; }; make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } else if(layout == GemmMatrixLayout::KM_KN_MN) { auto f_km_kn_mn = [&](auto m, auto n) { const int K = a.mDesc.GetLengths()[0]; double v = 0; for(int k = 0; k < K; ++k) { v += static_cast(a(k, m)) * static_cast(b(k, n)); } c(m, n) = v; }; make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } else if(layout == GemmMatrixLayout::KM_NK_MN) { auto f_km_nk_mn = [&](auto m, auto n) { const int K = a.mDesc.GetLengths()[0]; double v = 0; for(int k = 0; k < K; ++k) { v += static_cast(a(k, m)) * static_cast(b(n, k)); } c(m, n) = v; }; make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } else { throw std::runtime_error("wrong! not supported layout"); } }