Commit 3c95b34d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'layernorm_half2' into branch_for_ort2

parents af110526 789f86fb
...@@ -12,16 +12,15 @@ headers = ''' ...@@ -12,16 +12,15 @@ headers = '''
''' '''
form = string.Template(''' form = string.Template('''
#ifdef TYPE_ERASED_DECLARATION
/* // Type-erased interface for:
* Type-erased interface for: struct ${struct_name}
* {
* struct ${struct_name} ${decl_members}
* { };
${comment_members}
* }; #else
*
*/
struct ${struct_name} struct ${struct_name}
{ {
...@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x) ...@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
if (y == nullptr) throw std::bad_cast(); if (y == nullptr) throw std::bad_cast();
return *y; return *y;
} }
#endif
''') ''')
nonvirtual_member = string.Template(''' nonvirtual_member = string.Template('''
...@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override ...@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member = string.Template( comment_member = string.Template(
'''* ${friend} ${return_type} ${name}(${params}) ${const};''') '''* ${friend} ${return_type} ${name}(${params}) ${const};''')
decl_member = string.Template(''' ${comment}
${friend} ${return_type} ${name}(${params}) ${const};
''')
default_member = string.Template(''' default_member = string.Template('''
template<class T> template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params}) static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
...@@ -279,7 +283,8 @@ def convert_member(d, struct_name): ...@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
'this': '(*this)', 'this': '(*this)',
'using': '', 'using': '',
'brief': '', 'brief': '',
'return_': '' 'return_': '',
'comment': '// '
} }
args = [] args = []
params = [] params = []
...@@ -306,6 +311,7 @@ def convert_member(d, struct_name): ...@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
member['friend'] = 'friend' member['friend'] = 'friend'
elif x == 'default': elif x == 'default':
member['default'] = t member['default'] = t
member['comment'] = member['comment'] + '(optional)'
elif x == 'using': elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using']) member['using'] = 'using {};'.format(d[name]['using'])
elif x == '__brief__': elif x == '__brief__':
...@@ -347,18 +353,21 @@ def generate_form(name, members): ...@@ -347,18 +353,21 @@ def generate_form(name, members):
virtual_members = [] virtual_members = []
comment_members = [] comment_members = []
default_members = [] default_members = []
decl_members = []
for member in members: for member in members:
m = convert_member(member, name) m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m)) nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m)) pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m)) virtual_members.append(virtual_member.substitute(m))
comment_members.append(comment_member.substitute(m)) comment_members.append(comment_member.substitute(m))
decl_members.append(decl_member.substitute(m))
if 'default' in m: if 'default' in m:
default_members.append(default_member.substitute(m)) default_members.append(default_member.substitute(m))
return form.substitute(nonvirtual_members=''.join(nonvirtual_members), return form.substitute(nonvirtual_members=''.join(nonvirtual_members),
pure_virtual_members=''.join(pure_virtual_members), pure_virtual_members=''.join(pure_virtual_members),
virtual_members=''.join(virtual_members), virtual_members=''.join(virtual_members),
default_members=''.join(default_members), default_members=''.join(default_members),
decl_members=''.join(decl_members),
comment_members='\n'.join(comment_members), comment_members='\n'.join(comment_members),
struct_name=name) struct_name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment