Commit 71f9ab23 authored by Paul's avatar Paul
Browse files

Formatting

parent 6b9e7689
...@@ -101,7 +101,10 @@ void mlir_print(F f, T x, std::ostream& os) ...@@ -101,7 +101,10 @@ void mlir_print(F f, T x, std::ostream& os)
struct mlir_program struct mlir_program
{ {
mlir_program() : ctx(mlirContextCreate()), location(mlirLocationUnknownGet(ctx.get())), mmodule(mlirModuleCreateEmpty(location)) mlir_program()
: ctx(mlirContextCreate()),
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{ {
mlirRegisterAllDialects(ctx.get()); mlirRegisterAllDialects(ctx.get());
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/); mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
...@@ -226,21 +229,32 @@ struct mlir_program ...@@ -226,21 +229,32 @@ struct mlir_program
return attr; return attr;
} }
using attribute_t = std::variant<std::nullptr_t, std::uint64_t, unsigned char, bool, double, std::string, value, std::vector<value>, MlirType>; using attribute_t = std::variant<std::nullptr_t,
std::uint64_t,
unsigned char,
bool,
double,
std::string,
value,
std::vector<value>,
MlirType>;
using named_attribute_t = std::pair<std::string_view, attribute_t>; using named_attribute_t = std::pair<std::string_view, attribute_t>;
MlirNamedAttribute name_attribute(const named_attribute_t& na) const MlirNamedAttribute name_attribute(const named_attribute_t& na) const
{ {
return name_attribute(na.first, std::visit([&](const auto& x) { return attribute(x); }, na.second)); return name_attribute(na.first,
std::visit([&](const auto& x) { return attribute(x); }, na.second));
} }
std::vector<MlirNamedAttribute> name_attributes(const std::vector<named_attribute_t>& named_attrs) const std::vector<MlirNamedAttribute>
name_attributes(const std::vector<named_attribute_t>& named_attrs) const
{ {
std::vector<MlirNamedAttribute> attributes; std::vector<MlirNamedAttribute> attributes;
attributes.reserve(named_attrs.size()); attributes.reserve(named_attrs.size());
std::transform(named_attrs.begin(), named_attrs.end(), std::back_inserter(attributes), [&](const named_attribute_t& a) { std::transform(named_attrs.begin(),
return name_attribute(a); named_attrs.end(),
}); std::back_inserter(attributes),
[&](const named_attribute_t& a) { return name_attribute(a); });
return attributes; return attributes;
} }
...@@ -256,9 +270,10 @@ struct mlir_program ...@@ -256,9 +270,10 @@ struct mlir_program
struct mlir_operation_state struct mlir_operation_state
{ {
mlir_operation_state(mlir_program& p, const std::string_view& name) mlir_operation_state(mlir_program& p, const std::string_view& name)
: prog(&p), op_state(mlirOperationStateGet(make_mlir_string_ref(name), p.location)) : prog(&p), op_state(mlirOperationStateGet(make_mlir_string_ref(name), p.location))
{} {
}
mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs) mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs)
{ {
...@@ -308,11 +323,10 @@ struct mlir_program ...@@ -308,11 +323,10 @@ struct mlir_program
mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data()); mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data());
mlir_operation op(mlirOperationCreate(&op_state)); mlir_operation op(mlirOperationCreate(&op_state));
// Release memory since mlir_operation owns it // Release memory since mlir_operation owns it
for(auto& r:regions) for(auto& r : regions)
r.release(); r.release();
regions.clear(); regions.clear();
return op; return op;
} }
mlir_program* prog; mlir_program* prog;
...@@ -329,7 +343,7 @@ struct mlir_program ...@@ -329,7 +343,7 @@ struct mlir_program
{ {
std::vector<MlirValue> result; std::vector<MlirValue> result;
mlir_operation op = ops.create_operation(); mlir_operation op = ops.create_operation();
auto weak_op = op.get(); auto weak_op = op.get();
mlirBlockInsertOwnedOperation(body, 0, op.release()); mlirBlockInsertOwnedOperation(body, 0, op.release());
auto n = mlirOperationGetNumResults(weak_op); auto n = mlirOperationGetNumResults(weak_op);
...@@ -340,27 +354,30 @@ struct mlir_program ...@@ -340,27 +354,30 @@ struct mlir_program
return result; return result;
} }
MlirBlock insert(MlirBlock body, const module& m, std::unordered_map<instruction_ref, MlirValue>& ins_map) MlirBlock
insert(MlirBlock body, const module& m, std::unordered_map<instruction_ref, MlirValue>& ins_map)
{ {
auto names = m.get_parameter_names(); auto names = m.get_parameter_names();
std::vector<shape> inputs; std::vector<shape> inputs;
std::transform(names.begin(), names.end(), std::back_inserter(inputs), [&](const std::string& name) { std::transform(names.begin(),
return m.get_parameter_shape(name); names.end(),
}); std::back_inserter(inputs),
[&](const std::string& name) { return m.get_parameter_shape(name); });
std::vector<shape> outputs = m.get_output_shapes(); std::vector<shape> outputs = m.get_output_shapes();
auto body_inputs = make_tensors(inputs); auto body_inputs = make_tensors(inputs);
mlir_region region = mlirRegionCreate(); mlir_region region = mlirRegionCreate();
mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data()); mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data());
MlirBlock result = fbody.get(); MlirBlock result = fbody.get();
mlirRegionAppendOwnedBlock(region.get(), fbody.release()); mlirRegionAppendOwnedBlock(region.get(), fbody.release());
auto ops = create_operation_state("builtin.func"); auto ops = create_operation_state("builtin.func");
ops.add_attributes({{"type", make_function_type(inputs, outputs)}, {"sym_name", "\"main\""}}); ops.add_attributes(
{{"type", make_function_type(inputs, outputs)}, {"sym_name", "\"main\""}});
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
for(auto i:range(names.size())) for(auto i : range(names.size()))
ins_map[m.get_parameter(names[i])] = mlirBlockGetArgument(result, i); ins_map[m.get_parameter(names[i])] = mlirBlockGetArgument(result, i);
return result; return result;
} }
...@@ -370,17 +387,16 @@ struct mlir_program ...@@ -370,17 +387,16 @@ struct mlir_program
auto mbody = mlirModuleGetBody(mmodule.get()); auto mbody = mlirModuleGetBody(mmodule.get());
std::unordered_map<instruction_ref, MlirValue> ins_map; std::unordered_map<instruction_ref, MlirValue> ins_map;
auto fbody = insert(mbody, m, ins_map); auto fbody = insert(mbody, m, ins_map);
for(auto ins:iterator_for(m)) for(auto ins : iterator_for(m))
{ {
auto name = "migraphx." + ins->name(); auto name = "migraphx." + ins->name();
auto ops = create_operation_state(name); auto ops = create_operation_state(name);
ops.add_attribute_value(ins->get_operator().to_value()); ops.add_attribute_value(ins->get_operator().to_value());
ops.add_results({ins->get_shape()}); ops.add_results({ins->get_shape()});
std::vector<MlirValue> inputs; std::vector<MlirValue> inputs;
transform(ins->inputs(), std::back_inserter(inputs), [&](auto i) { transform(
return ins_map.at(i); ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); });
});
ops.add_operands(inputs); ops.add_operands(inputs);
auto outputs = insert(fbody, std::move(ops)); auto outputs = insert(fbody, std::move(ops));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment