Unverified Commit 68189043 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Use current device when constructng context (#1294)

parent 2781ccd8
...@@ -33,6 +33,8 @@ namespace gpu { ...@@ -33,6 +33,8 @@ namespace gpu {
std::string get_device_name(); std::string get_device_name();
int get_device_id();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
...@@ -162,7 +163,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -162,7 +163,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
std::string target::name() const { return "gpu"; } std::string target::name() const { return "gpu"; }
migraphx::context target::get_context() const { return context{}; } migraphx::context target::get_context() const { return context(gpu::get_device_id()); }
argument target::copy_to(const argument& arg) const { return gpu::to_gpu(arg); } argument target::copy_to(const argument& arg) const { return gpu::to_gpu(arg); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment