"ts/webui/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "5df75c33cb543cb55e7a1616bfa2a4a3416243b8"
Unverified Commit 5fe24500 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] support torch 1.8 and 1.9 (#3937)

parent 542a660d
...@@ -12,7 +12,7 @@ In this quick start tutorial, we use multi-trial NAS as an example to show how t ...@@ -12,7 +12,7 @@ In this quick start tutorial, we use multi-trial NAS as an example to show how t
One-shot NAS tutorial can be found `here <./OneshotTrainer.rst>`__. One-shot NAS tutorial can be found `here <./OneshotTrainer.rst>`__.
.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 and 1.7**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan. .. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 to 1.9**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan.
Define your Model Space Define your Model Space
----------------------- -----------------------
......
...@@ -59,7 +59,7 @@ class PrimConstant(PyTorchOperation): ...@@ -59,7 +59,7 @@ class PrimConstant(PyTorchOperation):
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant # TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types # TODO: deal with all the types
if self.parameters['type'] == 'None': if self.parameters['type'] in ['None', 'NoneType']:
return f'{output} = None' return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'):
return f'{output} = {self.parameters["value"]}' return f'{output} = {self.parameters["value"]}'
...@@ -238,7 +238,13 @@ class AtenIndex(PyTorchOperation): ...@@ -238,7 +238,13 @@ class AtenIndex(PyTorchOperation):
ManuallyChooseDef = { ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')], 'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')] 'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
'aten::view': [('size', 'List[int]', 'None')],
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
# torch.std(input, unbiased) Tensor
'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')]
} }
TensorOpExceptions = { TensorOpExceptions = {
...@@ -426,4 +432,11 @@ class AtenAvgpool2d(PyTorchOperation): ...@@ -426,4 +432,11 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason # NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d'] _ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})' return f'{output} = F.avg_pool2d({", ".join(inputs)})'
\ No newline at end of file
class AtenDet(PyTorchOperation):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = torch.det({inputs[0]})'
\ No newline at end of file
...@@ -375,7 +375,7 @@ class TestPytorch(unittest.TestCase): ...@@ -375,7 +375,7 @@ class TestPytorch(unittest.TestCase):
# NOTE: torch script gets an incorrect graph... # NOTE: torch script gets an incorrect graph...
def test_optional_inputs_with_mixed_optionals(self): def test_optional_inputs_with_mixed_optionals(self):
class MixedModel(nn.Module): class MixedModel(nn.Module):
def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'): def forward(self, x, y, z):
if y is not None: if y is not None:
return x + y return x + y
if z is not None: if z is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment