Commit 4f07b8f1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into test_branch_for_ort2

parents af110526 1e0bbd78
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); } migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
......
...@@ -26,11 +26,11 @@ struct schedule_model ...@@ -26,11 +26,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed /// Get the number of concurrent instruction allowed
std::size_t concurrency() const; std::size_t concurrency() const;
/// Schedule a concurrent instruction /// Schedule a concurrent instruction
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction // Insert necessary waits before an instruction
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction // Insert necessary records after an instruction
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation /// Compute weights for an operation
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
...@@ -40,9 +40,9 @@ struct schedule_model ...@@ -40,9 +40,9 @@ struct schedule_model
<% <%
interface('schedule_model', interface('schedule_model',
virtual('concurrency', returns='std::size_t', const=True), virtual('concurrency', returns='std::size_t', const=True),
virtual('sched', p='module&', ins='instruction_ref', n='std::size_t', const=True), virtual('sched', m='module&', ins='instruction_ref', n='std::size_t', const=True),
virtual('wait', p='module&', ins='instruction_ref', wait_id='std::size_t', const=True), virtual('wait', m='module&', ins='instruction_ref', wait_id='std::size_t', const=True),
virtual('record', p='module&', ins='instruction_ref', wait_id='std::size_t', const=True), virtual('record', m='module&', ins='instruction_ref', wait_id='std::size_t', const=True),
virtual('weight', returns='std::size_t', op='const operation&', const=True) virtual('weight', returns='std::size_t', op='const operation&', const=True)
) )
%> %>
......
...@@ -12,16 +12,15 @@ headers = ''' ...@@ -12,16 +12,15 @@ headers = '''
''' '''
form = string.Template(''' form = string.Template('''
#ifdef TYPE_ERASED_DECLARATION
/* // Type-erased interface for:
* Type-erased interface for: struct ${struct_name}
* {
* struct ${struct_name} ${decl_members}
* { };
${comment_members}
* }; #else
*
*/
struct ${struct_name} struct ${struct_name}
{ {
...@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x) ...@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
if (y == nullptr) throw std::bad_cast(); if (y == nullptr) throw std::bad_cast();
return *y; return *y;
} }
#endif
''') ''')
nonvirtual_member = string.Template(''' nonvirtual_member = string.Template('''
...@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override ...@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member = string.Template( comment_member = string.Template(
'''* ${friend} ${return_type} ${name}(${params}) ${const};''') '''* ${friend} ${return_type} ${name}(${params}) ${const};''')
decl_member = string.Template(''' ${comment}
${friend} ${return_type} ${name}(${params}) ${const};
''')
default_member = string.Template(''' default_member = string.Template('''
template<class T> template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params}) static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
...@@ -279,7 +283,8 @@ def convert_member(d, struct_name): ...@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
'this': '(*this)', 'this': '(*this)',
'using': '', 'using': '',
'brief': '', 'brief': '',
'return_': '' 'return_': '',
'comment': '// '
} }
args = [] args = []
params = [] params = []
...@@ -306,6 +311,7 @@ def convert_member(d, struct_name): ...@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
member['friend'] = 'friend' member['friend'] = 'friend'
elif x == 'default': elif x == 'default':
member['default'] = t member['default'] = t
member['comment'] = member['comment'] + '(optional)'
elif x == 'using': elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using']) member['using'] = 'using {};'.format(d[name]['using'])
elif x == '__brief__': elif x == '__brief__':
...@@ -347,18 +353,21 @@ def generate_form(name, members): ...@@ -347,18 +353,21 @@ def generate_form(name, members):
virtual_members = [] virtual_members = []
comment_members = [] comment_members = []
default_members = [] default_members = []
decl_members = []
for member in members: for member in members:
m = convert_member(member, name) m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m)) nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m)) pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m)) virtual_members.append(virtual_member.substitute(m))
comment_members.append(comment_member.substitute(m)) comment_members.append(comment_member.substitute(m))
decl_members.append(decl_member.substitute(m))
if 'default' in m: if 'default' in m:
default_members.append(default_member.substitute(m)) default_members.append(default_member.substitute(m))
return form.substitute(nonvirtual_members=''.join(nonvirtual_members), return form.substitute(nonvirtual_members=''.join(nonvirtual_members),
pure_virtual_members=''.join(pure_virtual_members), pure_virtual_members=''.join(pure_virtual_members),
virtual_members=''.join(virtual_members), virtual_members=''.join(virtual_members),
default_members=''.join(default_members), default_members=''.join(default_members),
decl_members=''.join(decl_members),
comment_members='\n'.join(comment_members), comment_members='\n'.join(comment_members),
struct_name=name) struct_name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment