Commit 038a4c52 authored by wsttiger's avatar wsttiger
Browse files

Merged from master still debugging resnet

parents 06cc4f8f 905d4ab0
......@@ -25,7 +25,7 @@ struct operation
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
shape compute_shape(std::vector<shape> input) const;
shape compute_shape(const std::vector<shape>& input) const;
/**
* @brief This performs the operation's computation
*
......@@ -37,7 +37,7 @@ struct operation
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
*/
argument compute(context& ctx, shape output, std::vector<argument> input) const;
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op);
......@@ -56,7 +56,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
template <class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input)
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
......@@ -64,8 +65,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
<%
interface('operation',
virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', ctx='context&', output='shape', input='std::vector<argument>', const=True, default='compute_op'),
virtual('compute_shape', returns='shape', input='const std::vector<shape>&', const=True),
virtual('compute', returns='argument', ctx='context&', output='const shape&', input='const std::vector<argument>&', const=True, default='compute_op'),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<')
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment