Commit a652e90c authored by Paul's avatar Paul
Browse files

Fix device name

parent 13418e23
...@@ -4,4 +4,4 @@ blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -D ...@@ -4,4 +4,4 @@ blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -D
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
# jungpark-mlir/llvm-project-mlir@263a3b89c7ce83bef23c7530217184abcfd79956 -DBUILD_MIXR_TARGET=On # ROCmSoftwarePlatform/llvm-project-mlir,jungpark-mlir/llvm-project-mlir@263a3b89c7ce83bef23c7530217184abcfd79956 -DBUILD_MIXR_TARGET=On
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <deque> #include <deque>
#include <variant> #include <variant>
...@@ -461,11 +462,14 @@ struct mlir_program ...@@ -461,11 +462,14 @@ struct mlir_program
// 1st pipeline to call // 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get()); mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call // 2nd pipeline to call
const char* deviceName = "gfx908"; std::string tname = get_device_name();
mlirMIGraphXAddBackendPipeline(pm.get(), deviceName); // HACK: Since MLIR can't handle the full target name
auto hacked_tname = tname.substr(0, tname.find(":"));
mlirMIGraphXAddBackendPipeline(pm.get(), hacked_tname.c_str());
mlirPassManagerRun(pm.get(), mmodule.get()); mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op; code_object_op op;
op.symbol_name = "main";
op.code_object = get_binary(); op.code_object = get_binary();
std::tie(op.global, op.local) = get_launch_params(); std::tie(op.global, op.local) = get_launch_params();
return op; return op;
...@@ -554,7 +558,6 @@ instruction_ref insert_mlir(module& m, ...@@ -554,7 +558,6 @@ instruction_ref insert_mlir(module& m,
} }
co.expected_inputs = to_shapes(refs); co.expected_inputs = to_shapes(refs);
co.output = mmlir.get_output_shapes().front(); co.output = mmlir.get_output_shapes().front();
co.symbol_name = "main";
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment