Commit 1851e975 authored by Paul's avatar Paul
Browse files

Register dialect

parent e6f8a2cf
......@@ -122,6 +122,8 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{
MlirDialectHandle mixrHandle = mlirGetDialectHandle__migraphx__();
mlirDialectHandleRegisterDialect(mixrHandle, ctx.get());
mlirRegisterAllDialects(ctx.get());
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
}
......
......@@ -33,7 +33,7 @@ TEST_CASE(conv)
const std::string mlir_output = R"__migraphx__(
module {
func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> {
%0 = "migraphx.convolution"(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 :
%0 = migraphx.convolution(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 :
si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = return %0 : tensor<1x2x2x2xf32>
}
......@@ -57,9 +57,9 @@ TEST_CASE(conv_add_relu)
const std::string mlir_output = R"__migraphx__(
module {
func @main(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> {
%0 = "migraphx.convolution"(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 : si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = "migraphx.add"(%0, %arg2) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = "migraphx.relu"(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%0 = migraphx.convolution(%arg0, %arg1) {dilation = [1 : si64, 1 : si64], group = 1 : si64, padding = [0 : si64, 0 : si64], padding_mode = 0 : si64, stride = [1 : si64, 1 : si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%3 = return %2 : tensor<1x2x2x2xf32>
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment