Commit 2afa47f1 authored by Paul's avatar Paul
Browse files

Add test for literals

parent 0a9d6a09
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
namespace migraph { namespace migraph {
migraph::argument generate_argument(migraph::shape s, std::mt19937::result_type seed) argument generate_argument(shape s, std::mt19937::result_type seed)
{ {
migraph::argument result; argument result;
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
using type = typename decltype(as)::type; using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed); auto v = generate_tensor_data<type>(s, seed);
...@@ -13,4 +13,15 @@ migraph::argument generate_argument(migraph::shape s, std::mt19937::result_type ...@@ -13,4 +13,15 @@ migraph::argument generate_argument(migraph::shape s, std::mt19937::result_type
return result; return result;
} }
literal generate_literal(shape s, std::mt19937::result_type seed)
{
literal result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed);
result = {s, v};
});
return result;
}
} // namespace migraph } // namespace migraph
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/literal.hpp>
#include <random> #include <random>
namespace migraph { namespace migraph {
...@@ -16,7 +17,9 @@ std::vector<T> generate_tensor_data(migraph::shape s, std::mt19937::result_type ...@@ -16,7 +17,9 @@ std::vector<T> generate_tensor_data(migraph::shape s, std::mt19937::result_type
return result; return result;
} }
migraph::argument generate_argument(migraph::shape s, std::mt19937::result_type seed = 0); argument generate_argument(shape s, std::mt19937::result_type seed = 0);
literal generate_literal(shape s, std::mt19937::result_type seed = 0);
} // namespace migraph } // namespace migraph
......
...@@ -49,6 +49,24 @@ void verify_program() ...@@ -49,6 +49,24 @@ void verify_program()
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) { EXPECT(test::verify_range(cpu, gpu)); }); visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) { EXPECT(test::verify_range(cpu, gpu)); });
} }
struct test_literals
{
migraph::program create_program() const
{
migraph::program p;
auto input = p.add_literal(generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}));
auto weights = p.add_literal(generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}));
auto conv = p.add_instruction(migraph::convolution{}, input, weights);
p.add_instruction(migraph::activation{"relu"}, conv);
return p;
}
migraph::program::parameter_map create_params() const
{
return {};
}
};
struct test_add struct test_add
{ {
migraph::program create_program() const migraph::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment