Commit df3749cd authored by Paul's avatar Paul
Browse files

Add code to insert memrefs

parent 60ab44c7
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
#define MIGRAPHX_GUARD_RTGLIB_GPU_MLIR_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_MLIR_HPP
#include <string> #include <string>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -12,6 +14,7 @@ namespace gpu { ...@@ -12,6 +14,7 @@ namespace gpu {
std::string dump_mlir(const module& m); std::string dump_mlir(const module& m);
code_object_op compile_mlir(const module& m); code_object_op compile_mlir(const module& m);
instruction_ref insert_mlir(module& m, instruction_ref ins, const module& mmlir, const std::vector<instruction_ref>& inputs);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <deque> #include <deque>
#include <variant> #include <variant>
...@@ -509,6 +510,50 @@ code_object_op compile_mlir(const module& m) ...@@ -509,6 +510,50 @@ code_object_op compile_mlir(const module& m)
return mp.compile(); return mp.compile();
} }
instruction_ref insert_mlir(module& m, instruction_ref ins, const module& mmlir, const std::vector<instruction_ref>& inputs)
{
auto co = compile_mlir(mmlir);
std::vector<instruction_ref> refs;
refs.reserve(inputs.size() * 15);
std::unordered_map<uint64_t, instruction_ref> literal_map{};
auto get_literal = [&](uint64_t value) {
auto fi = literal_map.find(value);
if(fi != literal_map.end())
return fi->second;
auto lit = m.add_literal(value);
literal_map.emplace(value, lit);
return lit;
};
for(auto input:inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
// dim sizes
std::transform(s.lens().begin(),
s.lens().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
refs.push_back(get_literal(1)); // G
// dim strides
std::transform(s.strides().begin(),
s.strides().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
refs.push_back(get_literal(1)); // G
}
co.expected_inputs = to_shapes(refs);
co.output = mmlir.get_output_shapes().front();
return m.insert_instruction(ins, co, refs);
}
#else #else
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment