Commit 2a0ff223 authored by Paul's avatar Paul
Browse files

Add return

parent 1ac17a13
...@@ -396,6 +396,23 @@ struct mlir_program ...@@ -396,6 +396,23 @@ struct mlir_program
return result; return result;
} }
static std::string get_name(instruction_ref ins)
{
if (ins->name() == "@return")
return "std.return";
return "migraphx." + ins->name();
}
static shape get_shape(instruction_ref ins)
{
if (ins->name() == "@return")
{
assert(ins->inputs().size() == 1);
return ins->inputs().front()->get_shape();
}
return ins->get_shape();
}
void parse(const module& m) void parse(const module& m)
{ {
auto mbody = mlirModuleGetBody(mmodule.get()); auto mbody = mlirModuleGetBody(mmodule.get());
...@@ -405,10 +422,10 @@ struct mlir_program ...@@ -405,10 +422,10 @@ struct mlir_program
{ {
if(ins->name() == "@param") if(ins->name() == "@param")
continue; continue;
auto name = "migraphx." + ins->name(); auto name = get_name(ins);
auto ops = create_operation_state(name); auto ops = create_operation_state(name);
ops.add_attribute_value(ins->get_operator().to_value()); ops.add_attribute_value(ins->get_operator().to_value());
ops.add_results({ins->get_shape()}); ops.add_results({get_shape(ins)});
std::vector<MlirValue> inputs; std::vector<MlirValue> inputs;
transform( transform(
......
...@@ -35,13 +35,15 @@ module { ...@@ -35,13 +35,15 @@ module {
func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> { func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> {
%0 = "migraphx.convolution"(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 : %0 = "migraphx.convolution"(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 :
si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = return %0 : tensor<1x2x2x2xf32>
} }
} }
)__migraphx__"; )__migraphx__";
migraphx::module m; migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}}); auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}});
m.add_instruction(migraphx::make_op("convolution"), x, w); auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
m.add_return({conv});
auto s = migraphx::gpu::dump_mlir(m); auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled // Skip test if MLIR is not enabled
if(s.empty()) if(s.empty())
...@@ -57,6 +59,7 @@ module { ...@@ -57,6 +59,7 @@ module {
%0 = "migraphx.convolution"(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 : si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = "migraphx.convolution"(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 : si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = "migraphx.add"(%0, %arg2) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = "migraphx.add"(%0, %arg2) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = "migraphx.relu"(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = "migraphx.relu"(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%3 = return %2 : tensor<1x2x2x2xf32>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -66,7 +69,8 @@ module { ...@@ -66,7 +69,8 @@ module {
auto b = m.add_parameter("b", {migraphx::shape::float_type, {1, 2, 2, 2}}); auto b = m.add_parameter("b", {migraphx::shape::float_type, {1, 2, 2, 2}});
auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w); auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
auto add = m.add_instruction(migraphx::make_op("add"), conv, b); auto add = m.add_instruction(migraphx::make_op("add"), conv, b);
m.add_instruction(migraphx::make_op("relu"), add); auto relu = m.add_instruction(migraphx::make_op("relu"), add);
m.add_return({relu});
auto s = migraphx::gpu::dump_mlir(m); auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled // Skip test if MLIR is not enabled
if(s.empty()) if(s.empty())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment