Commit a40d1b1d authored by Paul's avatar Paul
Browse files

Auto generate fallback code in type-erasure

parent 0045d0b7
......@@ -62,7 +62,9 @@ bool has_finalize(const operation& x);
#else
namespace operation_stream {
namespace detail {
namespace operation_operators {
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
......@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
return os;
}
} // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
......@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
} // namespace operation_operators
template <class T>
auto compute_op(rank<2>,
......@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
}
template <class T>
std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&)
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
{
return -1;
}
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
template <class T>
auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
......@@ -233,6 +218,8 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
} // namespace detail
/*
* Type-erased interface for:
*
......@@ -396,6 +383,110 @@ struct operation
virtual bool operator==(const operation& y) const = 0;
};
template <class T>
static auto private_detail_te_default_is_context_free(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_context_free())
{
return private_detail_te_self.is_context_free();
}
template <class T>
static bool private_detail_te_default_is_context_free(float, T&& private_detail_te_self)
{
return detail::is_context_free_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.has_finalize())
{
return private_detail_te_self.has_finalize();
}
template <class T>
static bool private_detail_te_default_has_finalize(float, T&& private_detail_te_self)
{
return detail::has_finalize_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_output_alias(char,
T&& private_detail_te_self,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.output_alias(input))
{
return private_detail_te_self.output_alias(input);
}
template <class T>
static std::ptrdiff_t private_detail_te_default_output_alias(float,
T&& private_detail_te_self,
const std::vector<shape>& input)
{
return detail::output_alias_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_finalize(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.finalize(ctx, output, input))
{
private_detail_te_self.finalize(ctx, output, input);
}
template <class T>
static void private_detail_te_default_finalize(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<shape>& input)
{
detail::finalize_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_compute(char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input)
-> decltype(private_detail_te_self.compute(ctx, output, input))
{
return private_detail_te_self.compute(ctx, output, input);
}
template <class T>
static argument private_detail_te_default_compute(float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input)
{
return detail::compute_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_compute(char,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input)
-> decltype(private_detail_te_self.compute(output, input))
{
return private_detail_te_self.compute(output, input);
}
template <class T>
static argument private_detail_te_default_compute(float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input)
{
return detail::compute_op(private_detail_te_self, output, input);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -429,21 +520,26 @@ struct operation
bool is_context_free() const override
{
return is_context_free_op(private_detail_te_value);
return private_detail_te_default_is_context_free(char(0), private_detail_te_value);
}
bool has_finalize() const override { return has_finalize_op(private_detail_te_value); }
bool has_finalize() const override
{
return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
{
return output_alias_op(private_detail_te_value, input);
return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
}
void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
{
finalize_op(private_detail_te_value, ctx, output, input);
private_detail_te_default_finalize(
char(0), private_detail_te_value, ctx, output, input);
}
shape compute_shape(const std::vector<shape>& input) const override
......@@ -457,24 +553,26 @@ struct operation
const std::vector<argument>& input) const override
{
return compute_op(private_detail_te_value, ctx, output, input);
return private_detail_te_default_compute(
char(0), private_detail_te_value, ctx, output, input);
}
argument compute(const shape& output, const std::vector<argument>& input) const override
{
return compute_op(private_detail_te_value, output, input);
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input);
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using migraphx::operation_stream::operator<<;
using migraphx::detail::operation_operators::operator<<;
return os << private_detail_te_value;
}
bool operator==(const operation& y) const override
{
using migraphx::operation_equal::operator==;
using migraphx::detail::operation_operators::operator==;
return private_detail_te_value == y;
}
......@@ -550,7 +648,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free();
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
return detail::is_context_free_op(x);
}
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
......@@ -558,7 +656,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
return detail::has_finalize_op(x);
}
#endif
......
......@@ -248,6 +248,50 @@ struct target
virtual argument allocate(const shape& s) const = 0;
};
template <class T>
static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
-> decltype(private_detail_te_self.copy_to(input))
{
return private_detail_te_self.copy_to(input);
}
template <class T>
static argument
private_detail_te_default_copy_to(float, T&& private_detail_te_self, const argument& input)
{
return copy_to_target(private_detail_te_self, input);
}
template <class T>
static auto
private_detail_te_default_copy_from(char, T&& private_detail_te_self, const argument& input)
-> decltype(private_detail_te_self.copy_from(input))
{
return private_detail_te_self.copy_from(input);
}
template <class T>
static argument
private_detail_te_default_copy_from(float, T&& private_detail_te_self, const argument& input)
{
return copy_from_target(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_allocate(char, T&& private_detail_te_self, const shape& s)
-> decltype(private_detail_te_self.allocate(s))
{
return private_detail_te_self.allocate(s);
}
template <class T>
static argument
private_detail_te_default_allocate(float, T&& private_detail_te_self, const shape& s)
{
return target_allocate(private_detail_te_self, s);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -289,19 +333,19 @@ struct target
argument copy_to(const argument& input) const override
{
return copy_to_target(private_detail_te_value, input);
return private_detail_te_default_copy_to(char(0), private_detail_te_value, input);
}
argument copy_from(const argument& input) const override
{
return copy_from_target(private_detail_te_value, input);
return private_detail_te_default_copy_from(char(0), private_detail_te_value, input);
}
argument allocate(const shape& s) const override
{
return target_allocate(private_detail_te_value, s);
return private_detail_te_default_allocate(char(0), private_detail_te_value, s);
}
PrivateDetailTypeErasedT private_detail_te_value;
......
......@@ -62,7 +62,9 @@ bool has_finalize(const operation& x);
#else
namespace operation_stream {
namespace detail {
namespace operation_operators {
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
......@@ -80,10 +82,6 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
return os;
}
} // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
......@@ -95,7 +93,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
} // namespace operation_operators
template <class T>
auto compute_op(rank<2>,
......@@ -177,24 +175,11 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
}
template <class T>
std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&)
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
{
return -1;
}
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
template <class T>
auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
......@@ -233,22 +218,24 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
return {};
}
} // namespace detail
<%
interface(
'operation',
virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
virtual('is_context_free', returns = 'bool', const = True, default = 'detail::is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'detail::has_finalize_op'),
virtual('output_alias',
returns = 'std::ptrdiff_t',
input = 'const std::vector<shape>&',
const = True,
default = 'output_alias_op'),
default = 'detail::output_alias_op'),
virtual('finalize',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<shape>&',
default = 'finalize_op'),
default = 'detail::finalize_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute',
returns = 'argument',
......@@ -256,23 +243,23 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
default = 'detail::compute_op'),
virtual('compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
default = 'detail::compute_op'),
friend('operator<<',
returns = 'std::ostream &',
os = 'std::ostream &',
op = 'const operation &',
using = 'migraphx::operation_stream::operator<<'),
using = 'migraphx::detail::operation_operators::operator<<'),
friend('operator==',
returns = 'bool',
x = 'const operation &',
y = 'const operation &',
using = 'migraphx::operation_equal::operator==')) %>
using = 'migraphx::detail::operation_operators::operator==')) %>
inline bool operator!=(const operation& x, const operation& y)
{
......@@ -284,7 +271,7 @@ inline bool is_context_free(const operation& op) { return op.is_context_free();
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
return detail::is_context_free_op(x);
}
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
......@@ -292,7 +279,7 @@ inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
return detail::has_finalize_op(x);
}
#endif
......
......@@ -88,6 +88,8 @@ private:
${pure_virtual_members}
};
${default_members}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type :
private_detail_te_handle_base_type
......@@ -205,6 +207,21 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member = string.Template(
'''* ${friend} ${return_type} ${name}(${params}) ${const};''')
default_member = string.Template('''
template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
-> decltype(private_detail_te_self.${name}(${args}))
{
${return_} private_detail_te_self.${name}(${args});
}
template<class T>
static ${return_type} private_detail_te_default_${internal_name}(float, T&& private_detail_te_self ${comma} ${member_params})
{
${return_} ${default}(private_detail_te_self ${comma} ${args});
}
''')
def trim_type_name(name):
n = name.strip()
......@@ -237,12 +254,8 @@ def generate_call(m, friend, indirect):
if friend:
return string.Template('${name}(${args})').substitute(m)
if indirect:
if m['args']:
return string.Template(
'${default}(private_detail_te_value, ${args})').substitute(m)
else:
return string.Template(
'${default}(private_detail_te_value)').substitute(m)
'private_detail_te_default_${internal_name}(char(0), private_detail_te_value ${comma} ${args})').substitute(m)
return string.Template(
'private_detail_te_value.${name}(${args})').substitute(m)
......@@ -314,6 +327,7 @@ def convert_member(d, struct_name):
member['params'] = ','.join(params)
member['params'] = ','.join(params)
member['member_params'] = ','.join(member_params)
member['comma'] = ',' if len(args) > 0 else ''
member['call'] = generate_call(member, friend, indirect)
return member
return None
......@@ -324,15 +338,19 @@ def generate_form(name, members):
pure_virtual_members = []
virtual_members = []
comment_members = []
default_members = []
for member in members:
m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m))
comment_members.append(comment_member.substitute(m))
if 'default' in m:
default_members.append(default_member.substitute(m))
return form.substitute(nonvirtual_members=''.join(nonvirtual_members),
pure_virtual_members=''.join(pure_virtual_members),
virtual_members=''.join(virtual_members),
default_members=''.join(default_members),
comment_members='\n'.join(comment_members),
struct_name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment