Commit e8e7135d authored by Paul's avatar Paul
Browse files

Fix code gen

parent 69bec4b6
...@@ -177,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( ...@@ -177,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
} }
template <class T> template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&) std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{ {
return -1; return -1;
} }
...@@ -190,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes) ...@@ -190,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
} }
template <class T> template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes) std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{ {
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
...@@ -241,7 +241,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -241,7 +241,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
* std::string name() const; * std::string name() const;
* bool is_context_free() const; * bool is_context_free() const;
* bool has_finalize() const; * bool has_finalize() const;
* int output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
...@@ -327,7 +327,7 @@ struct operation ...@@ -327,7 +327,7 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize(); return (*this).private_detail_te_get_handle().has_finalize();
} }
int output_alias(const std::vector<shape>& input) const std::ptrdiff_t output_alias(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input); return (*this).private_detail_te_get_handle().output_alias(input);
...@@ -382,10 +382,10 @@ struct operation ...@@ -382,10 +382,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual void virtual void
finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0; finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
...@@ -434,7 +434,7 @@ struct operation ...@@ -434,7 +434,7 @@ struct operation
bool has_finalize() const override { return has_finalize_op(private_detail_te_value); } bool has_finalize() const override { return has_finalize_op(private_detail_te_value); }
int output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{ {
return output_alias_op(private_detail_te_value, input); return output_alias_op(private_detail_te_value, input);
......
...@@ -177,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( ...@@ -177,7 +177,7 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
} }
template <class T> template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&) std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{ {
return -1; return -1;
} }
...@@ -190,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes) ...@@ -190,7 +190,7 @@ auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
} }
template <class T> template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes) std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{ {
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
...@@ -240,7 +240,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -240,7 +240,7 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'), virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'), virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
virtual('output_alias', virtual('output_alias',
returns = 'int', returns = 'std::ptrdiff_t',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'output_alias_op'), default = 'output_alias_op'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment