Commit bf571e25 authored by Paul's avatar Paul
Browse files

Add output alias to fusion

parent 5b888363
...@@ -613,6 +613,7 @@ struct identity ...@@ -613,6 +613,7 @@ struct identity
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct abs : unary struct abs : unary
......
...@@ -155,6 +155,7 @@ struct hip_triadd ...@@ -155,6 +155,7 @@ struct hip_triadd
device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
struct hip_triadd_relu struct hip_triadd_relu
...@@ -170,6 +171,7 @@ struct hip_triadd_relu ...@@ -170,6 +171,7 @@ struct hip_triadd_relu
device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
struct hip_add_relu struct hip_add_relu
...@@ -185,6 +187,7 @@ struct hip_add_relu ...@@ -185,6 +187,7 @@ struct hip_add_relu
device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1)); device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1));
return args.at(2); return args.at(2);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
struct find_add_relu struct find_add_relu
...@@ -271,6 +274,7 @@ struct miopen_conv_bias ...@@ -271,6 +274,7 @@ struct miopen_conv_bias
f.compile(ctx); f.compile(ctx);
return f.get_workspace(ctx); return f.get_workspace(ctx);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
struct miopen_conv_bias_relu struct miopen_conv_bias_relu
...@@ -314,6 +318,7 @@ struct miopen_conv_bias_relu ...@@ -314,6 +318,7 @@ struct miopen_conv_bias_relu
f.compile(ctx); f.compile(ctx);
return f.get_workspace(ctx); return f.get_workspace(ctx);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
template <class... Ms> template <class... Ms>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment