Commit a40d1b1d authored by Paul's avatar Paul
Browse files

Auto generate fallback code in type-erasure

parent 0045d0b7
...@@ -62,7 +62,9 @@ bool has_finalize(const operation& x); ...@@ -62,7 +62,9 @@ bool has_finalize(const operation& x);
#else #else
namespace operation_stream { namespace detail {
namespace operation_operators {
template <class T> template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
return os; return os;
} }
} // namespace operation_stream
namespace operation_equal {
template <class T, class U> template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{ {
...@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
return reflect_tie(x) == reflect_tie(yy); return reflect_tie(x) == reflect_tie(yy);
} }
} // namespace operation_equal } // namespace operation_operators
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<2>,
...@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( ...@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
} }
template <class T> template <class T>
std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&) std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
{ {
return -1; return -1;
} }
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
template <class T> template <class T>
auto finalize_op( auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input) rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
...@@ -233,6 +218,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -233,6 +218,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
} // namespace detail
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
...@@ -396,6 +383,110 @@ struct operation ...@@ -396,6 +383,110 @@ struct operation
virtual bool operator==(const operation& y) const = 0; virtual bool operator==(const operation& y) const = 0;
}; };
template <class T>
static auto private_detail_te_default_is_context_free(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_context_free())
{
return private_detail_te_self.is_context_free();
}
template <class T>
static bool private_detail_te_default_is_context_free(float, T&& private_detail_te_self)
{
return detail::is_context_free_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.has_finalize())
{
return private_detail_te_self.has_finalize();
}
template <class T>
static bool private_detail_te_default_has_finalize(float, T&& private_detail_te_self)
{
return detail::has_finalize_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_output_alias(char,
T&& private_detail_te_self,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.output_alias(input))
{
return private_detail_te_self.output_alias(input);
}
template <class T>
static std::ptrdiff_t private_detail_te_default_output_alias(float,
T&& private_detail_te_self,
const std::vector<shape>& input)
{
return detail::output_alias_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_finalize(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.finalize(ctx, output, input))
{
private_detail_te_self.finalize(ctx, output, input);
}
template <class T>
static void private_detail_te_default_finalize(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
{
detail::finalize_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_compute(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input)
-> decltype(private_detail_te_self.compute(ctx, output, input))
{
return private_detail_te_self.compute(ctx, output, input);
}
template <class T>
static argument private_detail_te_default_compute(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input)
{
return detail::compute_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_compute(char,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input)
-> decltype(private_detail_te_self.compute(output, input))
{
return private_detail_te_self.compute(output, input);
}
template <class T>
static argument private_detail_te_default_compute(float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input)
{
return detail::compute_op(private_detail_te_self, output, input);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -429,21 +520,26 @@ struct operation ...@@ -429,21 +520,26 @@ struct operation
bool is_context_free() const override bool is_context_free() const override
{ {
return is_context_free_op(private_detail_te_value); return private_detail_te_default_is_context_free(char(0), private_detail_te_value);
} }
bool has_finalize() const override { return has_finalize_op(private_detail_te_value); } bool has_finalize() const override
{
return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{ {
return output_alias_op(private_detail_te_value, input); return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
} }
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{ {
finalize_op(private_detail_te_value, ctx, output, input); private_detail_te_default_finalize(
char(0), private_detail_te_value, ctx, output, input);
} }
shape compute_shape(const std::vector<shape>& input) const override shape compute_shape(const std::vector<shape>& input) const override
...@@ -457,24 +553,26 @@ struct operation ...@@ -457,24 +553,26 @@ struct operation
const std::vector<argument>& input) const override const std::vector<argument>& input) const override
{ {
return compute_op(private_detail_te_value, ctx, output, input); return private_detail_te_default_compute(
char(0), private_detail_te_value, ctx, output, input);
} }
argument compute(const shape& output, const std::vector<argument>& input) const override argument compute(const shape& output, const std::vector<argument>& input) const override
{ {
return compute_op(private_detail_te_value, output, input); return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input);
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
{ {
using migraphx::operation_stream::operator<<; using migraphx::detail::operation_operators::operator<<;
return os << private_detail_te_value; return os << private_detail_te_value;
} }
bool operator==(const operation& y) const override bool operator==(const operation& y) const override
{ {
using migraphx::operation_equal::operator==; using migraphx::detail::operation_operators::operator==;
return private_detail_te_value == y; return private_detail_te_value == y;
} }
...@@ -550,7 +648,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free(); ...@@ -550,7 +648,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free();
template <class T> template <class T>
bool is_context_free(const T& x) bool is_context_free(const T& x)
{ {
return is_context_free_op(x); return detail::is_context_free_op(x);
} }
inline bool has_finalize(const operation& op) { return op.has_finalize(); } inline bool has_finalize(const operation& op) { return op.has_finalize(); }
...@@ -558,7 +656,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); } ...@@ -558,7 +656,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T> template <class T>
bool has_finalize(const T& x) bool has_finalize(const T& x)
{ {
return has_finalize_op(x); return detail::has_finalize_op(x);
} }
#endif #endif
......
...@@ -248,6 +248,50 @@ struct target ...@@ -248,6 +248,50 @@ struct target
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T>
static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
-> decltype(private_detail_te_self.copy_to(input))
{
return private_detail_te_self.copy_to(input);
}
template <class T>
static argument
private_detail_te_default_copy_to(float, T&& private_detail_te_self, const argument& input)
{
return copy_to_target(private_detail_te_self, input);
}
template <class T>
static auto
private_detail_te_default_copy_from(char, T&& private_detail_te_self, const argument& input)
-> decltype(private_detail_te_self.copy_from(input))
{
return private_detail_te_self.copy_from(input);
}
template <class T>
static argument
private_detail_te_default_copy_from(float, T&& private_detail_te_self, const argument& input)
{
return copy_from_target(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_allocate(char, T&& private_detail_te_self, const shape& s)
-> decltype(private_detail_te_self.allocate(s))
{
return private_detail_te_self.allocate(s);
}
template <class T>
static argument
private_detail_te_default_allocate(float, T&& private_detail_te_self, const shape& s)
{
return target_allocate(private_detail_te_self, s);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -289,19 +333,19 @@ struct target ...@@ -289,19 +333,19 @@ struct target
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
{ {
return copy_to_target(private_detail_te_value, input); return private_detail_te_default_copy_to(char(0), private_detail_te_value, input);
} }
argument copy_from(const argument& input) const override argument copy_from(const argument& input) const override
{ {
return copy_from_target(private_detail_te_value, input); return private_detail_te_default_copy_from(char(0), private_detail_te_value, input);
} }
argument allocate(const shape& s) const override argument allocate(const shape& s) const override
{ {
return target_allocate(private_detail_te_value, s); return private_detail_te_default_allocate(char(0), private_detail_te_value, s);
} }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
......
...@@ -62,7 +62,9 @@ bool has_finalize(const operation& x); ...@@ -62,7 +62,9 @@ bool has_finalize(const operation& x);
#else #else
namespace operation_stream { namespace detail {
namespace operation_operators {
template <class T> template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
return os; return os;
} }
} // namespace operation_stream
namespace operation_equal {
template <class T, class U> template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{ {
...@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
return reflect_tie(x) == reflect_tie(yy); return reflect_tie(x) == reflect_tie(yy);
} }
} // namespace operation_equal } // namespace operation_operators
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<2>,
...@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( ...@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
} }
template <class T> template <class T>
std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&) std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
{ {
return -1; return -1;
} }
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
template <class T> template <class T>
auto finalize_op( auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input) rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
...@@ -233,22 +218,24 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -233,22 +218,24 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {}; return {};
} }
} // namespace detail
<% <%
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'), virtual('is_context_free', returns = 'bool', const = True, default = 'detail::is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'), virtual('has_finalize', returns = 'bool', const = True, default = 'detail::has_finalize_op'),
virtual('output_alias', virtual('output_alias',
returns = 'std::ptrdiff_t', returns = 'std::ptrdiff_t',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'output_alias_op'), default = 'detail::output_alias_op'),
virtual('finalize', virtual('finalize',
ctx = 'context&', ctx = 'context&',
output = 'const shape&', output = 'const shape&',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
default = 'finalize_op'), default = 'detail::finalize_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
...@@ -256,23 +243,23 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, ...@@ -256,23 +243,23 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
output = 'const shape&', output = 'const shape&',
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
const = True, const = True,
default = 'compute_op'), default = 'detail::compute_op'),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
output = 'const shape&', output = 'const shape&',
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
const = True, const = True,
default = 'compute_op'), default = 'detail::compute_op'),
friend('operator<<', friend('operator<<',
returns = 'std::ostream &', returns = 'std::ostream &',
os = 'std::ostream &', os = 'std::ostream &',
op = 'const operation &', op = 'const operation &',
using = 'migraphx::operation_stream::operator<<'), using = 'migraphx::detail::operation_operators::operator<<'),
friend('operator==', friend('operator==',
returns = 'bool', returns = 'bool',
x = 'const operation &', x = 'const operation &',
y = 'const operation &', y = 'const operation &',
using = 'migraphx::operation_equal::operator==')) %> using = 'migraphx::detail::operation_operators::operator==')) %>
inline bool operator!=(const operation& x, const operation& y) inline bool operator!=(const operation& x, const operation& y)
{ {
...@@ -284,7 +271,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free(); ...@@ -284,7 +271,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free();
template <class T> template <class T>
bool is_context_free(const T& x) bool is_context_free(const T& x)
{ {
return is_context_free_op(x); return detail::is_context_free_op(x);
} }
inline bool has_finalize(const operation& op) { return op.has_finalize(); } inline bool has_finalize(const operation& op) { return op.has_finalize(); }
...@@ -292,7 +279,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); } ...@@ -292,7 +279,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T> template <class T>
bool has_finalize(const T& x) bool has_finalize(const T& x)
{ {
return has_finalize_op(x); return detail::has_finalize_op(x);
} }
#endif #endif
......
...@@ -88,6 +88,8 @@ private: ...@@ -88,6 +88,8 @@ private:
${pure_virtual_members} ${pure_virtual_members}
}; };
${default_members}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : struct private_detail_te_handle_type :
private_detail_te_handle_base_type private_detail_te_handle_base_type
...@@ -205,6 +207,21 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override ...@@ -205,6 +207,21 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member = string.Template( comment_member = string.Template(
'''* ${friend} ${return_type} ${name}(${params}) ${const};''') '''* ${friend} ${return_type} ${name}(${params}) ${const};''')
default_member = string.Template('''
template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
-> decltype(private_detail_te_self.${name}(${args}))
{
${return_} private_detail_te_self.${name}(${args});
}
template<class T>
static ${return_type} private_detail_te_default_${internal_name}(float, T&& private_detail_te_self ${comma} ${member_params})
{
${return_} ${default}(private_detail_te_self ${comma} ${args});
}
''')
def trim_type_name(name): def trim_type_name(name):
n = name.strip() n = name.strip()
...@@ -237,12 +254,8 @@ def generate_call(m, friend, indirect): ...@@ -237,12 +254,8 @@ def generate_call(m, friend, indirect):
if friend: if friend:
return string.Template('${name}(${args})').substitute(m) return string.Template('${name}(${args})').substitute(m)
if indirect: if indirect:
if m['args']: return string.Template(
return string.Template( 'private_detail_te_default_${internal_name}(char(0), private_detail_te_value ${comma} ${args})').substitute(m)
'${default}(private_detail_te_value, ${args})').substitute(m)
else:
return string.Template(
'${default}(private_detail_te_value)').substitute(m)
return string.Template( return string.Template(
'private_detail_te_value.${name}(${args})').substitute(m) 'private_detail_te_value.${name}(${args})').substitute(m)
...@@ -314,6 +327,7 @@ def convert_member(d, struct_name): ...@@ -314,6 +327,7 @@ def convert_member(d, struct_name):
member['params'] = ','.join(params) member['params'] = ','.join(params)
member['params'] = ','.join(params) member['params'] = ','.join(params)
member['member_params'] = ','.join(member_params) member['member_params'] = ','.join(member_params)
member['comma'] = ',' if len(args) > 0 else ''
member['call'] = generate_call(member, friend, indirect) member['call'] = generate_call(member, friend, indirect)
return member return member
return None return None
...@@ -324,15 +338,19 @@ def generate_form(name, members): ...@@ -324,15 +338,19 @@ def generate_form(name, members):
pure_virtual_members = [] pure_virtual_members = []
virtual_members = [] virtual_members = []
comment_members = [] comment_members = []
default_members = []
for member in members: for member in members:
m = convert_member(member, name) m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m)) nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m)) pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m)) virtual_members.append(virtual_member.substitute(m))
comment_members.append(comment_member.substitute(m)) comment_members.append(comment_member.substitute(m))
if 'default' in m:
default_members.append(default_member.substitute(m))
return form.substitute(nonvirtual_members=''.join(nonvirtual_members), return form.substitute(nonvirtual_members=''.join(nonvirtual_members),
pure_virtual_members=''.join(pure_virtual_members), pure_virtual_members=''.join(pure_virtual_members),
virtual_members=''.join(virtual_members), virtual_members=''.join(virtual_members),
default_members=''.join(default_members),
comment_members='\n'.join(comment_members), comment_members='\n'.join(comment_members),
struct_name=name) struct_name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment