Commit 3c95b34d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'layernorm_half2' into branch_for_ort2

parents af110526 789f86fb
......@@ -12,16 +12,15 @@ headers = '''
'''
form = string.Template('''
#ifdef TYPE_ERASED_DECLARATION
/*
* Type-erased interface for:
*
* struct ${struct_name}
* {
${comment_members}
* };
*
*/
// Type-erased interface for:
struct ${struct_name}
{
${decl_members}
};
#else
struct ${struct_name}
{
......@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
if (y == nullptr) throw std::bad_cast();
return *y;
}
#endif
''')
nonvirtual_member = string.Template('''
......@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member = string.Template(
'''* ${friend} ${return_type} ${name}(${params}) ${const};''')
decl_member = string.Template(''' ${comment}
${friend} ${return_type} ${name}(${params}) ${const};
''')
default_member = string.Template('''
template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
......@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
'this': '(*this)',
'using': '',
'brief': '',
'return_': ''
'return_': '',
'comment': '// '
}
args = []
params = []
......@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
member['friend'] = 'friend'
elif x == 'default':
member['default'] = t
member['comment'] = member['comment'] + '(optional)'
elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using'])
elif x == '__brief__':
......@@ -347,18 +353,21 @@ def generate_form(name, members):
virtual_members = []
comment_members = []
default_members = []
decl_members = []
for member in members:
m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m))
comment_members.append(comment_member.substitute(m))
decl_members.append(decl_member.substitute(m))
if 'default' in m:
default_members.append(default_member.substitute(m))
return form.substitute(nonvirtual_members=''.join(nonvirtual_members),
pure_virtual_members=''.join(pure_virtual_members),
virtual_members=''.join(virtual_members),
default_members=''.join(default_members),
decl_members=''.join(decl_members),
comment_members='\n'.join(comment_members),
struct_name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment