Commit 647d7908 authored by Paul's avatar Paul
Browse files

Add missing header

parent 238bfadd
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
namespace migraph {
template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name()
{
static std::string name;
if(name.empty())
{
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe =";
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
}
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return migraph::get_type_name<T>();
}
} // namespace migraph
#endif
......@@ -51,7 +51,7 @@ void migemm_impl(tensor_view<T> cmat,
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
c = (a * b) * alpha + beta * c;
c = (a * b) * alpha + beta * c;
});
});
}
......@@ -72,13 +72,9 @@ void migemm_impl(tensor_view<T> cmat,
assert(m == amat.get_shape().lens()[0]);
assert(n == bmat.get_shape().lens()[1]);
dfor(m, n)([&](auto ii, auto jj)
{
dfor(m, n)([&](auto ii, auto jj) {
double s = cmat(ii, jj) * beta;
dfor(k)([&](auto kk)
{
s += amat(ii, kk) * bmat(kk, jj);
});
dfor(k)([&](auto kk) { s += amat(ii, kk) * bmat(kk, jj); });
cmat(ii, jj) = alpha * s;
});
}
......
......@@ -242,41 +242,41 @@ void reshape_test()
}
}
template<class T>
template <class T>
void gemm_test()
{
migraph::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<T> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<T> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraph::shape a_shape{migraph::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraph::literal{a_shape, a});
migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}};
......
......@@ -49,8 +49,9 @@ void verify_program()
{
auto cpu_arg = run_cpu<V>();
auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
if(not test::verify_range(cpu, gpu)) {
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
if(not test::verify_range(cpu, gpu))
{
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
}
});
......
#include <migraph/type_name.hpp>
#include "test.hpp"
struct global_class
{
struct inner_class
{
};
};
namespace foo {
struct ns_class
{
struct inner_class
{
};
};
} // namespace foo
int main()
{
EXPECT(migraph::get_type_name<global_class>() == "global_class");
EXPECT(migraph::get_type_name<global_class::inner_class>() == "global_class::inner_class");
EXPECT(migraph::get_type_name<foo::ns_class>() == "foo::ns_class");
EXPECT(migraph::get_type_name<foo::ns_class::inner_class>() == "foo::ns_class::inner_class");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment