Commit bacf1572 authored by wsttiger's avatar wsttiger
Browse files

Resnet18 seems to run on GPU

parent 6ac569c4
...@@ -22,7 +22,7 @@ target_link_libraries(mnist migraph_cpu migraph_onnx) ...@@ -22,7 +22,7 @@ target_link_libraries(mnist migraph_cpu migraph_onnx)
add_executable(resnet18 resnet18.cpp) add_executable(resnet18 resnet18.cpp)
rocm_clang_tidy_check(resnet18) rocm_clang_tidy_check(resnet18)
target_link_libraries(resnet18 migraph_cpu migraph_onnx) target_link_libraries(resnet18 migraph_gpu migraph_onnx)
if(MIGRAPH_ENABLE_GPU) if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp) add_executable(verify_onnx verify_onnx.cpp)
......
...@@ -285,20 +285,19 @@ struct onnx_parser ...@@ -285,20 +285,19 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph) void parse_graph(const onnx::GraphProto& graph)
{ {
nodes = get_nodes(graph); nodes = get_nodes(graph);
std::unordered_map<std::string, size_t> initializer_data; std::unordered_map<std::string, onnx::TensorProto> initializer_data;
auto cnt = 0; auto cnt = 0;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
initializer_data[f.name()] = cnt++; initializer_data[f.name()] = f;
} }
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// Does the input have an initializer? // Does the input have an initializer?
if(initializer_data.find(name) != initializer_data.end()) if(contains(initializer_data, name))
{ {
auto idx = initializer_data[name]; auto t = initializer_data[name];
auto t = graph.initializer()[idx];
instructions[name] = prog.add_literal(parse_tensor(t)); instructions[name] = prog.add_literal(parse_tensor(t));
} }
else else
......
...@@ -6,15 +6,26 @@ ...@@ -6,15 +6,26 @@
#include <migraph/onnx.hpp> #include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp> #include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
std::string file = argv[1]; std::string file = argv[1];
auto prog = migraph::parse_onnx(file); auto prog = migraph::parse_onnx(file);
prog.compile(migraph::cpu::cpu_target{});
// GPU target
prog.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}}; auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
auto input3 = migraph::generate_argument(s, 12345); m["output"] = migraph::gpu::to_gpu(migraph::generate_argument(prog.get_parameter_shape("output")));
auto result = prog.eval({{"0", input3}}); m["0"] = migraph::gpu::to_gpu(migraph::generate_argument(s, 12345));
} auto result = migraph::gpu::from_gpu(prog.eval(m));
\ No newline at end of file
// // CPU target
// prog.compile(migraph::cpu::cpu_target{});
// auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
// auto input3 = migraph::generate_argument(s, 12345);
// auto result = prog.eval({{"0", input3}});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment