Commit 742e6a4b authored by Scott Thornton's avatar Scott Thornton
Browse files

Added gemm operator and cpu target

parent dc2b0abf
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
#include <array>
#include <rtg/operation.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/streamutils.hpp>
......@@ -218,6 +219,33 @@ struct reshape
}
};
struct gemm
{
std::string name() const { return "gemm";}
std::size_t lda = 1;
std::size_t ldb = 1;
std::size_t ldc = 1;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims().only_dims(2);
const shape& A = inputs.at(0);
const shape& B = inputs.at(1);
auto t = A.type();
if (A.lens()[1] != B.lens()[0])
RTG_THROW("Inner dimensions do not match");
return {t, {A.lens()[0], B.lens()[1]}};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
os << op.name() << "[";
os << "]";
}
};
} // namespace rtg
#endif
......@@ -47,6 +47,45 @@ struct cpu_convolution
}
};
struct cpu_gemm
{
gemm op;
std::string name() const { return "cpu::gemm"; }
shape compute_shape(std::vector<shape> inputs)
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument C{output_shape};
visit_all(C, args[0], args[1])([&](auto C, auto A, auto B) {
auto M = A.get_shape().lens()[0];
auto N = B.get_shape().lens()[1];
auto K = B.get_shape().lens()[0];
auto a = A.data();
auto b = B.data();
auto c = C.data();
for (int ii = 0; ii < M; ii++) {
for (int jj = 0; jj < N; jj++) {
c[ii*N+jj] = 0;
}
}
for (int ii = 0; ii < M; ii++) {
for (int kk = 0; kk < K; kk++) {
auto aik = a[ii*K+kk];
auto* bkj = &b[kk*N];
auto* cij = &c[ii*N];
for (int jj = 0; jj < N; jj++, cij++, bkj++) {
*cij += aik*(*bkj);
}
}
}
});
}
};
struct relu
{
std::string name() const { return "cpu::relu"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment