Commit 9a3932d4 authored by Paul's avatar Paul
Browse files

Add output alias function

parent f76e2490
......@@ -43,6 +43,9 @@ struct operation
* the same the `output` shape.
*/
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op);
......@@ -108,12 +111,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
return -1;
}
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
......@@ -185,6 +208,12 @@ struct operation
return (*this).private_detail_te_get_handle().name();
}
int output_alias(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input);
}
shape compute_shape(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -217,6 +246,7 @@ struct operation
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
......@@ -254,8 +284,15 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); }
int output_alias(const std::vector<shape>& input) const override
{
return output_alias_op(private_detail_te_value, input);
}
shape compute_shape(const std::vector<shape>& input) const override
{
return private_detail_te_value.compute_shape(input);
}
......
......@@ -43,6 +43,9 @@ struct operation
* the same the `output` shape.
*/
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op);
......@@ -108,10 +111,30 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template<class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
return -1;
}
template<class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes) -> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template<class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
<%
interface(
'operation',
virtual('name', returns = 'std::string', const = True),
virtual('output_alias', returns = 'int', input = 'const std::vector<shape>&', const = True, default = 'output_alias_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute',
returns = 'argument',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment