Commit 031ccf5f authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent baac1dab
...@@ -46,6 +46,7 @@ def shape_type_wrap(p): ...@@ -46,6 +46,7 @@ def shape_type_wrap(p):
def auto_handle(*args, **kwargs): def auto_handle(*args, **kwargs):
def with_handle(f): def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__, return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f) *args, **kwargs)(f)
......
...@@ -38,6 +38,7 @@ pytest_plugins = 'onnx.backend.test.report', ...@@ -38,6 +38,7 @@ pytest_plugins = 'onnx.backend.test.report',
class MIGraphXBackendTest(onnx.backend.test.BackendTest): class MIGraphXBackendTest(onnx.backend.test.BackendTest):
def __init__(self, backend, parent_module=None): def __init__(self, backend, parent_module=None):
super(MIGraphXBackendTest, self).__init__(backend, parent_module) super(MIGraphXBackendTest, self).__init__(backend, parent_module)
......
...@@ -50,6 +50,7 @@ class Template(string.Template): ...@@ -50,6 +50,7 @@ class Template(string.Template):
class Type: class Type:
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
self.name = name.strip() self.name = name.strip()
...@@ -144,6 +145,7 @@ extern "C" ${error_type} ${name}(${params}) ...@@ -144,6 +145,7 @@ extern "C" ${error_type} ${name}(${params})
class CFunction: class CFunction:
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
self.name = name self.name = name
self.params: List[str] = [] self.params: List[str] = []
...@@ -188,12 +190,14 @@ class CFunction: ...@@ -188,12 +190,14 @@ class CFunction:
class BadParam: class BadParam:
def __init__(self, cond: str, msg: str) -> None: def __init__(self, cond: str, msg: str) -> None:
self.cond = cond self.cond = cond
self.msg = msg self.msg = msg
class Parameter: class Parameter:
def __init__(self, def __init__(self,
name: str, name: str,
type: str, type: str,
...@@ -250,7 +254,8 @@ class Parameter: ...@@ -250,7 +254,8 @@ class Parameter:
size=self.size_name, size=self.size_name,
result=result or '') result=result or '')
def add_param(self, t: Union[str, Type], def add_param(self,
t: Union[str, Type],
name: Optional[str] = None) -> None: name: Optional[str] = None) -> None:
if not isinstance(t, str): if not isinstance(t, str):
t = t.str() t = t.str()
...@@ -409,6 +414,7 @@ def to_template_vars(params: List[Union[Any, Parameter]]) -> str: ...@@ -409,6 +414,7 @@ def to_template_vars(params: List[Union[Any, Parameter]]) -> str:
class Function: class Function:
def __init__(self, def __init__(self,
name: str, name: str,
params: Optional[List[Parameter]] = None, params: Optional[List[Parameter]] = None,
...@@ -545,6 +551,7 @@ cpp_class_constructor_template = Template(''' ...@@ -545,6 +551,7 @@ cpp_class_constructor_template = Template('''
class CPPMember: class CPPMember:
def __init__(self, def __init__(self,
name: str, name: str,
function: Function, function: Function,
...@@ -621,6 +628,7 @@ class CPPMember: ...@@ -621,6 +628,7 @@ class CPPMember:
class CPPClass: class CPPClass:
def __init__(self, name: str, ctype: str) -> None: def __init__(self, name: str, ctype: str) -> None:
self.name = name self.name = name
self.ctype = ctype self.ctype = ctype
...@@ -677,6 +685,7 @@ def add_function(name: str, *args, **kwargs) -> Function: ...@@ -677,6 +685,7 @@ def add_function(name: str, *args, **kwargs) -> Function:
def once(f: Callable) -> Any: def once(f: Callable) -> Any:
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not decorated.has_run: if not decorated.has_run:
...@@ -722,6 +731,7 @@ c_type_map: Dict[str, Type] = {} ...@@ -722,6 +731,7 @@ c_type_map: Dict[str, Type] = {}
def cwrap(name: str, c_type: Optional[str] = None) -> Callable: def cwrap(name: str, c_type: Optional[str] = None) -> Callable:
def with_cwrap(f): def with_cwrap(f):
type_map[name] = f type_map[name] = f
if c_type: if c_type:
...@@ -1015,6 +1025,7 @@ def string_c_wrap(p: Parameter) -> None: ...@@ -1015,6 +1025,7 @@ def string_c_wrap(p: Parameter) -> None:
class Handle: class Handle:
def __init__(self, name: str, ctype: str, cpptype: str, **kwargs) -> None: def __init__(self, name: str, ctype: str, cpptype: str, **kwargs) -> None:
self.name = name self.name = name
self.ctype = ctype self.ctype = ctype
...@@ -1140,6 +1151,7 @@ def generate_virtual_impl(f: Function, fname: str) -> str: ...@@ -1140,6 +1151,7 @@ def generate_virtual_impl(f: Function, fname: str) -> str:
class Interface(Handle): class Interface(Handle):
def __init__(self, name: str, ctype: str, cpptype: str) -> None: def __init__(self, name: str, ctype: str, cpptype: str) -> None:
super().__init__(name, ctype, cpptype, skip_def=True) super().__init__(name, ctype, cpptype, skip_def=True)
self.ifunctions: List[Function] = [] self.ifunctions: List[Function] = []
...@@ -1234,6 +1246,7 @@ def handle(ctype: str, ...@@ -1234,6 +1246,7 @@ def handle(ctype: str,
cpptype: str, cpptype: str,
name: Optional[str] = None, name: Optional[str] = None,
ref: Optional[bool] = None) -> Callable: ref: Optional[bool] = None) -> Callable:
def with_handle(f): def with_handle(f):
n = name or f.__name__ n = name or f.__name__
h = Handle(n, ctype, cpptype, ref=ref) h = Handle(n, ctype, cpptype, ref=ref)
...@@ -1249,8 +1262,10 @@ def handle(ctype: str, ...@@ -1249,8 +1262,10 @@ def handle(ctype: str,
return with_handle return with_handle
def interface(ctype: str, cpptype: str, def interface(ctype: str,
cpptype: str,
name: Optional[str] = None) -> Callable: name: Optional[str] = None) -> Callable:
def with_interface(f): def with_interface(f):
n = name or f.__name__ n = name or f.__name__
h = Interface(n, ctype, cpptype) h = Interface(n, ctype, cpptype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment