"test/vscode:/vscode.git/clone" did not exist on "54e1dfd1fa43fc9cce39869a692db17cb0fa3a3a"
Commit bacf1572 authored by wsttiger's avatar wsttiger
Browse files

Resnet18 seems to run on GPU

parent 6ac569c4
......@@ -22,7 +22,7 @@ target_link_libraries(mnist migraph_cpu migraph_onnx)
add_executable(resnet18 resnet18.cpp)
rocm_clang_tidy_check(resnet18)
target_link_libraries(resnet18 migraph_cpu migraph_onnx)
target_link_libraries(resnet18 migraph_gpu migraph_onnx)
if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp)
......
......@@ -285,20 +285,19 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph)
{
nodes = get_nodes(graph);
std::unordered_map<std::string, size_t> initializer_data;
std::unordered_map<std::string, onnx::TensorProto> initializer_data;
auto cnt = 0;
for(auto&& f : graph.initializer())
{
initializer_data[f.name()] = cnt++;
initializer_data[f.name()] = f;
}
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// Does the input have an initializer?
if(initializer_data.find(name) != initializer_data.end())
if(contains(initializer_data, name))
{
auto idx = initializer_data[name];
auto t = graph.initializer()[idx];
auto t = initializer_data[name];
instructions[name] = prog.add_literal(parse_tensor(t));
}
else
......
......@@ -6,15 +6,26 @@
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
int main(int argc, char const* argv[])
{
std::string file = argv[1];
auto prog = migraph::parse_onnx(file);
prog.compile(migraph::cpu::cpu_target{});
// GPU target
prog.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
auto input3 = migraph::generate_argument(s, 12345);
auto result = prog.eval({{"0", input3}});
}
\ No newline at end of file
m["output"] = migraph::gpu::to_gpu(migraph::generate_argument(prog.get_parameter_shape("output")));
m["0"] = migraph::gpu::to_gpu(migraph::generate_argument(s, 12345));
auto result = migraph::gpu::from_gpu(prog.eval(m));
// // CPU target
// prog.compile(migraph::cpu::cpu_target{});
// auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
// auto input3 = migraph::generate_argument(s, 12345);
// auto result = prog.eval({{"0", input3}});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment