"vscode:/vscode.git/clone" did not exist on "dcf597788fa99ec027c7bb1e2db5c1cb185da26f"
Commit df3749cd authored by Paul's avatar Paul
Browse files

Add code to insert memrefs

parent 60ab44c7
......@@ -2,8 +2,10 @@
#define MIGRAPHX_GUARD_RTGLIB_GPU_MLIR_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -12,6 +14,7 @@ namespace gpu {
std::string dump_mlir(const module& m);
code_object_op compile_mlir(const module& m);
instruction_ref insert_mlir(module& m, instruction_ref ins, const module& mmlir, const std::vector<instruction_ref>& inputs);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,6 +18,7 @@
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/iterator_for.hpp>
#include <deque>
#include <variant>
......@@ -509,6 +510,50 @@ code_object_op compile_mlir(const module& m)
return mp.compile();
}
instruction_ref insert_mlir(module& m, instruction_ref ins, const module& mmlir, const std::vector<instruction_ref>& inputs)
{
auto co = compile_mlir(mmlir);
std::vector<instruction_ref> refs;
refs.reserve(inputs.size() * 15);
std::unordered_map<uint64_t, instruction_ref> literal_map{};
auto get_literal = [&](uint64_t value) {
auto fi = literal_map.find(value);
if(fi != literal_map.end())
return fi->second;
auto lit = m.add_literal(value);
literal_map.emplace(value, lit);
return lit;
};
for(auto input:inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
// dim sizes
std::transform(s.lens().begin(),
s.lens().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
refs.push_back(get_literal(1)); // G
// dim strides
std::transform(s.strides().begin(),
s.strides().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
refs.push_back(get_literal(1)); // G
}
co.expected_inputs = to_shapes(refs);
co.output = mmlir.get_output_shapes().front();
return m.insert_instruction(ins, co, refs);
}
#else
std::string dump_mlir(const module&) { return {}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment