"testing/vscode:/vscode.git/clone" did not exist on "5f202fe5c1d63a5e3a1598690877eccff2ad4640"
Unverified Commit 0f88b86b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Retiarii graph and code generation test (#3231)

parent 4fae3ed9
......@@ -30,6 +30,10 @@ BasicOpsPT = {
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::Bool': 'Bool',
'aten::empty': 'Empty',
'aten::zeros': 'Zeros',
'aten::chunk': 'Chunk',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
......
......@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return f'{output} = {value}'
elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'prim::GetAttr':
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__':
......@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add':
assert len(inputs) == 2
return f'{output} = {inputs[0]} + {inputs[1]}'
return f'{output} = ' + ' + '.join(inputs)
elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
......@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
elif self.type == 'aten::Bool':
return f'{output} = bool({inputs[0]})'
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
......
......@@ -27,6 +27,11 @@ def get_records():
return _records
def clear_records():
global _records
_records = {}
def add_record(key, value):
"""
"""
......@@ -56,7 +61,7 @@ def _blackbox_cls(cls, module_name, register_format=None):
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list)):
if not isinstance(full_args[k], (int, float, str, dict, list, tuple)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment