Commit 238bfadd authored by Paul's avatar Paul
Browse files

Add simple fallback for now

parent 0b5fa390
......@@ -29,6 +29,7 @@ struct tensor_view
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const
{
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
......@@ -36,6 +37,7 @@ struct tensor_view
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs)
{
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
......
#include <migraph/cpu/gemm.hpp>
#include <migraph/dfor.hpp>
#include <migraph/requires.hpp>
#include <blaze/math/CustomMatrix.h>
......@@ -50,10 +51,7 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
if(alpha == 1.0 and beta == 0.0)
c = a * b;
else
c = (a * b) * alpha + beta * c;
c = (a * b) * alpha + beta * c;
});
});
}
......@@ -66,12 +64,23 @@ void migemm_impl(tensor_view<T> cmat,
float beta,
std::false_type)
{
(void)cmat;
(void)amat;
(void)bmat;
(void)alpha;
(void)beta;
assert(true && "TODO");
auto m = cmat.get_shape().lens()[0];
auto n = cmat.get_shape().lens()[1];
auto k = amat.get_shape().lens()[1];
assert(amat.get_shape().lens()[1] == bmat.get_shape().lens()[0]);
assert(m == amat.get_shape().lens()[0]);
assert(n == bmat.get_shape().lens()[1]);
dfor(m, n)([&](auto ii, auto jj)
{
double s = cmat(ii, jj) * beta;
dfor(k)([&](auto kk)
{
s += amat(ii, kk) * bmat(kk, jj);
});
cmat(ii, jj) = alpha * s;
});
}
template <class T>
......
......@@ -242,14 +242,15 @@ void reshape_test()
}
}
template<class T>
void gemm_test()
{
migraph::program p;
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
std::vector<T> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
......@@ -264,7 +265,7 @@ void gemm_test()
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
std::vector<T> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
......@@ -276,14 +277,14 @@ void gemm_test()
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraph::shape a_shape{migraph::shape::float_type, {4, 5}};
migraph::shape a_shape{migraph::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraph::literal{a_shape, a});
migraph::shape b_shape{migraph::shape::float_type, {5, 3}};
migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraph::literal{b_shape, b});
p.add_instruction(migraph::gemm{}, al, bl);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(12);
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for(int i = 0; i < results_vector.size(); i++)
......@@ -656,7 +657,8 @@ int main()
add_broadcast_test();
sub_test();
mul_test();
gemm_test();
gemm_test<float>();
gemm_test<double>();
reshape_test();
transpose_test();
contiguous_test();
......
......@@ -7,6 +7,7 @@
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/type_name.hpp>
#include <miopen/miopen.h>
......@@ -48,7 +49,11 @@ void verify_program()
{
auto cpu_arg = run_cpu<V>();
auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) { EXPECT(test::verify_range(cpu, gpu)); });
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
if(not test::verify_range(cpu, gpu)) {
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
}
});
}
struct test_literals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment