Commit a472ec0f authored by Paul's avatar Paul
Browse files

Merge branch 'onnx-build'

parents abe4092b a9a9f126
......@@ -13,9 +13,9 @@ endif()
if(CMAKE_CXX_COMPILER MATCHES ".*hcc")
message(STATUS "Enable miopen backend")
set(MIGRAPH_ENABLE_MIOPEN On CACHE BOOL "")
set(MIGRAPH_ENABLE_GPU On CACHE BOOL "")
else()
set(MIGRAPH_ENABLE_MIOPEN Off CACHE BOOL "")
set(MIGRAPH_ENABLE_GPU Off CACHE BOOL "")
endif()
add_compile_options(-std=c++14)
......
......@@ -23,5 +23,5 @@ target_link_libraries(mnist migraph_cpu migraph_onnx)
if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_miopen)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu)
endif()
......@@ -2,11 +2,11 @@
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/miopen/target.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <miopen/miopen.h>
#include <migraph/miopen/miopen.hpp>
#include <migraph/gpu/miopen.hpp>
migraph::argument run_cpu(std::string file)
{
......@@ -24,15 +24,14 @@ migraph::argument run_gpu(std::string file)
auto p = migraph::parse_onnx(file);
p.compile(migraph::cpu::cpu_target{});
auto s = p.get_parameter_shape("Input3");
auto input3 = migraph::miopen::to_gpu(migraph::generate_argument(s));
auto input3 = migraph::gpu::to_gpu(migraph::generate_argument(s));
auto output =
migraph::miopen::to_gpu(migraph::generate_argument(p.get_parameter_shape("output")));
auto handle = migraph::miopen::make_obj<migraph::miopen::miopen_handle>(&miopenCreate);
auto output = migraph::gpu::to_gpu(migraph::generate_argument(p.get_parameter_shape("output")));
auto handle = migraph::gpu::make_obj<migraph::gpu::miopen_handle>(&miopenCreate);
auto out = p.eval({{"Input3", input3}, {"output", output}});
std::cout << p << std::endl;
return migraph::miopen::from_gpu(out);
return migraph::gpu::from_gpu(out);
}
int main(int argc, char const* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment