Commit 08ac24cf authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rocblas_api_opt

parents 96c82f21 b20e3d4d
......@@ -142,10 +142,12 @@ jobs:
with:
python-version: 3.6
- name: Install pyflakes
run: pip install pyflakes==2.3.1
run: pip install pyflakes==2.3.1 mypy==0.931
- name: Run pyflakes
run: pyflakes examples/ tools/ src/ test/ doc/
run: |
pyflakes examples/ tools/ src/ test/ doc/
mypy tools/api.py
linux:
......
......@@ -89,6 +89,13 @@ inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
return x.index - y.index;
}
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x -= y;
}
template <class F, class Iterator>
inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
......
......@@ -163,9 +163,9 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
struct parse_resize : op_parser<parse_resize>
{
std::vector<op_desc> operators() const { return {{"Resize"}}; }
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
......@@ -183,7 +183,7 @@ struct parse_resize : op_parser<parse_resize>
if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1)
{
MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!");
}
// input data shape info
......@@ -215,12 +215,14 @@ struct parse_resize : op_parser<parse_resize>
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size");
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
}
// compute the scale
......@@ -239,12 +241,14 @@ struct parse_resize : op_parser<parse_resize>
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_RESIZE: dynamic input scale is not supported!");
"PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!");
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}
std::transform(in_lens.begin(),
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_upsample : op_parser<parse_upsample>
{
std::vector<op_desc> operators() const { return {{"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
}
}
auto arg_scale = args[1]->eval();
check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
std::vector<float> vec_scale;
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
}
std::vector<std::size_t> out_lens(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::vector<float> idx_scale(in_lens.size());
std::transform(
out_lens.begin(),
out_lens.end(),
in_lens.begin(),
idx_scale.begin(),
[](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
std::transform(idx.begin(),
idx.end(),
idx_scale.begin(),
in_idx.begin(),
// nearest mode
[](auto index, auto scale) {
return static_cast<std::size_t>(std::round(index * scale));
});
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -180,6 +180,63 @@ void program::finalize()
mm->finalize(this->impl->ctx);
}
template <class T>
std::string classify(T x)
{
switch(std::fpclassify(x))
{
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
std::unordered_set<std::string> classify_argument(const argument& a)
{
std::unordered_set<std::string> result;
a.visit(
[&](auto t) {
for(const auto& x : t)
result.insert(classify(x));
},
[&](const auto& xs) {
for(const auto& x : xs)
{
auto r = classify_argument(x);
result.insert(r.begin(), r.end());
}
});
return result;
}
void preview_argument(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
if(t.size() <= 10)
{
os << t;
}
else
{
os << to_string_range(t.begin(), t.begin() + 5);
os << ", ..., ";
os << to_string_range(t.end() - 5, t.end());
}
},
[&](const auto& xs) {
for(const auto& x : xs)
{
os << '{';
preview_argument(os, x);
os << '}';
}
});
}
template <class F>
std::vector<argument> generic_eval(const module* mod,
context& ctx,
......@@ -312,8 +369,21 @@ std::vector<argument> program::eval(parameter_map params) const
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
{
target tgt = make_target(this->impl->target_name);
std::cout << "Output: " << tgt.copy_from(result) << std::endl;
target tgt = make_target(this->impl->target_name);
auto buffer = tgt.copy_from(result);
if(trace_level == 2)
{
std::cout << "Output has "
<< to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
}));
......
......@@ -5074,6 +5074,25 @@ def unknown_aten_test():
return ([node], [x, y], [a])
@onnx_test
def upsample_linear_test():
scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32)
scales_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(
np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Upsample',
inputs=['X', '', 'scales'],
outputs=['Y'],
mode='linear')
return ([node], [X], [Y], [scales_tensor])
@onnx_test
def upsample_test():
scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32)
......
......@@ -3643,7 +3643,7 @@ TEST_CASE(resize_nonstd_input_test)
EXPECT(p == prog);
}
TEST_CASE(resize_upsample_linear_ac_test)
static auto create_upsample_linear_prog()
{
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -3734,6 +3734,12 @@ TEST_CASE(resize_upsample_linear_ac_test)
auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10);
mm->add_return({add1});
return p;
}
TEST_CASE(resize_upsample_linear_ac_test)
{
auto p = create_upsample_linear_prog();
auto prog = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx");
EXPECT(p == prog);
}
......@@ -4753,6 +4759,13 @@ TEST_CASE(unknown_test_throw)
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); }));
}
TEST_CASE(upsample_linear_test)
{
auto p = create_upsample_linear_prog();
auto prog = migraphx::parse_onnx("upsample_linear_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(upsample_test)
{
migraphx::program p;
......
import string, sys, re, runpy
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
type_map = {}
cpp_type_map = {}
functions = []
cpp_classes = []
type_map: Dict[str, Callable[['Parameter'], None]] = {}
cpp_type_map: Dict[str, str] = {}
functions: List['Function'] = []
cpp_classes: List['CPPClass'] = []
error_type = ''
success_type = ''
try_wrap = ''
c_header_preamble = []
c_api_body_preamble = []
cpp_header_preamble = []
c_header_preamble: List[str] = []
c_api_body_preamble: List[str] = []
cpp_header_preamble: List[str] = []
def bad_param_error(msg):
......@@ -23,31 +24,31 @@ class Template(string.Template):
class Type:
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name.strip()
def is_pointer(self):
def is_pointer(self) -> bool:
return self.name.endswith('*')
def is_reference(self):
def is_reference(self) -> bool:
return self.name.endswith('&')
def is_const(self):
def is_const(self) -> bool:
return self.name.startswith('const ')
def is_variadic(self):
return self.name.startswith('...')
def add_pointer(self):
def add_pointer(self) -> 'Type':
return Type(self.name + '*')
def add_reference(self):
return Type(self.name + '&')
def add_const(self):
def add_const(self) -> 'Type':
return Type('const ' + self.name)
def inner_type(self):
def inner_type(self) -> Optional['Type']:
i = self.name.find('<')
j = self.name.rfind('>')
if i > 0 and j > 0:
......@@ -55,7 +56,7 @@ class Type:
else:
return None
def remove_generic(self):
def remove_generic(self) -> 'Type':
i = self.name.find('<')
j = self.name.rfind('>')
if i > 0 and j > 0:
......@@ -63,25 +64,25 @@ class Type:
else:
return self
def remove_pointer(self):
def remove_pointer(self) -> 'Type':
if self.is_pointer():
return Type(self.name[0:-1])
return self
def remove_reference(self):
def remove_reference(self) -> 'Type':
if self.is_reference():
return Type(self.name[0:-1])
return self
def remove_const(self):
def remove_const(self) -> 'Type':
if self.is_const():
return Type(self.name[6:])
return self
def basic(self):
def basic(self) -> 'Type':
return self.remove_pointer().remove_const().remove_reference()
def decay(self):
def decay(self) -> 'Type':
t = self.remove_reference()
if t.is_pointer():
return t
......@@ -93,7 +94,7 @@ class Type:
return self.add_const()
return self
def str(self):
def str(self) -> str:
return self.name
......@@ -113,20 +114,20 @@ extern "C" ${error_type} ${name}(${params})
class CFunction:
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
self.params = []
self.body = []
self.va_start = []
self.va_end = []
self.params: List[str] = []
self.body: List[str] = []
self.va_start: List[str] = []
self.va_end: List[str] = []
def add_param(self, type, pname):
def add_param(self, type: str, pname: str) -> None:
self.params.append('{} {}'.format(type, pname))
def add_statement(self, stmt):
def add_statement(self, stmt: str) -> None:
self.body.append(stmt)
def add_vlist(self, name):
def add_vlist(self, name: str) -> None:
last_param = self.params[-1].split()[-1]
self.va_start = [
'va_list {};'.format(name),
......@@ -135,7 +136,7 @@ class CFunction:
self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '')
def substitute(self, form):
def substitute(self, form: Template) -> str:
return form.substitute(error_type=error_type,
try_wrap=try_wrap,
name=self.name,
......@@ -144,25 +145,29 @@ class CFunction:
va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end))
def generate_header(self):
def generate_header(self) -> str:
return self.substitute(header_function)
def generate_body(self):
def generate_body(self) -> str:
return self.substitute(c_api_impl)
class BadParam:
def __init__(self, cond, msg):
def __init__(self, cond: str, msg: str) -> None:
self.cond = cond
self.msg = msg
class Parameter:
def __init__(self, name, type, optional=False, returns=False):
def __init__(self,
name: str,
type: str,
optional: bool = False,
returns: bool = False) -> None:
self.name = name
self.type = Type(type)
self.optional = optional
self.cparams = []
self.cparams: List[Tuple[str, str]] = []
self.size_cparam = -1
self.size_name = ''
self.read = '${name}'
......@@ -170,15 +175,15 @@ class Parameter:
self.cpp_read = '${name}'
self.cpp_write = '${name}'
self.returns = returns
self.bad_param_check = None
self.bad_param_check: Optional[BadParam] = None
def get_name(self, prefix=None):
def get_name(self, prefix: Optional[str] = None) -> str:
if prefix:
return prefix + self.name
else:
return self.name
def get_cpp_type(self):
def get_cpp_type(self) -> str:
if self.type.str() in cpp_type_map:
return cpp_type_map[self.type.basic().str()]
elif self.type.basic().str() in cpp_type_map:
......@@ -188,7 +193,10 @@ class Parameter:
else:
return self.type.str()
def substitute(self, s, prefix=None, result=None):
def substitute(self,
s: str,
prefix: Optional[str] = None,
result: Optional[str] = None) -> str:
ctype = None
if len(self.cparams) > 0:
ctype = Type(self.cparams[0][0]).basic().str()
......@@ -199,12 +207,13 @@ class Parameter:
size=self.size_name,
result=result or '')
def add_param(self, t, name=None):
def add_param(self, t: Union[str, Type],
name: Optional[str] = None) -> None:
if not isinstance(t, str):
t = t.str()
self.cparams.append((t, name or self.name))
def add_size_param(self, name=None):
def add_size_param(self, name: Optional[str] = None) -> None:
self.size_cparam = len(self.cparams)
self.size_name = name or self.name + '_size'
if self.returns:
......@@ -212,7 +221,7 @@ class Parameter:
else:
self.add_param('size_t', self.size_name)
def bad_param(self, cond, msg):
def bad_param(self, cond: str, msg: str) -> None:
self.bad_param_check = BadParam(cond, msg)
def remove_size_param(self, name):
......@@ -223,7 +232,7 @@ class Parameter:
self.size_name = name
return p
def update(self):
def update(self) -> None:
t = self.type.basic().str()
g = self.type.remove_generic().basic().str()
if t in type_map:
......@@ -239,18 +248,18 @@ class Parameter:
raise ValueError("Error for {}: write cannot be a string".format(
self.type.str()))
def cpp_param(self, prefix=None):
def cpp_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${cpptype} ${name}', prefix=prefix)
def cpp_arg(self, prefix=None):
def cpp_arg(self, prefix: Optional[str] = None) -> str:
return self.substitute(self.cpp_read, prefix=prefix)
def cpp_output_args(self, prefix=None):
def cpp_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix, n=n) for t, n in self.cparams
]
def output_declarations(self, prefix=None):
def output_declarations(self, prefix: Optional[str] = None) -> List[str]:
return [
'{type} {prefix}{n};'.format(type=Type(t).remove_pointer().str(),
prefix=prefix,
......@@ -262,16 +271,16 @@ class Parameter:
'&{prefix}{n};'.format(prefix=prefix, n=n) for t, n in self.cparams
]
def cpp_output(self, prefix=None):
def cpp_output(self, prefix: Optional[str] = None) -> str:
return self.substitute(self.cpp_write, prefix=prefix)
def input(self, prefix=None):
def input(self, prefix: Optional[str] = None) -> str:
return '(' + self.substitute(self.read, prefix=prefix) + ')'
def outputs(self, result=None):
def outputs(self, result: Optional[str] = None) -> List[str]:
return [self.substitute(w, result=result) for w in self.write]
def add_to_cfunction(self, cfunction):
def add_to_cfunction(self, cfunction: CFunction) -> None:
for t, name in self.cparams:
if t.startswith('...'):
cfunction.add_vlist(name)
......@@ -285,35 +294,35 @@ class Parameter:
body=bad_param_error(msg)))
def template_var(s):
def template_var(s: str) -> str:
return '${' + s + '}'
def to_template_vars(params):
def to_template_vars(params: List[Union[Any, Parameter]]) -> str:
return ', '.join([template_var(p.name) for p in params])
class Function:
def __init__(self,
name,
params=None,
shared_size=False,
returns=None,
invoke=None,
fname=None,
return_name=None,
**kwargs):
name: str,
params: Optional[List[Parameter]] = None,
shared_size: bool = False,
returns: Optional[str] = None,
invoke: Optional[str] = None,
fname: Optional[str] = None,
return_name: Optional[str] = None,
**kwargs) -> None:
self.name = name
self.params = params or []
self.shared_size = False
self.cfunction = None
self.cfunction: Optional[CFunction] = None
self.fname = fname
self.invoke = invoke or '${__fname__}($@)'
self.return_name = return_name or 'out'
self.returns = Parameter(self.return_name, returns,
returns=True) if returns else None
def share_params(self):
def share_params(self) -> None:
if self.shared_size == True:
size_param_name = 'size'
size_type = Type('size_t')
......@@ -323,7 +332,7 @@ class Function:
size_type = Type(p[0])
self.params.append(Parameter(size_param_name, size_type.str()))
def update(self):
def update(self) -> None:
self.share_params()
for param in self.params:
param.update()
......@@ -331,11 +340,12 @@ class Function:
self.returns.update()
self.create_cfunction()
def inputs(self):
def inputs(self) -> str:
return ', '.join([p.input() for p in self.params])
def input_map(self):
m = {}
# TODO: Shoule we remove Optional?
def input_map(self) -> Dict[str, Optional[str]]:
m: Dict[str, Optional[str]] = {}
for p in self.params:
m[p.name] = p.input()
m['return'] = self.return_name
......@@ -343,14 +353,22 @@ class Function:
m['__fname__'] = self.fname
return m
def get_invoke(self):
def get_invoke(self) -> str:
return Template(self.invoke).safe_substitute(self.input_map())
def write_to_tmp_var(self):
def write_to_tmp_var(self) -> bool:
if not self.returns:
return False
return len(self.returns.write) > 1 or self.returns.write[0].count(
'${result}') > 1
def create_cfunction(self):
def get_cfunction(self) -> CFunction:
if self.cfunction:
return self.cfunction
raise Exception(
"self.cfunction is None: self.update() needs to be called.")
def create_cfunction(self) -> None:
self.cfunction = CFunction(self.name)
# Add the return as a parameter
if self.returns:
......@@ -358,12 +376,12 @@ class Function:
# Add the input parameters
for param in self.params:
param.add_to_cfunction(self.cfunction)
f = self.get_invoke()
f: Optional[str] = self.get_invoke()
# Write the assignments
assigns = []
if self.returns:
result = f
if self.write_to_tmp_var():
if self.write_to_tmp_var() and f:
f = 'auto&& api_result = ' + f
result = 'api_result'
else:
......@@ -416,31 +434,37 @@ cpp_class_constructor_template = Template('''
class CPPMember:
def __init__(self, name, function, prefix, method=True):
def __init__(self,
name: str,
function: Function,
prefix: str,
method: bool = True) -> None:
self.name = name
self.function = function
self.prefix = prefix
self.method = method
def get_function_params(self):
def get_function_params(self) -> List[Union[Any, Parameter]]:
if self.method:
return self.function.params[1:]
else:
return self.function.params
def get_args(self):
def get_args(self) -> str:
output_args = []
if self.function.returns:
output_args = self.function.returns.cpp_output_args(self.prefix)
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
return ', '.join(
['&{}'.format(self.function.cfunction.name)] + output_args +
[p.cpp_arg(self.prefix) for p in self.get_function_params()])
def get_params(self):
def get_params(self) -> str:
return ', '.join(
[p.cpp_param(self.prefix) for p in self.get_function_params()])
def get_return_declarations(self):
def get_return_declarations(self) -> str:
if self.function.returns:
return '\n '.join([
d
......@@ -452,7 +476,9 @@ class CPPMember:
def get_result(self):
return self.function.returns.input(self.prefix)
def generate_method(self):
def generate_method(self) -> str:
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
if self.function.returns:
return_type = self.function.returns.get_cpp_type()
return cpp_class_method_template.safe_substitute(
......@@ -472,7 +498,9 @@ class CPPMember:
args=self.get_args(),
success=success_type)
def generate_constructor(self, name):
def generate_constructor(self, name: str) -> str:
if not self.function.cfunction:
raise Exception('self.function.update() must be called')
return cpp_class_constructor_template.safe_substitute(
name=name,
cfunction=self.function.cfunction.name,
......@@ -482,98 +510,101 @@ class CPPMember:
class CPPClass:
def __init__(self, name, ctype):
def __init__(self, name: str, ctype: str) -> None:
self.name = name
self.ctype = ctype
self.constructors = []
self.methods = []
self.constructors: List[CPPMember] = []
self.methods: List[CPPMember] = []
self.prefix = 'p'
def add_method(self, name, f):
def add_method(self, name: str, f: Function) -> None:
self.methods.append(CPPMember(name, f, self.prefix, method=True))
def add_constructor(self, name, f):
def add_constructor(self, name: str, f: Function) -> None:
self.constructors.append(CPPMember(name, f, self.prefix, method=True))
def generate_methods(self):
def generate_methods(self) -> str:
return '\n '.join([m.generate_method() for m in self.methods])
def generate_constructors(self):
def generate_constructors(self) -> str:
return '\n '.join(
[m.generate_constructor(self.name) for m in self.constructors])
def substitute(self, s, **kwargs):
t = s
if isinstance(s, str):
t = string.Template(s)
def substitute(self, s: Union[string.Template, str], **kwargs) -> str:
t = string.Template(s) if isinstance(s, str) else s
destroy = self.ctype + '_destroy'
return t.safe_substitute(name=self.name,
ctype=self.ctype,
destroy=destroy,
**kwargs)
def generate(self):
def generate(self) -> str:
return self.substitute(
cpp_class_template,
constructors=self.substitute(self.generate_constructors()),
methods=self.substitute(self.generate_methods()))
def params(virtual=None, **kwargs):
def params(virtual: Optional[Dict[str, str]] = None,
**kwargs) -> List[Parameter]:
result = []
for name in virtual or {}:
result.append(Parameter(name, virtual[name]))
v: Dict[str, str] = virtual or {}
for name in v:
result.append(Parameter(name, v[name]))
for name in kwargs:
result.append(Parameter(name, kwargs[name]))
return result
def add_function(name, *args, **kwargs):
def add_function(name: str, *args, **kwargs) -> Function:
f = Function(name, *args, **kwargs)
functions.append(f)
return f
def once(f):
def once(f: Callable) -> Any:
@wraps(f)
def decorated(*args, **kwargs):
if not decorated.has_run:
decorated.has_run = True
return f(*args, **kwargs)
decorated.has_run = False
return decorated
d: Any = decorated
d.has_run = False
return d
@once
def process_functions():
def process_functions() -> None:
for f in functions:
f.update()
def generate_lines(p):
def generate_lines(p: List[str]) -> str:
return '\n'.join(p)
def generate_c_header():
def generate_c_header() -> str:
process_functions()
return generate_lines(c_header_preamble +
[f.cfunction.generate_header() for f in functions])
return generate_lines(
c_header_preamble +
[f.get_cfunction().generate_header() for f in functions])
def generate_c_api_body():
def generate_c_api_body() -> str:
process_functions()
return generate_lines(c_api_body_preamble +
[f.cfunction.generate_body() for f in functions])
return generate_lines(
c_api_body_preamble +
[f.get_cfunction().generate_body() for f in functions])
def generate_cpp_header():
def generate_cpp_header() -> str:
process_functions()
return generate_lines(cpp_header_preamble +
[c.generate() for c in cpp_classes])
def cwrap(name):
def cwrap(name: str) -> Callable:
def with_cwrap(f):
type_map[name] = f
......@@ -677,13 +708,17 @@ protected:
@once
def add_handle_preamble():
def add_handle_preamble() -> None:
c_api_body_preamble.append(handle_preamble)
cpp_header_preamble.append(
string.Template(cpp_handle_preamble).substitute(success=success_type))
def add_handle(name, ctype, cpptype, destroy=None, ref=None):
def add_handle(name: str,
ctype: str,
cpptype: str,
destroy: Optional[str] = None,
ref: Optional[bool] = None) -> None:
opaque_type = ctype + '_t'
def handle_wrap(p):
......@@ -718,8 +753,12 @@ def add_handle(name, ctype, cpptype, destroy=None, ref=None):
@cwrap('std::vector')
def vector_c_wrap(p):
t = p.type.inner_type().add_pointer()
def vector_c_wrap(p: Parameter) -> None:
inner = p.type.inner_type()
# Not a generic type
if not inner:
return
t = inner.add_pointer()
if p.returns:
if p.type.is_reference():
if p.type.is_const():
......@@ -747,7 +786,7 @@ def vector_c_wrap(p):
@cwrap('std::string')
def string_c_wrap(p):
def string_c_wrap(p: Parameter) -> None:
t = Type('char*')
if p.returns:
if p.type.is_reference():
......@@ -771,7 +810,11 @@ def string_c_wrap(p):
class Handle:
def __init__(self, name, ctype, cpptype, ref=None):
def __init__(self,
name: str,
ctype: str,
cpptype: str,
ref: Optional[bool] = None) -> None:
self.name = name
self.ctype = ctype
self.cpptype = cpptype
......@@ -779,17 +822,21 @@ class Handle:
add_handle(name, ctype, cpptype, ref=ref)
cpp_type_map[cpptype] = name
def cname(self, name):
def cname(self, name: str) -> str:
return self.ctype + '_' + name
def substitute(self, s, **kwargs):
def substitute(self, s: str, **kwargs) -> str:
return Template(s).safe_substitute(name=self.name,
ctype=self.ctype,
cpptype=self.cpptype,
**kwargs)
def constructor(self, name, params=None, fname=None, invoke=None,
**kwargs):
def constructor(self,
name: str,
params: Optional[List[Parameter]] = None,
fname: Optional[str] = None,
invoke: Optional[str] = None,
**kwargs) -> 'Handle':
create = self.substitute('allocate<${cpptype}>($@)')
if fname:
create = self.substitute('allocate<${cpptype}>(${fname}($@))',
......@@ -805,13 +852,13 @@ class Handle:
return self
def method(self,
name,
params=None,
fname=None,
invoke=None,
cpp_name=None,
const=None,
**kwargs):
name: str,
params: Optional[List[Parameter]] = None,
fname: Optional[str] = None,
invoke: Optional[str] = None,
cpp_name: Optional[str] = None,
const: Optional[bool] = None,
**kwargs) -> 'Handle':
cpptype = self.cpptype
if const:
cpptype = Type(cpptype).add_const().str()
......@@ -832,11 +879,14 @@ class Handle:
add_function(self.cname(name), params=params, **kwargs)
return self
def add_cpp_class(self):
def add_cpp_class(self) -> None:
cpp_classes.append(self.cpp_class)
def handle(ctype, cpptype, name=None, ref=None):
def handle(ctype: str,
cpptype: str,
name: Optional[str] = None,
ref: Optional[bool] = None) -> Callable:
def with_handle(f):
n = name or f.__name__
h = Handle(n, ctype, cpptype, ref=ref)
......@@ -865,10 +915,10 @@ def template_eval(template, **kwargs):
return template
def run():
runpy.run_path(sys.argv[1])
if len(sys.argv) > 2:
f = open(sys.argv[2]).read()
def run(args: List[str]) -> None:
runpy.run_path(args[0])
if len(args) > 1:
f = open(args[1]).read()
r = template_eval(f)
sys.stdout.write(r)
else:
......@@ -879,4 +929,4 @@ def run():
if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
run()
run(sys.argv[1:])
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
SRC_DIR=$DIR/../src
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python3.6 $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $SRC_DIR/include/migraphx/{}"
PYTHON=python3
if type -p python3.6 > /dev/null ; then
PYTHON=python3.6
fi
if type -p python3.8 > /dev/null ; then
PYTHON=python3.8
fi
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $SRC_DIR/include/migraphx/{}"
function api {
python3.6 $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-5.0 -style=file > $2
$PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-5.0 -style=file > $2
}
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment