"templates/vscode:/vscode.git/clone" did not exist on "78f81fc0e598fdb066e41608a8384d493308728e"
Commit 330fe429 authored by Paul's avatar Paul
Browse files

Merge branch 'friend-op'

parents f54dcb28 8dc320a6
......@@ -7,7 +7,7 @@ def rocmtestnode(variant, name, body) {
mkdir build
cd build
CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_CXX_FLAGS_DEBUG='-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined' ${flags} ..
CTEST_PARALLEL_LEVEL=32 make -j32 check
CTEST_PARALLEL_LEVEL=32 make -j32 all doc check
"""
echo cmd
sh cmd
......
......@@ -23,6 +23,11 @@ struct unknown
return input.front();
}
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
template <class C, class T>
......
......@@ -13,6 +13,11 @@ struct literal
std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const literal& op)
{
os << op.name();
return os;
}
};
struct param
......@@ -21,6 +26,11 @@ struct param
std::string name() const { return "@param:" + parameter; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); }
argument compute(std::vector<argument>) const { RTG_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op)
{
os << op.name();
return os;
}
};
} // namespace builtin
......
......@@ -11,14 +11,25 @@
namespace rtg {
namespace operation_stream {
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
return os << x.name();
}
} // namespace operation_stream
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
*/
......@@ -74,20 +85,26 @@ struct operation
std::string name() const
{
assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().name();
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
shape compute_shape(std::vector<shape> input) const
{
assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().compute_shape(std::move(input));
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
}
argument compute(std::vector<argument> input) const
{
assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().compute(std::move(input));
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(input));
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
{
assert(op.private_detail_te_handle_mem_var);
return op.private_detail_te_get_handle().operator_shift_left(os);
}
private:
......@@ -97,9 +114,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -134,14 +152,22 @@ struct operation
shape compute_shape(std::vector<shape> input) const override
{
return private_detail_te_value.compute_shape(std::move(input));
}
argument compute(std::vector<argument> input) const override
{
return private_detail_te_value.compute(std::move(input));
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using rtg::operation_stream::operator<<;
return os << private_detail_te_value;
}
PrivateDetailTypeErasedT private_detail_te_value;
};
......
......@@ -56,6 +56,12 @@ struct convolution
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
os << op.name();
return os;
}
};
struct pooling
......@@ -96,6 +102,12 @@ struct pooling
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
os << op.name();
return os;
}
};
struct activation
......@@ -110,6 +122,11 @@ struct activation
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name();
return os;
}
};
struct reshape
......@@ -136,6 +153,12 @@ struct reshape
}
argument compute(std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
os << op.name();
return os;
}
};
} // namespace rtg
......
......@@ -16,8 +16,8 @@ struct program;
*
* struct target
* {
* std::string name() const;
* void apply(program & p) const;
* std::string name() const;
* void apply(program & p) const;
* };
*
*/
......@@ -73,14 +73,14 @@ struct target
std::string name() const
{
assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().name();
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
void apply(program& p) const
{
assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().apply(p);
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p);
}
private:
......
......@@ -10,6 +10,11 @@ struct simple_operation
std::string name() const { return "simple"; }
rtg::shape compute_shape(std::vector<rtg::shape>) const { RTG_THROW("not computable"); }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{
os << op.name();
return os;
}
};
void operation_copy_test()
......
......@@ -11,11 +11,22 @@
namespace rtg {
namespace operation_stream {
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
return os << x.name();
}
} // namespace operation_stream
<%
interface('operation',
virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', input='std::vector<argument>', const=True)
virtual('compute', returns='argument', input='std::vector<argument>', const=True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='rtg::operation_stream::operator<<')
)
%>
......
......@@ -161,43 +161,109 @@ inline const ValueType & any_cast(const ${struct_name} & x)
''')
nonvirtual_member = string.Template('''
${return_type} ${name}(${params}) ${const}
${friend} ${return_type} ${name}(${params}) ${const}
{
assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${name}(${args});
assert(${this}.private_detail_te_handle_mem_var);
return ${this}.private_detail_te_get_handle().${internal_name}(${member_args});
}
''')
pure_virtual_member = string.Template("virtual ${return_type} ${name}(${params}) ${const} = 0;\n")
pure_virtual_member = string.Template("virtual ${return_type} ${internal_name}(${member_params}) ${member_const} = 0;\n")
virtual_member = string.Template('''
${return_type} ${name}(${params}) ${const} override
${return_type} ${internal_name}(${member_params}) ${member_const} override
{
return private_detail_te_value.${name}(${args});
${using}
return ${call};
}
''')
comment_member = string.Template('''* ${return_type} ${name}(${params}) ${const};''')
comment_member = string.Template('''* ${friend} ${return_type} ${name}(${params}) ${const};''')
def convert_member(d):
def trim_type_name(name):
n = name.strip()
if n.startswith('const'):
return trim_type_name(n[5:])
if n.endswith(('&', '*')):
return trim_type_name(n[0:-1])
return n
def internal_name(name):
internal_names = {
'operator<<': 'operator_shift_left',
'operator>>': 'operator_shift_right',
}
if name in internal_names:
return internal_names[name]
else:
return name
def generate_call(m, friend):
if m['name'].startswith('operator'):
op = m['name'][8:]
args = m['args']
if ',' in args:
return args.replace(',', op)
else:
return string.Template('${op}${arga}').substitute(op=op, args=args)
if friend:
return string.Template('${name}(${args})').substitute(m)
return string.Template('private_detail_te_value.${name}(${args})').substitute(m)
def convert_member(d, struct_name):
for name in d:
member = { 'name': name, 'const': ''}
member = {
'name': name,
'internal_name': internal_name(name),
'const': '',
'member_const': '',
'friend': '',
'this': '(*this)',
'using': ''
}
args = []
params = []
member_args = []
member_params = []
skip = False
friend = False
if 'friend' in d[name]:
friend = True
skip = True
for x in d[name]:
t = d[name][x]
if x == 'return':
member['return_type'] = t
elif x == 'const':
member['const'] = 'const'
member['member_const'] = 'const'
elif x == 'friend':
member['friend'] = 'friend'
elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using'])
else:
use_member = not(skip and struct_name == trim_type_name(t))
arg_name = x
if not use_member:
arg_name = 'private_detail_te_value'
member['this'] = x
if 'const' in t:
member['member_const'] = 'const'
if t.endswith(('&', '*')):
args.append(x)
if use_member: member_args.append(x)
args.append(arg_name)
else:
args.append('std::move({})'.format(x))
if use_member: member_args.append('std::move({})'.format(x))
args.append('std::move({})'.format(arg_name))
params.append(t+' '+x)
if use_member: member_params.append(t+' '+x)
else: skip = False
member['args'] = ','.join(args)
member['member_args'] = ','.join(member_args)
member['params'] = ','.join(params)
member['params'] = ','.join(params)
member['member_params'] = ','.join(member_params)
member['call'] = generate_call(member, friend)
return member
return None
......@@ -208,7 +274,7 @@ def generate_form(name, members):
virtual_members = []
comment_members = []
for member in members:
m = convert_member(member)
m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m))
......@@ -226,6 +292,12 @@ def virtual(name, returns=None, **kwargs):
args['return'] = returns
return { name: args }
def friend(name, returns=None, **kwargs):
args = kwargs
args['return'] = returns
args['friend'] = 'friend'
return { name: args }
def interface(name, *members):
return generate_form(name, members)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment