Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
...@@ -18,7 +18,7 @@ struct test_scatter0 : verify_program<test_scatter0> ...@@ -18,7 +18,7 @@ struct test_scatter0 : verify_program<test_scatter0>
auto pd = mm->add_parameter("data", sd); auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi}); auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su); auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -1}}), pd, li, pu); auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -1}}), pd, li, pu);
mm->add_return({r}); mm->add_return({r});
return p; return p;
......
...@@ -19,7 +19,7 @@ struct test_scatter1 : verify_program<test_scatter1> ...@@ -19,7 +19,7 @@ struct test_scatter1 : verify_program<test_scatter1>
auto pd = mm->add_parameter("data", sd); auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi}); auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su); auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -2}}), pd, li, pu); auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -2}}), pd, li, pu);
mm->add_return({r}); mm->add_return({r});
return p; return p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatternd : verify_program<test_scatternd>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {1}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto ld = mm->add_literal(migraphx::literal{ds, {1}});
auto data =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8}}}), ld);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatternd_add : verify_program<test_scatternd_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {1, 4}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto t_ind =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), indices);
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_add"), data, t_ind, updates);
mm->add_return({scatternd});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatternd_mul : verify_program<test_scatternd_mul>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto data = mm->add_parameter("data", ds);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_parameter("update", us);
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_mul"), data, indices, updates);
mm->add_return({scatternd});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with odd size tensor can't fit half2 packing
struct test_sqrt_half1 : verify_program<test_sqrt_half1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {5}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with tensor size that's divisible by 2,
// but not divisible by 4
struct test_sqrt_half2 : verify_program<test_sqrt_half2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {6}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with tensor size that fits into half4 packing
struct test_sqrt_half4 : verify_program<test_sqrt_half4>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {8}};
auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_sub_int : verify_program<test_sub_int>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", {migraphx::shape::int16_type, {4, 5}});
auto y = mm->add_parameter("y", {migraphx::shape::int16_type, {2, 3, 4, 5}});
auto xb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), x);
auto diff = mm->add_instruction(migraphx::make_op("sub"), y, xb);
mm->add_return({diff});
return p;
}
};
...@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = [] ...@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = []
cpp_header_preamble: List[str] = [] cpp_header_preamble: List[str] = []
def bad_param_error(msg): def bad_param_error(msg: str):
return 'throw std::runtime_error("{}")'.format(msg) return 'throw std::runtime_error("{}")'.format(msg)
...@@ -89,7 +89,7 @@ class Type: ...@@ -89,7 +89,7 @@ class Type:
else: else:
return t.remove_const() return t.remove_const()
def const_compatible(self, t): def const_compatible(self, t: 'Type'):
if t.is_const(): if t.is_const():
return self.add_const() return self.add_const()
return self return self
...@@ -102,6 +102,10 @@ header_function = Template(''' ...@@ -102,6 +102,10 @@ header_function = Template('''
${error_type} ${name}(${params}); ${error_type} ${name}(${params});
''') ''')
function_pointer_typedef = Template('''
typedef ${error_type} (*${fname})(${params});
''')
c_api_impl = Template(''' c_api_impl = Template('''
extern "C" ${error_type} ${name}(${params}) extern "C" ${error_type} ${name}(${params})
{ {
...@@ -136,18 +140,23 @@ class CFunction: ...@@ -136,18 +140,23 @@ class CFunction:
self.va_end = ['va_end({});'.format(name)] self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '') self.add_param('...', '')
def substitute(self, form: Template) -> str: def substitute(self, form: Template, **kwargs) -> str:
return form.substitute(error_type=error_type, return form.substitute(error_type=error_type,
try_wrap=try_wrap, try_wrap=try_wrap,
name=self.name, name=self.name,
params=', '.join(self.params), params=', '.join(self.params),
body=";\n ".join(self.body), body=";\n ".join(self.body),
va_start="\n ".join(self.va_start), va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end)) va_end="\n ".join(self.va_end),
**kwargs)
def generate_header(self) -> str: def generate_header(self) -> str:
return self.substitute(header_function) return self.substitute(header_function)
def generate_function_pointer(self, name: Optional[str] = None) -> str:
return self.substitute(function_pointer_typedef,
fname=name or self.name)
def generate_body(self) -> str: def generate_body(self) -> str:
return self.substitute(c_api_impl) return self.substitute(c_api_impl)
...@@ -163,7 +172,9 @@ class Parameter: ...@@ -163,7 +172,9 @@ class Parameter:
name: str, name: str,
type: str, type: str,
optional: bool = False, optional: bool = False,
returns: bool = False) -> None: returns: bool = False,
virtual: bool = False,
this: bool = False) -> None:
self.name = name self.name = name
self.type = Type(type) self.type = Type(type)
self.optional = optional self.optional = optional
...@@ -175,7 +186,11 @@ class Parameter: ...@@ -175,7 +186,11 @@ class Parameter:
self.cpp_read = '${name}' self.cpp_read = '${name}'
self.cpp_write = '${name}' self.cpp_write = '${name}'
self.returns = returns self.returns = returns
self.virtual = virtual
self.this = this
self.bad_param_check: Optional[BadParam] = None self.bad_param_check: Optional[BadParam] = None
self.virtual_read: Optional[List[str]] = None
self.virtual_write: Optional[str] = None
def get_name(self, prefix: Optional[str] = None) -> str: def get_name(self, prefix: Optional[str] = None) -> str:
if prefix: if prefix:
...@@ -248,6 +263,48 @@ class Parameter: ...@@ -248,6 +263,48 @@ class Parameter:
raise ValueError("Error for {}: write cannot be a string".format( raise ValueError("Error for {}: write cannot be a string".format(
self.type.str())) self.type.str()))
def virtual_arg(self, prefix: Optional[str] = None) -> List[str]:
read = self.virtual_read
if not read and len(self.write) >= len(self.cparams):
read = [
Template(w.partition('=')[2]).safe_substitute(result='${name}')
for w in self.write
]
if not read:
raise ValueError("No virtual_read parameter provided for: " +
self.type.str())
if isinstance(read, str):
raise ValueError(
"Error for {}: virtual_read cannot be a string".format(
self.type.str()))
return [self.substitute(r, prefix=prefix) for r in read]
def virtual_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]:
return [
'std::remove_pointer_t<{type}> {prefix}{n};'.format(
type=Type(t).str(), prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write
if not write:
if '*' in self.read or '->' in self.read:
write = Template(self.read).safe_substitute(name='(&${name})')
else:
write = self.read
return self.substitute(write, prefix=prefix)
def cpp_param(self, prefix: Optional[str] = None) -> str: def cpp_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${cpptype} ${name}', prefix=prefix) return self.substitute('${cpptype} ${name}', prefix=prefix)
...@@ -311,6 +368,7 @@ class Function: ...@@ -311,6 +368,7 @@ class Function:
invoke: Optional[str] = None, invoke: Optional[str] = None,
fname: Optional[str] = None, fname: Optional[str] = None,
return_name: Optional[str] = None, return_name: Optional[str] = None,
virtual: bool = False,
**kwargs) -> None: **kwargs) -> None:
self.name = name self.name = name
self.params = params or [] self.params = params or []
...@@ -321,6 +379,10 @@ class Function: ...@@ -321,6 +379,10 @@ class Function:
self.return_name = return_name or 'out' self.return_name = return_name or 'out'
self.returns = Parameter(self.return_name, returns, self.returns = Parameter(self.return_name, returns,
returns=True) if returns else None returns=True) if returns else None
for p in self.params:
p.virtual = virtual
if self.returns:
self.returns.virtual = virtual
def share_params(self) -> None: def share_params(self) -> None:
if self.shared_size == True: if self.shared_size == True:
...@@ -556,6 +618,9 @@ def params(virtual: Optional[Dict[str, str]] = None, ...@@ -556,6 +618,9 @@ def params(virtual: Optional[Dict[str, str]] = None,
return result return result
gparams = params
def add_function(name: str, *args, **kwargs) -> Function: def add_function(name: str, *args, **kwargs) -> Function:
f = Function(name, *args, **kwargs) f = Function(name, *args, **kwargs)
functions.append(f) functions.append(f)
...@@ -627,7 +692,7 @@ extern "C" struct ${ctype}; ...@@ -627,7 +692,7 @@ extern "C" struct ${ctype};
struct ${ctype} { struct ${ctype} {
template<class... Ts> template<class... Ts>
${ctype}(Ts&&... xs) ${ctype}(Ts&&... xs)
: object(std::forward<Ts>(xs)...) : object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{} {}
${cpptype} object; ${cpptype} object;
}; };
...@@ -656,6 +721,55 @@ void destroy(T* x) ...@@ -656,6 +721,55 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t)
{
}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
''' '''
cpp_handle_preamble = ''' cpp_handle_preamble = '''
...@@ -718,38 +832,53 @@ def add_handle(name: str, ...@@ -718,38 +832,53 @@ def add_handle(name: str,
ctype: str, ctype: str,
cpptype: str, cpptype: str,
destroy: Optional[str] = None, destroy: Optional[str] = None,
ref: Optional[bool] = None) -> None: ref=False,
skip_def=False) -> None:
opaque_type = ctype + '_t' opaque_type = ctype + '_t'
const_opaque_type = 'const_' + opaque_type
def handle_wrap(p): def handle_wrap(p: Parameter):
t = Type(opaque_type) t = Type(opaque_type)
if p.type.is_const(): if p.type.is_const():
t = Type('const_' + opaque_type) t = Type('const_' + opaque_type)
if p.returns: # p.read = 'object_cast<${ctype}>(&(${name}))'
if p.virtual:
p.add_param(t)
elif p.returns:
p.add_param(t.add_pointer()) p.add_param(t.add_pointer())
if p.type.is_reference():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
if p.type.is_reference():
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.virtual_read = ['object_cast<${ctype}>(${result})']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
if skip_def:
p.read = '*${name}'
else:
p.read = '${name}->object' p.read = '${name}->object'
p.cpp_read = '${name}.get_handle_ptr()' p.cpp_read = '${name}.get_handle_ptr()'
type_map[cpptype] = handle_wrap type_map[cpptype] = handle_wrap
if not ref: if not ref:
add_function(destroy or ctype + '_' + 'destroy', add_function(destroy or ctype + '_' + 'destroy',
params({name: opaque_type}), params({name: opaque_type}),
fname='destroy') fname='destroy')
add_function(ctype + '_' + 'assign_to',
params(output=opaque_type, input=const_opaque_type),
invoke='*output = *input')
add_handle_preamble() add_handle_preamble()
c_header_preamble.append(handle_typedef.substitute(locals())) c_header_preamble.append(handle_typedef.substitute(locals()))
c_api_body_preamble.append(handle_definition.substitute(locals())) if not skip_def:
c_api_body_preamble.append(handle_definition.substitute(locals()))
@cwrap('std::vector') @cwrap('std::vector')
...@@ -759,30 +888,32 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -759,30 +888,32 @@ def vector_c_wrap(p: Parameter) -> None:
if not inner: if not inner:
return return
t = inner.add_pointer() t = inner.add_pointer()
if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
if p.returns: if p.returns:
if p.type.is_reference(): if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
p.add_param(t.add_pointer()) p.add_param(t.add_pointer())
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr', p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer') 'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'std::copy(${result}.begin(), ${result}.end(), ${name})'
]
else: else:
p.add_param(t) p.add_param(t)
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer') p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer')
p.read = '${type}(${name}, ${name}+${size})'
p.read = '${type}(${name}, ${name}+${size})'
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.virtual_read = ['${name}.data()', '${name}.size()']
if p.type.is_reference():
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string') @cwrap('std::string')
...@@ -792,34 +923,34 @@ def string_c_wrap(p: Parameter) -> None: ...@@ -792,34 +923,34 @@ def string_c_wrap(p: Parameter) -> None:
if p.type.is_reference(): if p.type.is_reference():
p.add_param(t.add_pointer()) p.add_param(t.add_pointer())
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = ['*${name} = ${result}.c_str()']
else: else:
p.add_param(t) p.add_param(t)
p.add_param('size_t', p.name + '_size') p.add_param('size_t', p.name + '_size')
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.read = '${type}(${name})'
p.read = '${type}(${name})'
p.cpp_write = '${type}(${name})'
p.virtual_read = ['${name}.c_str()']
if p.type.is_reference():
p.write = ['*${name} = ${result}.c_str()']
else:
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
class Handle: class Handle:
def __init__(self, def __init__(self, name: str, ctype: str, cpptype: str, **kwargs) -> None:
name: str,
ctype: str,
cpptype: str,
ref: Optional[bool] = None) -> None:
self.name = name self.name = name
self.ctype = ctype self.ctype = ctype
self.cpptype = cpptype self.cpptype = cpptype
self.opaque_type = self.ctype + '_t'
self.cpp_class = CPPClass(name, ctype) self.cpp_class = CPPClass(name, ctype)
add_handle(name, ctype, cpptype, ref=ref) add_handle(name, ctype, cpptype, **kwargs)
cpp_type_map[cpptype] = name cpp_type_map[cpptype] = name
def cname(self, name: str) -> str: def cname(self, name: str) -> str:
...@@ -829,6 +960,7 @@ class Handle: ...@@ -829,6 +960,7 @@ class Handle:
return Template(s).safe_substitute(name=self.name, return Template(s).safe_substitute(name=self.name,
ctype=self.ctype, ctype=self.ctype,
cpptype=self.cpptype, cpptype=self.cpptype,
opaque_type=self.opaque_type,
**kwargs) **kwargs)
def constructor(self, def constructor(self,
...@@ -883,6 +1015,137 @@ class Handle: ...@@ -883,6 +1015,137 @@ class Handle:
cpp_classes.append(self.cpp_class) cpp_classes.append(self.cpp_class)
interface_handle_definition = Template('''
extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject;
${functions}
};
''')
c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const
{
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
auto api_error_result = ${fname}(${args});
if (api_error_result != ${success})
throw std::runtime_error("Error in ${name}.");
return ${output};
}
''')
def generate_virtual_impl(f: Function, fname: str) -> str:
success = success_type
name = f.name
return_type = 'void'
output_decls = ''
output = ''
largs = []
lparams = []
if f.returns:
return_type = f.returns.type.str()
output_decls = '\n'.join(f.returns.virtual_output_declarations())
largs += f.returns.virtual_output_args()
output = f.returns.virtual_output()
largs += [arg for p in f.params for arg in p.virtual_arg()]
lparams += [p.virtual_param() for p in f.params if not p.this]
args = ', '.join(largs)
params = ', '.join(lparams)
return c_api_virtual_impl.substitute(locals())
class Interface(Handle):
def __init__(self, name: str, ctype: str, cpptype: str) -> None:
super().__init__(name, ctype, cpptype, skip_def=True)
self.ifunctions: List[Function] = []
self.members: List[str] = []
def mname(self, name: str) -> str:
return name + "_f"
def constructor( # type: ignore
self,
name: str,
params: Optional[List[Parameter]] = None,
**kwargs) -> 'Interface':
create = self.substitute('allocate<${opaque_type}>($@)')
initial_params = gparams(obj='void*',
c=self.cname('copy'),
d=self.cname('delete'))
add_function(self.cname(name),
params=initial_params + (params or []),
invoke=create,
returns=self.opaque_type,
return_name=self.name,
**kwargs)
return self
def method(self, *args, **kwargs) -> 'Interface':
super().method(*args, **kwargs)
return self
def virtual(self,
name: str,
params: Optional[List[Parameter]] = None,
const: Optional[bool] = None,
**kwargs) -> 'Interface':
# Add this parameter to the function
this = Parameter('obj', 'void*', this=True)
this.virtual_read = ['object_ptr.data']
f = Function(name,
params=[this] + (params or []),
virtual=True,
**kwargs)
self.ifunctions.append(f)
add_function(self.cname('set_' + name),
params=gparams(obj=self.opaque_type,
input=self.cname(name)),
invoke='${{obj}}->{name} = ${{input}}'.format(
name=self.mname(name)))
return self
def generate_function(self, f: Function):
cname = self.cname(f.name)
mname = self.mname(f.name)
function = generate_virtual_impl(f, fname=mname)
return f"{cname} {mname} = nullptr;{function}"
def generate(self):
required_functions = [
Function('copy',
params=gparams(out='void**', input='void*'),
virtual=True),
Function('delete', params=gparams(input='void*'), virtual=True)
]
for f in self.ifunctions + required_functions:
f.update()
c_header_preamble.extend([
f.get_cfunction().generate_function_pointer(self.cname(f.name))
for f in self.ifunctions + required_functions
])
function_list = [self.generate_function(f) for f in self.ifunctions]
ctype = self.ctype
cpptype = self.cpptype
copier = self.cname('copy')
deleter = self.cname('delete')
functions = '\n'.join(function_list)
c_api_body_preamble.append(
interface_handle_definition.substitute(locals()))
def handle(ctype: str, def handle(ctype: str,
cpptype: str, cpptype: str,
name: Optional[str] = None, name: Optional[str] = None,
...@@ -902,6 +1165,23 @@ def handle(ctype: str, ...@@ -902,6 +1165,23 @@ def handle(ctype: str,
return with_handle return with_handle
def interface(ctype: str, cpptype: str,
name: Optional[str] = None) -> Callable:
def with_interface(f):
n = name or f.__name__
h = Interface(n, ctype, cpptype)
f(h)
h.generate()
@wraps(f)
def decorated(*args, **kwargs):
return f(*args, **kwargs)
return decorated
return with_interface
def template_eval(template, **kwargs): def template_eval(template, **kwargs):
start = '<%' start = '<%'
end = '%>' end = '%>'
...@@ -924,7 +1204,7 @@ def run(args: List[str]) -> None: ...@@ -924,7 +1204,7 @@ def run(args: List[str]) -> None:
else: else:
sys.stdout.write(generate_c_header()) sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body()) sys.stdout.write(generate_c_api_body())
sys.stdout.write(generate_cpp_header()) # sys.stdout.write(generate_cpp_header())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
...@@ -72,6 +74,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -72,6 +74,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
} }
template <class T>
auto to_obj_vector(const T* x, std::size_t n)
{
std::vector<decltype((*x)->object)> result;
std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; });
return result;
}
template <class T, class U>
auto to_objptr_vector(const U* x, std::size_t n)
{
std::vector<T> result;
std::transform(
x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); });
return result;
}
target get_target(const std::string& name) { return make_target(name); } target get_target(const std::string& name) { return make_target(name); }
void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; } void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
...@@ -194,6 +213,41 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -194,6 +213,41 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
<% generate_c_api_body() %> <% generate_c_api_body() %>
...@@ -25,7 +25,8 @@ extern "C" { ...@@ -25,7 +25,8 @@ extern "C" {
#endif #endif
// return code, more to be added later // return code, more to be added later
typedef enum { typedef enum
{
migraphx_status_success = 0, migraphx_status_success = 0,
migraphx_status_bad_param = 1, migraphx_status_bad_param = 1,
migraphx_status_unknown_target = 3, migraphx_status_unknown_target = 3,
...@@ -35,7 +36,8 @@ typedef enum { ...@@ -35,7 +36,8 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs /// An enum to represent the different data type inputs
typedef enum { typedef enum
{
migraphx_shape_tuple_type, migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
......
...@@ -7,11 +7,13 @@ fi ...@@ -7,11 +7,13 @@ fi
if type -p python3.8 > /dev/null ; then if type -p python3.8 > /dev/null ; then
PYTHON=python3.8 PYTHON=python3.8
fi fi
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $SRC_DIR/include/migraphx/{}" ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-10 -style=file > $SRC_DIR/include/migraphx/{}"
function api { function api {
$PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-5.0 -style=file > $2 $PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-10 -style=file > $2
} }
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
echo "Finished generating header migraphx.h"
api $DIR/api/api.cpp $SRC_DIR/api/api.cpp api $DIR/api/api.cpp $SRC_DIR/api/api.cpp
echo "Finished generating source api.cpp "
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -33,12 +34,21 @@ value to_value_context(const T&) ...@@ -33,12 +34,21 @@ value to_value_context(const T&)
} }
template <class T> template <class T>
void from_value_context(T&, const value&){} void from_value_context(T&, const value&)
{
}
template <class T>
any_ptr get_queue_context(T&)
{
return {};
}
<% <%
interface('context', interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'), virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('finish', returns = 'void', const = True)) %> virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx) inline void migraphx_to_value(value& v, const context& ctx)
......
...@@ -26,11 +26,11 @@ struct schedule_model ...@@ -26,11 +26,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed /// Get the number of concurrent instruction allowed
std::size_t concurrency() const; std::size_t concurrency() const;
/// Schedule a concurrent instruction /// Schedule a concurrent instruction
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction // Insert necessary waits before an instruction
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction // Insert necessary records after an instruction
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation /// Compute weights for an operation
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
...@@ -40,9 +40,9 @@ struct schedule_model ...@@ -40,9 +40,9 @@ struct schedule_model
<% <%
interface('schedule_model', interface('schedule_model',
virtual('concurrency', returns='std::size_t', const=True), virtual('concurrency', returns='std::size_t', const=True),
virtual('sched', p='module&', ins='instruction_ref', n='std::size_t', const=True), virtual('sched', m='module&', ins='instruction_ref', n='std::size_t', const=True),
virtual('wait', p='module&', ins='instruction_ref', wait_id='std::size_t', const=True), virtual('wait', m='module&', ins='instruction_ref', wait_id='std::size_t', const=True),
virtual('record', p='module&', ins='instruction_ref', wait_id='std::size_t', const=True), virtual('record', m='module&', ins='instruction_ref', wait_id='std::size_t', const=True),
virtual('weight', returns='std::size_t', op='const operation&', const=True) virtual('weight', returns='std::size_t', op='const operation&', const=True)
) )
%> %>
......
...@@ -4,12 +4,20 @@ ...@@ -4,12 +4,20 @@
set -e set -e
#install pip3, rocm-cmake, rocblas and miopen export LC_ALL=C.UTF-8
apt update && apt install -y python3-pip rocm-cmake rocblas miopen-hip openmp-extras export LANG=C.UTF-8
# Need pip3 and Python headers to build dependencies
apt update && apt install -y python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras
# Needed for cmake to build various pip packages
pip3 install setuptools wheel
# install rbuild to build dependencies # install rbuild to build dependencies
pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
PREFIX=/usr/local PREFIX=/usr/local
REQ_FILE_DIR="" REQ_FILE_DIR=""
if [ "$#" -ge 2 ]; then if [ "$#" -ge 2 ]; then
...@@ -19,7 +27,7 @@ elif [ "$#" -eq 1 ]; then ...@@ -19,7 +27,7 @@ elif [ "$#" -eq 1 ]; then
PREFIX=$1 PREFIX=$1
fi fi
echo "Dependencies are install at $PREFIX" echo "Dependencies are installed at $PREFIX"
# Install deps with rbuild # Install deps with rbuild
rbuild prepare -d $PREFIX -s develop rbuild prepare -d $PREFIX -s develop
...@@ -27,3 +35,5 @@ rbuild prepare -d $PREFIX -s develop ...@@ -27,3 +35,5 @@ rbuild prepare -d $PREFIX -s develop
# install onnx package for unit tests # install onnx package for unit tests
pip3 install onnx==1.8.1 numpy==1.18.5 typing==3.7.4 pytest==6.0.1 packaging==16.8 pip3 install onnx==1.8.1 numpy==1.18.5 typing==3.7.4 pytest==6.0.1 packaging==16.8
# pin version of protobuf in Python for onnx runtime unit tests
pip3 install protobuf==3.20.0
...@@ -12,16 +12,15 @@ headers = ''' ...@@ -12,16 +12,15 @@ headers = '''
''' '''
form = string.Template(''' form = string.Template('''
#ifdef TYPE_ERASED_DECLARATION
/* // Type-erased interface for:
* Type-erased interface for: struct ${struct_name}
* {
* struct ${struct_name} ${decl_members}
* { };
${comment_members}
* }; #else
*
*/
struct ${struct_name} struct ${struct_name}
{ {
...@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x) ...@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
if (y == nullptr) throw std::bad_cast(); if (y == nullptr) throw std::bad_cast();
return *y; return *y;
} }
#endif
''') ''')
nonvirtual_member = string.Template(''' nonvirtual_member = string.Template('''
...@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override ...@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member = string.Template( comment_member = string.Template(
'''* ${friend} ${return_type} ${name}(${params}) ${const};''') '''* ${friend} ${return_type} ${name}(${params}) ${const};''')
decl_member = string.Template(''' ${comment}
${friend} ${return_type} ${name}(${params}) ${const};
''')
default_member = string.Template(''' default_member = string.Template('''
template<class T> template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params}) static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
...@@ -279,7 +283,8 @@ def convert_member(d, struct_name): ...@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
'this': '(*this)', 'this': '(*this)',
'using': '', 'using': '',
'brief': '', 'brief': '',
'return_': '' 'return_': '',
'comment': '// '
} }
args = [] args = []
params = [] params = []
...@@ -306,6 +311,7 @@ def convert_member(d, struct_name): ...@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
member['friend'] = 'friend' member['friend'] = 'friend'
elif x == 'default': elif x == 'default':
member['default'] = t member['default'] = t
member['comment'] = member['comment'] + '(optional)'
elif x == 'using': elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using']) member['using'] = 'using {};'.format(d[name]['using'])
elif x == '__brief__': elif x == '__brief__':
...@@ -347,18 +353,21 @@ def generate_form(name, members): ...@@ -347,18 +353,21 @@ def generate_form(name, members):
virtual_members = [] virtual_members = []
comment_members = [] comment_members = []
default_members = [] default_members = []
decl_members = []
for member in members: for member in members:
m = convert_member(member, name) m = convert_member(member, name)
nonvirtual_members.append(nonvirtual_member.substitute(m)) nonvirtual_members.append(nonvirtual_member.substitute(m))
pure_virtual_members.append(pure_virtual_member.substitute(m)) pure_virtual_members.append(pure_virtual_member.substitute(m))
virtual_members.append(virtual_member.substitute(m)) virtual_members.append(virtual_member.substitute(m))
comment_members.append(comment_member.substitute(m)) comment_members.append(comment_member.substitute(m))
decl_members.append(decl_member.substitute(m))
if 'default' in m: if 'default' in m:
default_members.append(default_member.substitute(m)) default_members.append(default_member.substitute(m))
return form.substitute(nonvirtual_members=''.join(nonvirtual_members), return form.substitute(nonvirtual_members=''.join(nonvirtual_members),
pure_virtual_members=''.join(pure_virtual_members), pure_virtual_members=''.join(pure_virtual_members),
virtual_members=''.join(virtual_members), virtual_members=''.join(virtual_members),
default_members=''.join(default_members), default_members=''.join(default_members),
decl_members=''.join(decl_members),
comment_members='\n'.join(comment_members), comment_members='\n'.join(comment_members),
struct_name=name) struct_name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment