Commit 9a3932d4 authored by Paul's avatar Paul
Browse files

Add output alias function

parent f76e2490
...@@ -43,6 +43,9 @@ struct operation ...@@ -43,6 +43,9 @@ struct operation
* the same the `output` shape. * the same the `output` shape.
*/ */
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const; argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not /// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name. /// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
...@@ -108,12 +111,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -108,12 +111,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
return -1;
}
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
* struct operation * struct operation
* { * {
* std::string name() const; * std::string name() const;
* int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
...@@ -185,6 +208,12 @@ struct operation ...@@ -185,6 +208,12 @@ struct operation
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
int output_alias(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input);
}
shape compute_shape(const std::vector<shape>& input) const shape compute_shape(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -217,6 +246,7 @@ struct operation ...@@ -217,6 +246,7 @@ struct operation
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
...@@ -254,8 +284,15 @@ struct operation ...@@ -254,8 +284,15 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
int output_alias(const std::vector<shape>& input) const override
{
return output_alias_op(private_detail_te_value, input);
}
shape compute_shape(const std::vector<shape>& input) const override shape compute_shape(const std::vector<shape>& input) const override
{ {
return private_detail_te_value.compute_shape(input); return private_detail_te_value.compute_shape(input);
} }
......
...@@ -43,6 +43,9 @@ struct operation ...@@ -43,6 +43,9 @@ struct operation
* the same the `output` shape. * the same the `output` shape.
*/ */
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const; argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not /// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name. /// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
...@@ -108,10 +111,30 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -108,10 +111,30 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template<class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
return -1;
}
template<class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes) -> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template<class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
<% <%
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('output_alias', returns = 'int', input = 'const std::vector<shape>&', const = True, default = 'output_alias_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment