Commit 2c952efd authored by Paul's avatar Paul
Browse files

Dont provide output for return instruction

parent 032af369
...@@ -183,11 +183,15 @@ struct mlir_program ...@@ -183,11 +183,15 @@ struct mlir_program
MlirAttribute attribute(std::int64_t i) const MlirAttribute attribute(std::int64_t i) const
{ {
return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx.get(), 64), i); if (i < 0)
MIGRAPHX_THROW("MLIR cant handle negative values since they are ambiguous");
return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
} }
MlirAttribute attribute(std::uint64_t i) const MlirAttribute attribute(std::uint64_t i) const
{ {
return mlirIntegerAttrGet(mlirIntegerTypeUnsignedGet(ctx.get(), 64), i); if (i > (std::numeric_limits<std::uint64_t>::max() / 2))
MIGRAPHX_THROW("MLIR cant handle large integer values since they are ambiguous");
return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
} }
MlirAttribute attribute(unsigned char i) const { return attribute(std::uint64_t(i)); } MlirAttribute attribute(unsigned char i) const { return attribute(std::uint64_t(i)); }
MlirAttribute attribute(bool b) const { return mlirBoolAttrGet(ctx.get(), b ? 1 : 0); } MlirAttribute attribute(bool b) const { return mlirBoolAttrGet(ctx.get(), b ? 1 : 0); }
...@@ -433,6 +437,7 @@ struct mlir_program ...@@ -433,6 +437,7 @@ struct mlir_program
auto name = get_name(ins); auto name = get_name(ins);
auto ops = create_operation_state(name); auto ops = create_operation_state(name);
ops.add_attribute_value(ins->get_operator().to_value()); ops.add_attribute_value(ins->get_operator().to_value());
if(ins->name() != "@return")
ops.add_results({get_shape(ins)}); ops.add_results({get_shape(ins)});
std::vector<MlirValue> inputs; std::vector<MlirValue> inputs;
...@@ -441,10 +446,13 @@ struct mlir_program ...@@ -441,10 +446,13 @@ struct mlir_program
ops.add_operands(inputs); ops.add_operands(inputs);
auto outputs = insert(fbody, std::move(ops)); auto outputs = insert(fbody, std::move(ops));
if(ins->name() != "@return")
{
assert(outputs.size() == 1); assert(outputs.size() == 1);
ins_map[ins] = outputs.front(); ins_map[ins] = outputs.front();
} }
} }
}
code_object_op compile() code_object_op compile()
{ {
......
...@@ -33,9 +33,8 @@ TEST_CASE(conv) ...@@ -33,9 +33,8 @@ TEST_CASE(conv)
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> { func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> {
%0 = migraphx.convolution(%arg0, %arg1) {dilation = [1 : ui64, 1 : ui64], group = 1 : si64, padding = [0 : ui64, 0 : ui64, 0 : ui64, 0 : ui64], padding_mode = 0 : ui64, stride %0 = migraphx.convolution(%arg0, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
= [1 : ui64, 1 : ui64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> return %0 : tensor<1x2x2x2xf32>
%1 = return %0 : tensor<1x2x2x2xf32>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -58,10 +57,10 @@ TEST_CASE(conv_add_relu) ...@@ -58,10 +57,10 @@ TEST_CASE(conv_add_relu)
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> { func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> {
%0 = migraphx.convolution(%arg0, %arg1) {dilation = [1 : ui64, 1 : ui64], group = 1 : si64, padding = [0 : ui64, 0 : ui64, 0 : ui64, 0 : ui64], padding_mode = 0 : ui64, stride = [1 : ui64, 1 : ui64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg0, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add(%0, %arg2) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%3 = return %2 : tensor<1x2x2x2xf32> return %2 : tensor<1x2x2x2xf32>
} }
} }
)__migraphx__"; )__migraphx__";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment