Commit dd1da8aa authored by Paul's avatar Paul
Browse files

Add gpu support

parent 4d08193a
...@@ -6,6 +6,10 @@ ...@@ -6,6 +6,10 @@
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace py = pybind11; namespace py = pybind11;
...@@ -81,11 +85,23 @@ PYBIND11_MODULE(migraphx, m) ...@@ -81,11 +85,23 @@ PYBIND11_MODULE(migraphx, m)
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
return migraphx::cpu::target{}; return migraphx::cpu::target{};
#ifdef HAVE_GPU
if(name == "gpu")
return migraphx::gpu::target{};
#endif
throw std::runtime_error("Target not found: " + name); throw std::runtime_error("Target not found: " + name);
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
m.def("from_gpu", &migraphx::gpu::from_gpu);
m.def("gpu_sync", &migraphx::gpu::gpu_sync);
m.def("copy_to_gpu", &migraphx::gpu::copy_to_gpu);
#endif
#ifdef VERSION_INFO #ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO; m.attr("__version__") = VERSION_INFO;
#else #else
......
...@@ -15,4 +15,5 @@ endfunction() ...@@ -15,4 +15,5 @@ endfunction()
add_dependencies(tests migraphx_py) add_dependencies(tests migraphx_py)
add_dependencies(check migraphx_py) add_dependencies(check migraphx_py)
add_py_test(test test.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(cpu cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("gpu"))
print(p)
params = {}
for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r = migraphx.from_gpu(p.run(params))
print(r)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment