Commit ef7aa280 authored by Paul's avatar Paul
Browse files

Update onnx verification

parent fe91009b
...@@ -12,9 +12,12 @@ migraph::argument run_cpu(std::string file) ...@@ -12,9 +12,12 @@ migraph::argument run_cpu(std::string file)
{ {
auto p = migraph::parse_onnx(file); auto p = migraph::parse_onnx(file);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto s = p.get_parameter_shape("Input3"); migraph::program::parameter_map m;
auto input3 = migraph::generate_argument(s); for(auto&& x : p.get_parameter_shapes())
auto out = p.eval({{"Input3", input3}}); {
m[x.first] = migraph::generate_argument(x.second);
}
auto out = p.eval(m);
std::cout << p << std::endl; std::cout << p << std::endl;
return out; return out;
} }
...@@ -22,14 +25,14 @@ migraph::argument run_cpu(std::string file) ...@@ -22,14 +25,14 @@ migraph::argument run_cpu(std::string file)
migraph::argument run_gpu(std::string file) migraph::argument run_gpu(std::string file)
{ {
auto p = migraph::parse_onnx(file); auto p = migraph::parse_onnx(file);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::gpu::target{});
auto s = p.get_parameter_shape("Input3");
auto input3 = migraph::gpu::to_gpu(migraph::generate_argument(s));
auto output = migraph::gpu::to_gpu(migraph::generate_argument(p.get_parameter_shape("output"))); migraph::program::parameter_map m;
auto handle = migraph::gpu::make_obj<migraph::gpu::miopen_handle>(&miopenCreate); for(auto&& x : p.get_parameter_shapes())
{
auto out = p.eval({{"Input3", input3}, {"output", output}}); m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
auto out = migraph::gpu::from_gpu(p.eval(m));
std::cout << p << std::endl; std::cout << p << std::endl;
return migraph::gpu::from_gpu(out); return migraph::gpu::from_gpu(out);
} }
......
...@@ -138,7 +138,7 @@ struct miopen_pooling ...@@ -138,7 +138,7 @@ struct miopen_pooling
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(0)});
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment