Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_topk_1 : verify_program<test_topk_1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 5}};
auto data = mm->add_parameter("data", s);
auto r = mm->add_instruction(
migraphx::make_op("topk", {{"axis", -2}, {"k", 3}, {"largest", 1}}), data);
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r);
mm->add_return({r0, r1});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_topk_2 : verify_program<test_topk_2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 5}};
auto data = mm->add_parameter("data", s);
auto r = mm->add_instruction(
migraphx::make_op("topk", {{"axis", 1}, {"k", 4}, {"largest", 0}}), data);
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r);
mm->add_return({r0});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_topk_3 : verify_program<test_topk_3>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 5}};
auto data = mm->add_parameter("data", s);
auto r = mm->add_instruction(
migraphx::make_op("topk", {{"axis", -2}, {"k", 3}, {"largest", 0}}), data);
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r);
mm->add_return({r0, r1});
return p;
}
};
......@@ -11,7 +11,8 @@ struct test_trans_abs : verify_program<test_trans_abs>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x);
auto absx = mm->add_instruction(migraphx::make_op("abs"), tx);
auto r = mm->add_instruction(migraphx::make_op("add"), absx, absx);
mm->add_instruction(migraphx::make_op("contiguous"), r);
......
......@@ -11,7 +11,8 @@ struct test_trans_ret : verify_program<test_trans_ret>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x);
mm->add_return({tx});
return p;
......
......@@ -11,7 +11,8 @@ struct test_trans_tanh : verify_program<test_trans_tanh>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x);
auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx);
auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx);
mm->add_instruction(migraphx::make_op("contiguous"), r);
......
......@@ -11,7 +11,8 @@ struct test_trans_tanh1 : verify_program<test_trans_tanh1>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), x);
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x);
auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx);
auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx);
mm->add_return({tx, r});
......
......@@ -13,7 +13,7 @@ struct test_transpose : verify_program<test_transpose>
migraphx::shape s{migraphx::shape::float_type, {4, 3, 4, 4}};
auto x = mm->add_parameter("x", s);
std::vector<int64_t> perm = {0, 2, 3, 1};
auto l = mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), x);
auto l = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), x);
mm->add_instruction(migraphx::make_op("contiguous"), l);
return p;
}
......
......@@ -16,7 +16,7 @@ struct test_triadd2 : verify_program<test_triadd2>
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", b);
auto zb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), z);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), z);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("add"), sum, zb);
return p;
......
......@@ -17,7 +17,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = mm->add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", x->get_shape().lens()}}), y);
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", x->get_shape().lens()}}), y);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, by);
mm->add_instruction(migraphx::make_op("add"), sum, z);
return p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}};
migraphx::shape sx{migraphx::shape::float_type, {1, 3, 4, 5}};
auto b = mm->add_parameter("b", sb);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sx);
auto r = mm->add_instruction(migraphx::make_op("where"), b, x, y);
mm->add_return({r});
return p;
};
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_where2 : verify_program<test_where2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}};
migraphx::shape sx{migraphx::shape::float_type, {1}};
auto b = mm->add_parameter("b", sb);
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sx);
auto mbx = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), x);
auto mby = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), y);
auto r = mm->add_instruction(migraphx::make_op("where"), b, mbx, mby);
mm->add_return({r});
return p;
};
};
import string, sys, re, runpy
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
type_map = {}
cpp_type_map = {}
functions = []
cpp_classes = []
type_map: Dict[str, Callable[['Parameter'], None]] = {}
cpp_type_map: Dict[str, str] = {}
functions: List['Function'] = []
cpp_classes: List['CPPClass'] = []
error_type = ''
success_type = ''
try_wrap = ''
c_header_preamble = []
c_api_body_preamble = []
cpp_header_preamble = []
c_header_preamble: List[str] = []
c_api_body_preamble: List[str] = []
cpp_header_preamble: List[str] = []
def bad_param_error(msg):
def bad_param_error(msg: str):
return 'throw std::runtime_error("{}")'.format(msg)
......@@ -23,28 +24,31 @@ class Template(string.Template):
class Type:
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name.strip()
def is_pointer(self):
def is_pointer(self) -> bool:
return self.name.endswith('*')
def is_reference(self):
def is_reference(self) -> bool:
return self.name.endswith('&')
def is_const(self):
def is_const(self) -> bool:
return self.name.startswith('const ')
def add_pointer(self):
def is_variadic(self):
return self.name.startswith('...')
def add_pointer(self) -> 'Type':
return Type(self.name + '*')
def add_reference(self):
return Type(self.name + '&')
def add_const(self):
def add_const(self) -> 'Type':
return Type('const ' + self.name)
def inner_type(self):
def inner_type(self) -> Optional['Type']:
i = self.name.find('<')
j = self.name.rfind('>')
if i > 0 and j > 0:
......@@ -52,7 +56,7 @@ class Type:
else:
return None
def remove_generic(self):
def remove_generic(self) -> 'Type':
i = self.name.find('<')
j = self.name.rfind('>')
if i > 0 and j > 0:
......@@ -60,37 +64,37 @@ class Type:
else:
return self
def remove_pointer(self):
def remove_pointer(self) -> 'Type':
if self.is_pointer():
return Type(self.name[0:-1])
return self
def remove_reference(self):
def remove_reference(self) -> 'Type':
if self.is_reference():
return Type(self.name[0:-1])
return self
def remove_const(self):
def remove_const(self) -> 'Type':
if self.is_const():
return Type(self.name[6:])
return self
def basic(self):
def basic(self) -> 'Type':
return self.remove_pointer().remove_const().remove_reference()
def decay(self):
def decay(self) -> 'Type':
t = self.remove_reference()
if t.is_pointer():
return t
else:
return t.remove_const()
def const_compatible(self, t):
def const_compatible(self, t: 'Type'):
if t.is_const():
return self.add_const()
return self
def str(self):
def str(self) -> str:
return self.name
......@@ -98,54 +102,83 @@ header_function = Template('''
${error_type} ${name}(${params});
''')
function_pointer_typedef = Template('''
typedef ${error_type} (*${fname})(${params});
''')
c_api_impl = Template('''
extern "C" ${error_type} ${name}(${params})
{
return ${try_wrap}([&] {
${va_start}auto api_error_result = ${try_wrap}([&] {
${body};
});
${va_end}return api_error_result;
}
''')
class CFunction:
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
self.params = []
self.body = []
self.params: List[str] = []
self.body: List[str] = []
self.va_start: List[str] = []
self.va_end: List[str] = []
def add_param(self, type, pname):
def add_param(self, type: str, pname: str) -> None:
self.params.append('{} {}'.format(type, pname))
def add_statement(self, stmt):
def add_statement(self, stmt: str) -> None:
self.body.append(stmt)
def substitute(self, form):
def add_vlist(self, name: str) -> None:
last_param = self.params[-1].split()[-1]
self.va_start = [
'va_list {};'.format(name),
'va_start({}, {});'.format(name, last_param)
]
self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '')
def substitute(self, form: Template, **kwargs) -> str:
return form.substitute(error_type=error_type,
try_wrap=try_wrap,
name=self.name,
params=', '.join(self.params),
body=";\n ".join(self.body))
body=";\n ".join(self.body),
va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end),
**kwargs)
def generate_header(self):
def generate_header(self) -> str:
return self.substitute(header_function)
def generate_body(self):
def generate_function_pointer(self, name: Optional[str] = None) -> str:
return self.substitute(function_pointer_typedef,
fname=name or self.name)
def generate_body(self) -> str:
return self.substitute(c_api_impl)
class BadParam:
def __init__(self, cond, msg):
def __init__(self, cond: str, msg: str) -> None:
self.cond = cond
self.msg = msg
class Parameter:
def __init__(self, name, type, optional=False, returns=False):
def __init__(self,
name: str,
type: str,
optional: bool = False,
returns: bool = False,
virtual: bool = False,
this: bool = False) -> None:
self.name = name
self.type = Type(type)
self.optional = optional
self.cparams = []
self.cparams: List[Tuple[str, str]] = []
self.size_cparam = -1
self.size_name = ''
self.read = '${name}'
......@@ -153,15 +186,19 @@ class Parameter:
self.cpp_read = '${name}'
self.cpp_write = '${name}'
self.returns = returns
self.bad_param_check = None
self.virtual = virtual
self.this = this
self.bad_param_check: Optional[BadParam] = None
self.virtual_read: Optional[List[str]] = None
self.virtual_write: Optional[str] = None
def get_name(self, prefix=None):
def get_name(self, prefix: Optional[str] = None) -> str:
if prefix:
return prefix + self.name
else:
return self.name
def get_cpp_type(self):
def get_cpp_type(self) -> str:
if self.type.str() in cpp_type_map:
return cpp_type_map[self.type.basic().str()]
elif self.type.basic().str() in cpp_type_map:
......@@ -171,7 +208,10 @@ class Parameter:
else:
return self.type.str()
def substitute(self, s, prefix=None, result=None):
def substitute(self,
s: str,
prefix: Optional[str] = None,
result: Optional[str] = None) -> str:
ctype = None
if len(self.cparams) > 0:
ctype = Type(self.cparams[0][0]).basic().str()
......@@ -182,12 +222,13 @@ class Parameter:
size=self.size_name,
result=result or '')
def add_param(self, t, name=None):
def add_param(self, t: Union[str, Type],
name: Optional[str] = None) -> None:
if not isinstance(t, str):
t = t.str()
self.cparams.append((t, name or self.name))
def add_size_param(self, name=None):
def add_size_param(self, name: Optional[str] = None) -> None:
self.size_cparam = len(self.cparams)
self.size_name = name or self.name + '_size'
if self.returns:
......@@ -195,7 +236,7 @@ class Parameter:
else:
self.add_param('size_t', self.size_name)
def bad_param(self, cond, msg):
def bad_param(self, cond: str, msg: str) -> None:
self.bad_param_check = BadParam(cond, msg)
def remove_size_param(self, name):
......@@ -206,7 +247,7 @@ class Parameter:
self.size_name = name
return p
def update(self):
def update(self) -> None:
t = self.type.basic().str()
g = self.type.remove_generic().basic().str()
if t in type_map:
......@@ -222,18 +263,60 @@ class Parameter:
raise ValueError("Error for {}: write cannot be a string".format(
self.type.str()))
def cpp_param(self, prefix=None):
def virtual_arg(self, prefix: Optional[str] = None) -> List[str]:
read = self.virtual_read
if not read and len(self.write) >= len(self.cparams):
read = [
Template(w.partition('=')[2]).safe_substitute(result='${name}')
for w in self.write
]
if not read:
raise ValueError("No virtual_read parameter provided for: " +
self.type.str())
if isinstance(read, str):
raise ValueError(
"Error for {}: virtual_read cannot be a string".format(
self.type.str()))
return [self.substitute(r, prefix=prefix) for r in read]
def virtual_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]:
return [
'std::remove_pointer_t<{type}> {prefix}{n};'.format(
type=Type(t).str(), prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write
if not write:
if '*' in self.read or '->' in self.read:
write = Template(self.read).safe_substitute(name='(&${name})')
else:
write = self.read
return self.substitute(write, prefix=prefix)
def cpp_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${cpptype} ${name}', prefix=prefix)
def cpp_arg(self, prefix=None):
def cpp_arg(self, prefix: Optional[str] = None) -> str:
return self.substitute(self.cpp_read, prefix=prefix)
def cpp_output_args(self, prefix=None):
def cpp_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix, n=n) for t, n in self.cparams
]
def output_declarations(self, prefix=None):
def output_declarations(self, prefix: Optional[str] = None) -> List[str]:
return [
'{type} {prefix}{n};'.format(type=Type(t).remove_pointer().str(),
prefix=prefix,
......@@ -245,18 +328,21 @@ class Parameter:
'&{prefix}{n};'.format(prefix=prefix, n=n) for t, n in self.cparams
]
def cpp_output(self, prefix=None):
def cpp_output(self, prefix: Optional[str] = None) -> str:
return self.substitute(self.cpp_write, prefix=prefix)
def input(self, prefix=None):
def input(self, prefix: Optional[str] = None) -> str:
return '(' + self.substitute(self.read, prefix=prefix) + ')'
def outputs(self, result=None):
def outputs(self, result: Optional[str] = None) -> List[str]:
return [self.substitute(w, result=result) for w in self.write]
def add_to_cfunction(self, cfunction):
def add_to_cfunction(self, cfunction: CFunction) -> None:
for t, name in self.cparams:
cfunction.add_param(self.substitute(t), self.substitute(name))
if t.startswith('...'):
cfunction.add_vlist(name)
else:
cfunction.add_param(self.substitute(t), self.substitute(name))
if self.bad_param_check:
msg = 'Bad parameter {name}: {msg}'.format(
name=self.name, msg=self.bad_param_check.msg)
......@@ -265,35 +351,40 @@ class Parameter:
body=bad_param_error(msg)))
def template_var(s):
def template_var(s: str) -> str:
return '${' + s + '}'
def to_template_vars(params):
def to_template_vars(params: List[Union[Any, Parameter]]) -> str:
return ', '.join([template_var(p.name) for p in params])
class Function:
def __init__(self,
name,
params=None,
shared_size=False,
returns=None,
invoke=None,
fname=None,
return_name=None,
**kwargs):
name: str,
params: Optional[List[Parameter]] = None,
shared_size: bool = False,
returns: Optional[str] = None,
invoke: Optional[str] = None,
fname: Optional[str] = None,
return_name: Optional[str] = None,
virtual: bool = False,
**kwargs) -> None:
self.name = name
self.params = params or []
self.shared_size = False
self.cfunction = None
self.cfunction: Optional[CFunction] = None
self.fname = fname
self.invoke = invoke or '${__fname__}($@)'
self.return_name = return_name or 'out'
self.returns = Parameter(self.return_name, returns,
returns=True) if returns else None
for p in self.params:
p.virtual = virtual
if self.returns:
self.returns.virtual = virtual
def share_params(self):
def share_params(self) -> None:
if self.shared_size == True:
size_param_name = 'size'
size_type = Type('size_t')
......@@ -303,7 +394,7 @@ class Function:
size_type = Type(p[0])
self.params.append(Parameter(size_param_name, size_type.str()))
def update(self):
def update(self) -> None:
self.share_params()
for param in self.params:
param.update()
......@@ -311,11 +402,12 @@ class Function:
self.returns.update()
self.create_cfunction()
def inputs(self):
def inputs(self) -> str:
return ', '.join([p.input() for p in self.params])
def input_map(self):
m = {}
# TODO: Shoule we remove Optional?
def input_map(self) -> Dict[str, Optional[str]]:
m: Dict[str, Optional[str]] = {}
for p in self.params:
m[p.name] = p.input()
m['return'] = self.return_name
......@@ -323,14 +415,22 @@ class Function:
m['__fname__'] = self.fname
return m
def get_invoke(self):
def get_invoke(self) -> str:
return Template(self.invoke).safe_substitute(self.input_map())
def write_to_tmp_var(self):
def write_to_tmp_var(self) -> bool:
if not self.returns:
return False
return len(self.returns.write) > 1 or self.returns.write[0].count(
'${result}') > 1
def create_cfunction(self):
def get_cfunction(self) -> CFunction:
if self.cfunction:
return self.cfunction
raise Exception(
"self.cfunction is None: self.update() needs to be called.")
def create_cfunction(self) -> None:
self.cfunction = CFunction(self.name)
# Add the return as a parameter
if self.returns:
......@@ -338,12 +438,12 @@ class Function:
# Add the input parameters
for param in self.params:
param.add_to_cfunction(self.cfunction)
f = self.get_invoke()
f: Optional[str] = self.get_invoke()
# Write the assignments
assigns = []
if self.returns:
result = f
if self.write_to_tmp_var():
if self.write_to_tmp_var() and f:
f = 'auto&& api_result = ' + f
result = 'api_result'
else:
......@@ -396,31 +496,37 @@ cpp_class_constructor_template = Template('''
class CPPMember:
def __init__(self, name, function, prefix, method=True):
def __init__(self,
name: str,
function: Function,
prefix: str,
method: bool = True) -> None:
self.name = name
self.function = function
self.prefix = prefix
self.method = method
def get_function_params(self):
def get_function_params(self) -> List[Union[Any, Parameter]]:
if self.method:
return self.function.params[1:]
else:
return self.function.params
def get_args(self):
def get_args(self) -> str:
output_args = []
if self.function.returns:
output_args = self.function.returns.cpp_output_args(self.prefix)
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
return ', '.join(
['&{}'.format(self.function.cfunction.name)] + output_args +
[p.cpp_arg(self.prefix) for p in self.get_function_params()])
def get_params(self):
def get_params(self) -> str:
return ', '.join(
[p.cpp_param(self.prefix) for p in self.get_function_params()])
def get_return_declarations(self):
def get_return_declarations(self) -> str:
if self.function.returns:
return '\n '.join([
d
......@@ -432,7 +538,9 @@ class CPPMember:
def get_result(self):
return self.function.returns.input(self.prefix)
def generate_method(self):
def generate_method(self) -> str:
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
if self.function.returns:
return_type = self.function.returns.get_cpp_type()
return cpp_class_method_template.safe_substitute(
......@@ -452,7 +560,9 @@ class CPPMember:
args=self.get_args(),
success=success_type)
def generate_constructor(self, name):
def generate_constructor(self, name: str) -> str:
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
return cpp_class_constructor_template.safe_substitute(
name=name,
cfunction=self.function.cfunction.name,
......@@ -462,98 +572,104 @@ class CPPMember:
class CPPClass:
def __init__(self, name, ctype):
def __init__(self, name: str, ctype: str) -> None:
self.name = name
self.ctype = ctype
self.constructors = []
self.methods = []
self.constructors: List[CPPMember] = []
self.methods: List[CPPMember] = []
self.prefix = 'p'
def add_method(self, name, f):
def add_method(self, name: str, f: Function) -> None:
self.methods.append(CPPMember(name, f, self.prefix, method=True))
def add_constructor(self, name, f):
def add_constructor(self, name: str, f: Function) -> None:
self.constructors.append(CPPMember(name, f, self.prefix, method=True))
def generate_methods(self):
def generate_methods(self) -> str:
return '\n '.join([m.generate_method() for m in self.methods])
def generate_constructors(self):
def generate_constructors(self) -> str:
return '\n '.join(
[m.generate_constructor(self.name) for m in self.constructors])
def substitute(self, s, **kwargs):
t = s
if isinstance(s, str):
t = string.Template(s)
def substitute(self, s: Union[string.Template, str], **kwargs) -> str:
t = string.Template(s) if isinstance(s, str) else s
destroy = self.ctype + '_destroy'
return t.safe_substitute(name=self.name,
ctype=self.ctype,
destroy=destroy,
**kwargs)
def generate(self):
def generate(self) -> str:
return self.substitute(
cpp_class_template,
constructors=self.substitute(self.generate_constructors()),
methods=self.substitute(self.generate_methods()))
def params(virtual=None, **kwargs):
def params(virtual: Optional[Dict[str, str]] = None,
**kwargs) -> List[Parameter]:
result = []
for name in virtual or {}:
result.append(Parameter(name, virtual[name]))
v: Dict[str, str] = virtual or {}
for name in v:
result.append(Parameter(name, v[name]))
for name in kwargs:
result.append(Parameter(name, kwargs[name]))
return result
def add_function(name, *args, **kwargs):
gparams = params
def add_function(name: str, *args, **kwargs) -> Function:
f = Function(name, *args, **kwargs)
functions.append(f)
return f
def once(f):
def once(f: Callable) -> Any:
@wraps(f)
def decorated(*args, **kwargs):
if not decorated.has_run:
decorated.has_run = True
return f(*args, **kwargs)
decorated.has_run = False
return decorated
d: Any = decorated
d.has_run = False
return d
@once
def process_functions():
def process_functions() -> None:
for f in functions:
f.update()
def generate_lines(p):
def generate_lines(p: List[str]) -> str:
return '\n'.join(p)
def generate_c_header():
def generate_c_header() -> str:
process_functions()
return generate_lines(c_header_preamble +
[f.cfunction.generate_header() for f in functions])
return generate_lines(
c_header_preamble +
[f.get_cfunction().generate_header() for f in functions])
def generate_c_api_body():
def generate_c_api_body() -> str:
process_functions()
return generate_lines(c_api_body_preamble +
[f.cfunction.generate_body() for f in functions])
return generate_lines(
c_api_body_preamble +
[f.get_cfunction().generate_body() for f in functions])
def generate_cpp_header():
def generate_cpp_header() -> str:
process_functions()
return generate_lines(cpp_header_preamble +
[c.generate() for c in cpp_classes])
def cwrap(name):
def cwrap(name: str) -> Callable:
def with_cwrap(f):
type_map[name] = f
......@@ -576,7 +692,7 @@ extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(Ts&&... xs)
: object(std::forward<Ts>(xs)...)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{}
${cpptype} object;
};
......@@ -605,6 +721,55 @@ void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t)
{
}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
'''
cpp_handle_preamble = '''
......@@ -657,119 +822,153 @@ protected:
@once
def add_handle_preamble():
def add_handle_preamble() -> None:
c_api_body_preamble.append(handle_preamble)
cpp_header_preamble.append(
string.Template(cpp_handle_preamble).substitute(success=success_type))
def add_handle(name, ctype, cpptype, destroy=None, ref=None):
def add_handle(name: str,
ctype: str,
cpptype: str,
destroy: Optional[str] = None,
ref=False,
skip_def=False) -> None:
opaque_type = ctype + '_t'
const_opaque_type = 'const_' + opaque_type
def handle_wrap(p):
def handle_wrap(p: Parameter):
t = Type(opaque_type)
if p.type.is_const():
t = Type('const_' + opaque_type)
if p.returns:
# p.read = 'object_cast<${ctype}>(&(${name}))'
if p.virtual:
p.add_param(t)
elif p.returns:
p.add_param(t.add_pointer())
if p.type.is_reference():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
if p.type.is_reference():
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.virtual_read = ['object_cast<${ctype}>(${result})']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
if skip_def:
p.read = '*${name}'
else:
p.read = '${name}->object'
p.cpp_read = '${name}.get_handle_ptr()'
p.cpp_read = '${name}.get_handle_ptr()'
type_map[cpptype] = handle_wrap
if not ref:
add_function(destroy or ctype + '_' + 'destroy',
params({name: opaque_type}),
fname='destroy')
add_function(ctype + '_' + 'assign_to',
params(output=opaque_type, input=const_opaque_type),
invoke='*output = *input')
add_handle_preamble()
c_header_preamble.append(handle_typedef.substitute(locals()))
c_api_body_preamble.append(handle_definition.substitute(locals()))
if not skip_def:
c_api_body_preamble.append(handle_definition.substitute(locals()))
@cwrap('std::vector')
def vector_c_wrap(p):
t = p.type.inner_type().add_pointer()
def vector_c_wrap(p: Parameter) -> None:
inner = p.type.inner_type()
# Not a generic type
if not inner:
return
t = inner.add_pointer()
if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
if p.returns:
if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
p.add_param(t.add_pointer())
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'std::copy(${result}.begin(), ${result}.end(), ${name})'
]
else:
p.add_param(t)
p.add_size_param()
p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer')
p.read = '${type}(${name}, ${name}+${size})'
p.read = '${type}(${name}, ${name}+${size})'
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.virtual_read = ['${name}.data()', '${name}.size()']
if p.type.is_reference():
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string')
def string_c_wrap(p):
def string_c_wrap(p: Parameter) -> None:
t = Type('char*')
if p.returns:
if p.type.is_reference():
p.add_param(t.add_pointer())
p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = ['*${name} = ${result}.c_str()']
else:
p.add_param(t)
p.add_param('size_t', p.name + '_size')
p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
p.read = '${type}(${name})'
p.read = '${type}(${name})'
p.cpp_write = '${type}(${name})'
p.virtual_read = ['${name}.c_str()']
if p.type.is_reference():
p.write = ['*${name} = ${result}.c_str()']
else:
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
class Handle:
def __init__(self, name, ctype, cpptype, ref=None):
def __init__(self, name: str, ctype: str, cpptype: str, **kwargs) -> None:
self.name = name
self.ctype = ctype
self.cpptype = cpptype
self.opaque_type = self.ctype + '_t'
self.cpp_class = CPPClass(name, ctype)
add_handle(name, ctype, cpptype, ref=ref)
add_handle(name, ctype, cpptype, **kwargs)
cpp_type_map[cpptype] = name
def cname(self, name):
def cname(self, name: str) -> str:
return self.ctype + '_' + name
def substitute(self, s, **kwargs):
def substitute(self, s: str, **kwargs) -> str:
return Template(s).safe_substitute(name=self.name,
ctype=self.ctype,
cpptype=self.cpptype,
opaque_type=self.opaque_type,
**kwargs)
def constructor(self, name, params=None, fname=None, invoke=None,
**kwargs):
def constructor(self,
name: str,
params: Optional[List[Parameter]] = None,
fname: Optional[str] = None,
invoke: Optional[str] = None,
**kwargs) -> 'Handle':
create = self.substitute('allocate<${cpptype}>($@)')
if fname:
create = self.substitute('allocate<${cpptype}>(${fname}($@))',
......@@ -785,13 +984,13 @@ class Handle:
return self
def method(self,
name,
params=None,
fname=None,
invoke=None,
cpp_name=None,
const=None,
**kwargs):
name: str,
params: Optional[List[Parameter]] = None,
fname: Optional[str] = None,
invoke: Optional[str] = None,
cpp_name: Optional[str] = None,
const: Optional[bool] = None,
**kwargs) -> 'Handle':
cpptype = self.cpptype
if const:
cpptype = Type(cpptype).add_const().str()
......@@ -812,11 +1011,145 @@ class Handle:
add_function(self.cname(name), params=params, **kwargs)
return self
def add_cpp_class(self):
def add_cpp_class(self) -> None:
cpp_classes.append(self.cpp_class)
def handle(ctype, cpptype, name=None, ref=None):
interface_handle_definition = Template('''
extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject;
${functions}
};
''')
c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const
{
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
auto api_error_result = ${fname}(${args});
if (api_error_result != ${success})
throw std::runtime_error("Error in ${name}.");
return ${output};
}
''')
def generate_virtual_impl(f: Function, fname: str) -> str:
success = success_type
name = f.name
return_type = 'void'
output_decls = ''
output = ''
largs = []
lparams = []
if f.returns:
return_type = f.returns.type.str()
output_decls = '\n'.join(f.returns.virtual_output_declarations())
largs += f.returns.virtual_output_args()
output = f.returns.virtual_output()
largs += [arg for p in f.params for arg in p.virtual_arg()]
lparams += [p.virtual_param() for p in f.params if not p.this]
args = ', '.join(largs)
params = ', '.join(lparams)
return c_api_virtual_impl.substitute(locals())
class Interface(Handle):
def __init__(self, name: str, ctype: str, cpptype: str) -> None:
super().__init__(name, ctype, cpptype, skip_def=True)
self.ifunctions: List[Function] = []
self.members: List[str] = []
def mname(self, name: str) -> str:
return name + "_f"
def constructor( # type: ignore
self,
name: str,
params: Optional[List[Parameter]] = None,
**kwargs) -> 'Interface':
create = self.substitute('allocate<${opaque_type}>($@)')
initial_params = gparams(obj='void*',
c=self.cname('copy'),
d=self.cname('delete'))
add_function(self.cname(name),
params=initial_params + (params or []),
invoke=create,
returns=self.opaque_type,
return_name=self.name,
**kwargs)
return self
def method(self, *args, **kwargs) -> 'Interface':
super().method(*args, **kwargs)
return self
def virtual(self,
name: str,
params: Optional[List[Parameter]] = None,
const: Optional[bool] = None,
**kwargs) -> 'Interface':
# Add this parameter to the function
this = Parameter('obj', 'void*', this=True)
this.virtual_read = ['object_ptr.data']
f = Function(name,
params=[this] + (params or []),
virtual=True,
**kwargs)
self.ifunctions.append(f)
add_function(self.cname('set_' + name),
params=gparams(obj=self.opaque_type,
input=self.cname(name)),
invoke='${{obj}}->{name} = ${{input}}'.format(
name=self.mname(name)))
return self
def generate_function(self, f: Function):
cname = self.cname(f.name)
mname = self.mname(f.name)
function = generate_virtual_impl(f, fname=mname)
return f"{cname} {mname} = nullptr;{function}"
def generate(self):
required_functions = [
Function('copy',
params=gparams(out='void**', input='void*'),
virtual=True),
Function('delete', params=gparams(input='void*'), virtual=True)
]
for f in self.ifunctions + required_functions:
f.update()
c_header_preamble.extend([
f.get_cfunction().generate_function_pointer(self.cname(f.name))
for f in self.ifunctions + required_functions
])
function_list = [self.generate_function(f) for f in self.ifunctions]
ctype = self.ctype
cpptype = self.cpptype
copier = self.cname('copy')
deleter = self.cname('delete')
functions = '\n'.join(function_list)
c_api_body_preamble.append(
interface_handle_definition.substitute(locals()))
def handle(ctype: str,
cpptype: str,
name: Optional[str] = None,
ref: Optional[bool] = None) -> Callable:
def with_handle(f):
n = name or f.__name__
h = Handle(n, ctype, cpptype, ref=ref)
......@@ -832,6 +1165,23 @@ def handle(ctype, cpptype, name=None, ref=None):
return with_handle
def interface(ctype: str, cpptype: str,
name: Optional[str] = None) -> Callable:
def with_interface(f):
n = name or f.__name__
h = Interface(n, ctype, cpptype)
f(h)
h.generate()
@wraps(f)
def decorated(*args, **kwargs):
return f(*args, **kwargs)
return decorated
return with_interface
def template_eval(template, **kwargs):
start = '<%'
end = '%>'
......@@ -845,18 +1195,18 @@ def template_eval(template, **kwargs):
return template
def run():
runpy.run_path(sys.argv[1])
if len(sys.argv) > 2:
f = open(sys.argv[2]).read()
def run(args: List[str]) -> None:
runpy.run_path(args[0])
if len(args) > 1:
f = open(args[1]).read()
r = template_eval(f)
sys.stdout.write(r)
else:
sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body())
sys.stdout.write(generate_cpp_header())
# sys.stdout.write(generate_cpp_header())
if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
run()
run(sys.argv[1:])
......@@ -4,15 +4,18 @@
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
......@@ -71,28 +74,41 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
}
target get_target(const std::string& name) { return make_target(name); }
migraphx::compile_options to_compile_options(const migraphx_compile_options& options)
template <class T>
auto to_obj_vector(const T* x, std::size_t n)
{
migraphx::compile_options result{};
result.offload_copy = options.offload_copy;
result.fast_math = options.fast_math;
std::vector<decltype((*x)->object)> result;
std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; });
return result;
}
migraphx::file_options to_file_options(const migraphx_file_options& options)
template <class T, class U>
auto to_objptr_vector(const U* x, std::size_t n)
{
migraphx::file_options result{};
result.format = options.format;
std::vector<T> result;
std::transform(
x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); });
return result;
}
target get_target(const std::string& name) { return make_target(name); }
void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
void set_fast_math(compile_options& options, bool value) { options.fast_math = value; }
void set_file_format(file_options& options, const char* format) { options.format = format; }
void set_default_dim_value(onnx_options& options, size_t value)
{
options.default_dim_value = value;
}
void set_default_loop_iterations(onnx_options& options, int64_t value)
{
options.max_loop_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......@@ -159,18 +175,30 @@ void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& o
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
operation create_op(const char* name, const char* attributes)
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
operation create_op(const char* name, const char* attributes, va_list vlist)
{
std::string sattributes = attributes == nullptr ? "" : attributes;
std::vector<char> buffer(sattributes.size() * 2);
std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist);
value v = value::object{};
if(attributes != nullptr)
{
v = from_json_string(convert_to_json(std::string(attributes)));
v = from_json_string(convert_to_json(std::string(buffer.data())));
}
auto op = make_op(name, v);
return op;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
template <class T>
bool equal(const T& x, const T& y)
{
......@@ -185,6 +213,41 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx
<% generate_c_api_body() %>
......@@ -25,7 +25,8 @@ extern "C" {
#endif
// return code, more to be added later
typedef enum {
typedef enum
{
migraphx_status_success = 0,
migraphx_status_bad_param = 1,
migraphx_status_unknown_target = 3,
......@@ -35,32 +36,13 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef enum {
typedef enum
{
migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
/// Options to be passed when compiling
typedef struct
{
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
bool offload_copy;
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
bool fast_math;
} migraphx_compile_options;
/// Options for saving and loading files
typedef struct
{
/// Format to be used for file. It can either be json or msgpack
const char* format;
} migraphx_file_options;
<% generate_c_header() %>
#ifdef __cplusplus
......
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
SRC_DIR=$DIR/../src
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python3.6 $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $SRC_DIR/include/migraphx/{}"
PYTHON=python3
if type -p python3.6 > /dev/null ; then
PYTHON=python3.6
fi
if type -p python3.8 > /dev/null ; then
PYTHON=python3.8
fi
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-10 -style=file > $SRC_DIR/include/migraphx/{}"
function api {
python3.6 $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-5.0 -style=file > $2
$PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-10 -style=file > $2
}
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
echo "Finished generating header migraphx.h"
api $DIR/api/api.cpp $SRC_DIR/api/api.cpp
echo "Finished generating source api.cpp "
......@@ -9,6 +9,7 @@
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -33,12 +34,21 @@ value to_value_context(const T&)
}
template <class T>
void from_value_context(T&, const value&){}
void from_value_context(T&, const value&)
{
}
template <class T>
any_ptr get_queue_context(T&)
{
return {};
}
<%
interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx)
......
#ifndef MIGRAPHX_GUARD_MARKER_HPP
#define MIGRAPHX_GUARD_MARKER_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
/// Marker is an interface to general marking functions, such as rocTX markers.
#else
<%
interface('marker',
virtual('mark_start', ins_ref = 'instruction_ref', returns = 'void'),
virtual('mark_start', prog = 'const program&', returns = 'void'),
virtual('mark_stop', ins = 'instruction_ref', returns = 'void'),
virtual('mark_stop', prog = 'const program&', returns = 'void')
) %>
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -103,79 +103,69 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators
template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
auto compute_shape_op(rank<3>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs))
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
return x.compute_shape(inputs);
}
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
return normalize_compute_shape_op(rank<1>{}, x, inputs);
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
auto compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
auto compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {}))
{
return x.compute_shape(inputs, mod_args);
return x.compute_shape(inputs, {});
}
template <class T>
shape
compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
shape compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return compute_shape_op(rank<1>{}, x, inputs, mod_args);
return compute_shape_op(rank<3>{}, x, inputs);
}
template <class T>
auto normalize_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
-> decltype(x.normalize_compute_shape(inputs, mod_args))
auto mod_compute_shape_op(rank<1>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
-> decltype(x.compute_shape(inputs, mod_args))
{
return x.normalize_compute_shape(inputs, mod_args);
return x.compute_shape(inputs, mod_args);
}
template <class T>
shape normalize_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>&,
const std::vector<module_ref>&)
shape mod_compute_shape_op(rank<0>,
const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
if(mod_args.empty())
return compute_shape_op(x, inputs);
std::string name = x.name();
MIGRAPHX_THROW("Shape not computable: " + name);
}
template <class T>
shape normalize_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
std::vector<module_ref>& mod_args)
shape mod_compute_shape_op(const T& x,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args);
return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}
template <class T>
......@@ -256,6 +246,18 @@ argument compute_op(const T& x,
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<4>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
{
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<3>,
const T& x,
......@@ -313,7 +315,7 @@ argument compute_op(const T& x,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......@@ -476,13 +478,13 @@ lifetime get_lifetime_op(const T&)
returns = 'shape',
input = 'const std::vector<shape>&',
const = True,
default = 'detail::normalize_compute_shape_op'),
default = 'detail::compute_shape_op'),
virtual('compute_shape',
returns = 'shape',
inputs = 'const std::vector<shape>&',
mod_args = 'const std::vector<module_ref>&',
const = True,
default = 'detail::compute_shape_op'),
default = 'detail::mod_compute_shape_op'),
virtual('compute',
returns = 'argument',
ctx = 'context&',
......@@ -570,7 +572,7 @@ template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
-> decltype(op.normalize_compute_shape(inputs))
{
return detail::normalize_compute_shape_op(op, inputs);
return detail::compute_shape_op(op, inputs);
}
inline shape compute_shape(const operation& op,
......@@ -595,7 +597,7 @@ inline auto compute_shape(const T& op,
const std::vector<module_ref>& mod_args)
-> decltype(op.normalize_compute_shape(inputs, mod_args))
{
return detail::normalize_compute_shape_op(op, inputs, mod_args);
return detail::compute_shape_op(op, inputs, mod_args);
}
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
......
......@@ -8,12 +8,14 @@
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
struct module_pass_manager;
#ifdef DOXYGEN
......@@ -24,6 +26,7 @@ struct pass
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the module
void apply(module_pass_manager& mpm) const;
void apply(module& m) const;
/// Run the pass on the program
void apply(program& p) const;
......@@ -31,10 +34,34 @@ struct pass
#else
module& get_module(module_pass_manager& mpm);
namespace detail {
template <class T>
auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm)
-> decltype(x.apply(get_module(mpm)))
{
return x.apply(get_module(mpm));
}
template <class T>
void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&)
{
}
template <class T>
void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
{
module_pass_manager_apply(rank<1>{}, x, mpm);
}
} // namespace detail
<%
interface('pass',
virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', m='module &', const=True, default='migraphx::nop'),
virtual('apply', returns='void', mpm='module_pass_manager &', const=True, default='migraphx::detail::module_pass_manager_apply'),
virtual('apply', returns='void', p='program &', const=True, default='migraphx::nop')
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment