Commit 6b9e7689 authored by Paul's avatar Paul
Browse files

Parse module

parent dd033c75
......@@ -175,6 +175,12 @@ void transform(Range&& r, Iterator it, F f)
std::transform(r.begin(), r.end(), it, f);
}
template <class Range1, class Range2, class Iterator, class F>
void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
{
std::transform(r1.begin(), r1.end(), r2.begin(), it, f);
}
template <class Range>
auto reverse(Range& r)
{
......
......@@ -10,7 +10,12 @@
#include <migraphx/manage_ptr.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
#include <deque>
#include <variant>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -57,7 +62,7 @@ struct mlir_handle
T get() const { return handle.get().get(); }
T release() const { return handle.release().get(); }
T release() { return handle.release().get(); }
private:
std::unique_ptr<ptr, deleter> handle;
......@@ -96,7 +101,7 @@ void mlir_print(F f, T x, std::ostream& os)
struct mlir_program
{
mlir_program() : ctx(mlirContextCreate())
mlir_program() : ctx(mlirContextCreate()), location(mlirLocationUnknownGet(ctx.get())), mmodule(mlirModuleCreateEmpty(location))
{
mlirRegisterAllDialects(ctx.get());
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
......@@ -194,11 +199,7 @@ struct mlir_program
}
if(not v.front().get_key().empty())
{
std::vector<MlirNamedAttribute> attributes;
attributes.reserve(v.size());
std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) {
return name_attribute(x.get_key(), x.without_key());
});
std::vector<MlirNamedAttribute> attributes = name_attributes(v);
return mlirDictionaryAttrGet(ctx.get(), attributes.size(), attributes.data());
}
else
......@@ -214,6 +215,8 @@ struct mlir_program
MlirAttribute attribute(MlirType t) const { return mlirTypeAttrGet(t); }
MlirAttribute attribute(MlirAttribute a) const { return a; }
template <class T>
MlirNamedAttribute name_attribute(const std::string_view& key, const T& x) const
{
......@@ -223,7 +226,173 @@ struct mlir_program
return attr;
}
using attribute_t = std::variant<std::nullptr_t, std::uint64_t, unsigned char, bool, double, std::string, value, std::vector<value>, MlirType>;
using named_attribute_t = std::pair<std::string_view, attribute_t>;
MlirNamedAttribute name_attribute(const named_attribute_t& na) const
{
return name_attribute(na.first, std::visit([&](const auto& x) { return attribute(x); }, na.second));
}
std::vector<MlirNamedAttribute> name_attributes(const std::vector<named_attribute_t>& named_attrs) const
{
std::vector<MlirNamedAttribute> attributes;
attributes.reserve(named_attrs.size());
std::transform(named_attrs.begin(), named_attrs.end(), std::back_inserter(attributes), [&](const named_attribute_t& a) {
return name_attribute(a);
});
return attributes;
}
std::vector<MlirNamedAttribute> name_attributes(const value& v) const
{
std::vector<MlirNamedAttribute> attributes;
attributes.reserve(v.size());
std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](const value& x) {
return name_attribute(x.get_key(), x.without_key());
});
return attributes;
}
struct mlir_operation_state
{
mlir_operation_state(mlir_program& p, const std::string_view& name)
: prog(&p), op_state(mlirOperationStateGet(make_mlir_string_ref(name), p.location))
{}
mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs)
{
auto attributes = prog->name_attributes(named_attrs);
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
return *this;
}
mlir_operation_state& add_attribute_value(const value& v)
{
auto attributes = prog->name_attributes(v);
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
return *this;
}
mlir_operation_state& add_regions(std::vector<mlir_region> rs)
{
regions = std::move(rs);
return *this;
}
mlir_operation_state& add_region(mlir_region r)
{
regions.emplace_back(std::move(r));
return *this;
}
mlir_operation_state& add_results(const std::vector<shape>& outputs)
{
auto x = prog->make_tensors(outputs);
mlirOperationStateAddResults(&op_state, x.size(), x.data());
return *this;
}
mlir_operation_state& add_operands(const std::vector<MlirValue>& inputs)
{
mlirOperationStateAddOperands(&op_state, inputs.size(), inputs.data());
return *this;
}
mlir_operation create_operation()
{
std::vector<MlirRegion> mregions(regions.size());
std::transform(regions.begin(), regions.end(), mregions.begin(), [](const auto& r) {
return r.get();
});
mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data());
mlir_operation op(mlirOperationCreate(&op_state));
// Release memory since mlir_operation owns it
for(auto& r:regions)
r.release();
regions.clear();
return op;
}
mlir_program* prog;
MlirOperationState op_state;
std::vector<mlir_region> regions = {};
};
mlir_operation_state create_operation_state(const std::string_view& name)
{
return {*this, name};
}
std::vector<MlirValue> insert(MlirBlock body, mlir_operation_state ops)
{
std::vector<MlirValue> result;
mlir_operation op = ops.create_operation();
auto weak_op = op.get();
mlirBlockInsertOwnedOperation(body, 0, op.release());
auto n = mlirOperationGetNumResults(weak_op);
result.reserve(n);
transform(range(n), std::back_inserter(result), [&](auto i) {
return mlirOperationGetResult(weak_op, i);
});
return result;
}
MlirBlock insert(MlirBlock body, const module& m, std::unordered_map<instruction_ref, MlirValue>& ins_map)
{
auto names = m.get_parameter_names();
std::vector<shape> inputs;
std::transform(names.begin(), names.end(), std::back_inserter(inputs), [&](const std::string& name) {
return m.get_parameter_shape(name);
});
std::vector<shape> outputs = m.get_output_shapes();
auto body_inputs = make_tensors(inputs);
mlir_region region = mlirRegionCreate();
mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data());
MlirBlock result = fbody.get();
mlirRegionAppendOwnedBlock(region.get(), fbody.release());
auto ops = create_operation_state("builtin.func");
ops.add_attributes({{"type", make_function_type(inputs, outputs)}, {"sym_name", "\"main\""}});
ops.add_region(std::move(region));
insert(body, std::move(ops));
for(auto i:range(names.size()))
ins_map[m.get_parameter(names[i])] = mlirBlockGetArgument(result, i);
return result;
}
void parse(const module& m)
{
auto mbody = mlirModuleGetBody(mmodule.get());
std::unordered_map<instruction_ref, MlirValue> ins_map;
auto fbody = insert(mbody, m, ins_map);
for(auto ins:iterator_for(m))
{
auto name = "migraphx." + ins->name();
auto ops = create_operation_state(name);
ops.add_attribute_value(ins->get_operator().to_value());
ops.add_results({ins->get_shape()});
std::vector<MlirValue> inputs;
transform(ins->inputs(), std::back_inserter(inputs), [&](auto i) {
return ins_map.at(i);
});
ops.add_operands(inputs);
auto outputs = insert(fbody, std::move(ops));
assert(outputs.size() == 1);
ins_map[ins] = outputs.front();
}
}
mlir_context ctx;
MlirLocation location;
mlir_module mmodule;
std::deque<std::string> strings{};
};
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment