Commit 238bfadd authored by Paul's avatar Paul
Browse files

Add simple fallback for now

parent 0b5fa390
...@@ -29,6 +29,7 @@ struct tensor_view ...@@ -29,6 +29,7 @@ struct tensor_view
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T)); assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
...@@ -36,6 +37,7 @@ struct tensor_view ...@@ -36,6 +37,7 @@ struct tensor_view
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T)); assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
......
#include <migraph/cpu/gemm.hpp> #include <migraph/cpu/gemm.hpp>
#include <migraph/dfor.hpp>
#include <migraph/requires.hpp> #include <migraph/requires.hpp>
#include <blaze/math/CustomMatrix.h> #include <blaze/math/CustomMatrix.h>
...@@ -50,9 +51,6 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -50,9 +51,6 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) { visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) { visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat); auto c = make_mat(cmat);
if(alpha == 1.0 and beta == 0.0)
c = a * b;
else
c = (a * b) * alpha + beta * c; c = (a * b) * alpha + beta * c;
}); });
}); });
...@@ -66,12 +64,23 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -66,12 +64,23 @@ void migemm_impl(tensor_view<T> cmat,
float beta, float beta,
std::false_type) std::false_type)
{ {
(void)cmat; auto m = cmat.get_shape().lens()[0];
(void)amat; auto n = cmat.get_shape().lens()[1];
(void)bmat; auto k = amat.get_shape().lens()[1];
(void)alpha;
(void)beta; assert(amat.get_shape().lens()[1] == bmat.get_shape().lens()[0]);
assert(true && "TODO"); assert(m == amat.get_shape().lens()[0]);
assert(n == bmat.get_shape().lens()[1]);
dfor(m, n)([&](auto ii, auto jj)
{
double s = cmat(ii, jj) * beta;
dfor(k)([&](auto kk)
{
s += amat(ii, kk) * bmat(kk, jj);
});
cmat(ii, jj) = alpha * s;
});
} }
template <class T> template <class T>
......
...@@ -242,14 +242,15 @@ void reshape_test() ...@@ -242,14 +242,15 @@ void reshape_test()
} }
} }
template<class T>
void gemm_test() void gemm_test()
{ {
migraph::program p; migraph::program p;
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632, -0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814}; -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01, std::vector<T> b = {6.09568541e-01,
-6.10527007e-01, -6.10527007e-01,
3.66646462e-01, 3.66646462e-01,
1.18951101e-01, 1.18951101e-01,
...@@ -264,7 +265,7 @@ void gemm_test() ...@@ -264,7 +265,7 @@ void gemm_test()
1.53027987e+00, 1.53027987e+00,
-3.81407415e-04, -3.81407415e-04,
-3.29650255e-01}; -3.29650255e-01};
std::vector<float> c = {-1.56327541e+00, std::vector<T> c = {-1.56327541e+00,
-7.09570140e-01, -7.09570140e-01,
-5.37424982e-01, -5.37424982e-01,
-2.22994831e-01, -2.22994831e-01,
...@@ -276,14 +277,14 @@ void gemm_test() ...@@ -276,14 +277,14 @@ void gemm_test()
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
migraph::shape a_shape{migraph::shape::float_type, {4, 5}}; migraph::shape a_shape{migraph::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraph::literal{a_shape, a});
migraph::shape b_shape{migraph::shape::float_type, {5, 3}}; migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraph::literal{b_shape, b}); auto bl = p.add_literal(migraph::literal{b_shape, b});
p.add_instruction(migraph::gemm{}, al, bl); p.add_instruction(migraph::gemm{}, al, bl);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for(int i = 0; i < results_vector.size(); i++) for(int i = 0; i < results_vector.size(); i++)
...@@ -656,7 +657,8 @@ int main() ...@@ -656,7 +657,8 @@ int main()
add_broadcast_test(); add_broadcast_test();
sub_test(); sub_test();
mul_test(); mul_test();
gemm_test(); gemm_test<float>();
gemm_test<double>();
reshape_test(); reshape_test();
transpose_test(); transpose_test();
contiguous_test(); contiguous_test();
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraph/gpu/miopen.hpp> #include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/manage_ptr.hpp> #include <migraph/manage_ptr.hpp>
#include <migraph/type_name.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
...@@ -48,7 +49,11 @@ void verify_program() ...@@ -48,7 +49,11 @@ void verify_program()
{ {
auto cpu_arg = run_cpu<V>(); auto cpu_arg = run_cpu<V>();
auto gpu_arg = run_gpu<V>(); auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) { EXPECT(test::verify_range(cpu, gpu)); }); visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
if(not test::verify_range(cpu, gpu)) {
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
}
});
} }
struct test_literals struct test_literals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment