Commit bf571e25 authored by Paul's avatar Paul
Browse files

Add output alias to fusion

parent 5b888363
......@@ -613,6 +613,7 @@ struct identity
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct abs : unary
......
......@@ -155,6 +155,7 @@ struct hip_triadd
device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct hip_triadd_relu
......@@ -170,6 +171,7 @@ struct hip_triadd_relu
device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct hip_add_relu
......@@ -185,6 +187,7 @@ struct hip_add_relu
device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1));
return args.at(2);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct find_add_relu
......@@ -271,6 +274,7 @@ struct miopen_conv_bias
f.compile(ctx);
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct miopen_conv_bias_relu
......@@ -314,6 +318,7 @@ struct miopen_conv_bias_relu
f.compile(ctx);
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
template <class... Ms>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment