Commit 5aa4e686 authored by Paul's avatar Paul
Browse files

Unify verify program

parent 5883ffad
...@@ -27,6 +27,8 @@ struct program ...@@ -27,6 +27,8 @@ struct program
program& operator=(program&&) noexcept; program& operator=(program&&) noexcept;
~program() noexcept; ~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts> template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
...@@ -64,7 +66,7 @@ struct program ...@@ -64,7 +66,7 @@ struct program
shape get_parameter_shape(std::string name); shape get_parameter_shape(std::string name);
argument eval(std::unordered_map<std::string, argument> params) const; argument eval(parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
......
...@@ -13,49 +13,64 @@ ...@@ -13,49 +13,64 @@
#include "test.hpp" #include "test.hpp"
#include "verify.hpp" #include "verify.hpp"
rtg::program create_program() template<class V>
rtg::argument run_cpu()
{ {
rtg::program p; V v;
auto input = p.add_parameter("x", rtg::shape{rtg::shape::float_type, {4, 3, 3, 3}}); auto p = v.create_program();
auto weights = p.add_parameter("w", rtg::shape{rtg::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(rtg::convolution{}, input, weights);
p.add_instruction(rtg::activation{"relu"}, conv);
return p;
}
std::vector<float> cpu()
{
std::vector<float> result;
auto p = create_program();
auto x = rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}});
auto w = rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}});
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto r = p.eval({{"x", x}, {"w", w}}); return p.eval(v.create_params());
auto output = r.get<float>();
result.assign(output.begin(), output.end());
return result;
} }
std::vector<float> gpu() template<class V>
rtg::argument run_gpu()
{ {
std::vector<float> result; V v;
auto p = create_program(); auto p = v.create_program();
auto x = rtg::miopen::to_gpu(rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}}));
auto w = rtg::miopen::to_gpu(rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}}));
p.compile(rtg::miopen::miopen_target{}); p.compile(rtg::miopen::miopen_target{});
auto y = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto m = v.create_params();
for(auto&& e:m)
{
e.second = rtg::miopen::to_gpu(e.second);
}
m["output"] = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate); auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
auto r = p.eval( m["handle"] = {rtg::shape::any_type, handle.get()};
{{"x", x}, {"w", w}, {"output", y}, {"handle", {rtg::shape::any_type, handle.get()}}});
result = rtg::miopen::read_from_gpu<float>(r.data(), r.get_shape().elements()); return rtg::miopen::from_gpu(p.eval(m));
return result;
} }
void test1() template<class V>
void verify_program()
{ {
auto x = cpu(); auto cpu_arg = run_cpu<V>();
auto y = gpu(); auto gpu_arg = run_gpu<V>();
EXPECT(test::verify_range(x, y)); visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
EXPECT(test::verify_range(cpu, gpu));
});
} }
int main() { test1(); } struct test1
{
rtg::program create_program() const
{
rtg::program p;
auto input = p.add_parameter("x", rtg::shape{rtg::shape::float_type, {4, 3, 3, 3}});
auto weights = p.add_parameter("w", rtg::shape{rtg::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(rtg::convolution{}, input, weights);
p.add_instruction(rtg::activation{"relu"}, conv);
return p;
}
rtg::program::parameter_map create_params() const
{
rtg::program::parameter_map m;
m["x"] = rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}});
m["w"] = rtg::generate_argument({rtg::shape::float_type, {4, 3, 3, 3}});
return m;
}
};
int main() { verify_program<test1>(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment