Commit ab0ea297 authored by Paul's avatar Paul
Browse files

Add script to generate friend functions

parent 13251b94
...@@ -16,9 +16,9 @@ namespace rtg { ...@@ -16,9 +16,9 @@ namespace rtg {
* *
* struct operation * struct operation
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const; * argument compute(std::vector<argument> input) const;
* }; * };
* *
*/ */
......
...@@ -16,8 +16,8 @@ struct program; ...@@ -16,8 +16,8 @@ struct program;
* *
* struct target * struct target
* { * {
* std::string name() const; * std::string name() const;
* void apply(program & p) const; * void apply(program & p) const;
* }; * };
* *
*/ */
......
...@@ -161,43 +161,93 @@ inline const ValueType & any_cast(const ${struct_name} & x) ...@@ -161,43 +161,93 @@ inline const ValueType & any_cast(const ${struct_name} & x)
''') ''')
nonvirtual_member = string.Template(''' nonvirtual_member = string.Template('''
${return_type} ${name}(${params}) ${const} ${friend} ${return_type} ${name}(${params}) ${const}
{ {
assert(private_detail_te_handle_mem_var); assert(private_detail_te_handle_mem_var);
return private_detail_te_get_handle().${name}(${args}); return private_detail_te_get_handle().${internal_name}(${member_args});
} }
''') ''')
pure_virtual_member = string.Template("virtual ${return_type} ${name}(${params}) ${const} = 0;\n") pure_virtual_member = string.Template("virtual ${return_type} ${internal_name}(${member_params}) ${const} = 0;\n")
virtual_member = string.Template(''' virtual_member = string.Template('''
${return_type} ${name}(${params}) ${const} override ${return_type} ${internal_name}(${member_params}) ${const} override
{ {
return private_detail_te_value.${name}(${args}); return ${call};
} }
''') ''')
comment_member = string.Template('''* ${return_type} ${name}(${params}) ${const};''') comment_member = string.Template('''* ${friend} ${return_type} ${name}(${params}) ${const};''')
def convert_member(d): def trim_type_name(name):
n = name.strip()
if n.startswith('const'):
return trim_type_name(n[5:])
if n.endswith(('&', '*')):
return trim_type_name(n[0:-1])
return n
def internal_name(name):
internal_names = {
'operator<<': 'operator_shift_left',
'operator>>': 'operator_shift_right',
}
if name in internal_names:
return internal_names[name]
else:
return name
def generate_call(m, friend):
if m['name'].startswith('operator'):
op = m['name'][8:]
args = m['args']
if len(m[args]) == 2:
return string.Template('${arg1} ${op} ${arg2}').substitute(op=op, arg1=args[0], arg2=args[1])
else:
return string.Template('${op}${arg1}').substitute(op=op, arg1=args[0])
if friend:
return string.Template('${name}(${args})').substitute(m)
return string.Template('private_detail_te_value.${name}(${args})').substitute(m)
def convert_member(d, struct_name):
for name in d: for name in d:
member = { 'name': name, 'const': ''} member = { 'name': name, 'internal_name': internal_name(name), 'const': '', 'friend': ''}
args = [] args = []
params = [] params = []
member_args = []
member_params = []
skip = False
friend = False
if 'friend' in d[name]:
friend = True
skip = True
for x in d[name]: for x in d[name]:
t = d[name][x] t = d[name][x]
if x == 'return': if x == 'return':
member['return_type'] = t member['return_type'] = t
elif x == 'const': elif x == 'const':
member['const'] = 'const' member['const'] = 'const'
elif x == 'friend':
member['friend'] = 'friend'
else: else:
use_member = not(skip and struct_name == trim_type_name(t))
arg_name = x
if not use_member: arg_name = 'private_detail_te_value'
if t.endswith(('&', '*')): if t.endswith(('&', '*')):
args.append(x) if use_member: member_args.append(x)
args.append(arg_name)
else: else:
args.append('std::move({})'.format(x)) if use_member: member_args.append('std::move({})'.format(x))
args.append('std::move({})'.format(arg_name))
params.append(t+' '+x) params.append(t+' '+x)
if use_member: member_params.append(t+' '+x)
else: skip = False
member['args'] = ','.join(args) member['args'] = ','.join(args)
member['member_args'] = ','.join(member_args)
member['params'] = ','.join(params) member['params'] = ','.join(params)
member['params'] = ','.join(params)
member['member_params'] = ','.join(member_params)
member['call'] = generate_call(member, friend)
return member return member
return None return None
...@@ -208,7 +258,7 @@ def generate_form(name, members): ...@@ -208,7 +258,7 @@ def generate_form(name, members):
virtual_members = [] virtual_members = []
comment_members = [] comment_members = []
for member in members: for member in members:
m = convert_member(member) m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m)) nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m)) pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m)) virtual_members.append(virtual_member.substitute(m))
...@@ -226,6 +276,12 @@ def virtual(name, returns=None, **kwargs): ...@@ -226,6 +276,12 @@ def virtual(name, returns=None, **kwargs):
args['return'] = returns args['return'] = returns
return { name: args } return { name: args }
def friend(name, returns=None, **kwargs):
args = kwargs
args['return'] = returns
args['friend'] = 'friend'
return { name: args }
def interface(name, *members): def interface(name, *members):
return generate_form(name, members) return generate_form(name, members)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment