"vscode:/vscode.git/clone" did not exist on "ab7680838fdfb997ee96464d21f8ac7057e4e619"
Commit cff83eda authored by Paul's avatar Paul
Browse files

Use blaze to compute matrix multiply

parent d1481b13
......@@ -313,6 +313,8 @@ struct reshape
struct gemm
{
float alpha = 1.0;
float beta = 0.0;
std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -3,6 +3,10 @@ add_library(migraph_cpu
cpu_target.cpp
cpu_lowering.cpp
)
find_file(BLAZE_INCLUDE blaze/Blaze.h)
rocm_clang_tidy_check(migraph_cpu)
target_link_libraries(migraph_cpu migraph)
target_include_directories(migraph_cpu PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraph_cpu PRIVATE ${BLAZE_INCLUDE})
......@@ -5,6 +5,7 @@
#include <migraph/operators.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/iterator_for.hpp>
#include <blaze/Blaze.h>
#include <unordered_map>
namespace migraph {
......@@ -226,37 +227,43 @@ struct cpu_gemm
std::string name() const { return "cpu::gemm"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) {
auto m = amat.get_shape().lens()[0];
auto n = bmat.get_shape().lens()[1];
auto k = bmat.get_shape().lens()[0];
template<class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT
auto a = amat.data();
auto b = bmat.data();
auto c = cmat.data();
for(int ii = 0; ii < m; ii++)
{
for(int jj = 0; jj < n; jj++)
template<class T>
static auto make_mat(tensor_view<T> x)
{
c[ii * n + jj] = 0;
const auto& s = x.get_shape();
assert(s.lens().size() == 2);
assert(s.packed());
if(s.transposed())
return matrix<T>{x.data(), s.lens()[1], s.lens()[0]};
return matrix<T>{x.data(), s.lens()[0], s.lens()[1]};
}
}
for(int ii = 0; ii < m; ii++)
{
for(int kk = 0; kk < k; kk++)
{
auto aik = a[ii * k + kk];
auto* bkj = &b[kk * n];
auto* cij = &c[ii * n];
for(int jj = 0; jj < n; jj++, cij++, bkj++)
template<class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
*cij += aik * (*bkj);
}
}
auto mat = make_mat(x);
if(x.get_shape().transposed())
f(blaze::trans(mat));
else
f(mat);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) {
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
if (op.alpha == 1.0 and op.beta == 0.0)
c = a * b;
else
c = (a * b) * op.alpha + op.beta * c;
});
});
});
return result;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment