"vscode:/vscode.git/clone" did not exist on "0c6162e0d2be8fa5544cb7cf1b17df34de08b5a5"
Unverified Commit 5fe24500 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] support torch 1.8 and 1.9 (#3937)

parent 542a660d
......@@ -12,7 +12,7 @@ In this quick start tutorial, we use multi-trial NAS as an example to show how t
One-shot NAS tutorial can be found `here <./OneshotTrainer.rst>`__.
.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 and 1.7**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan.
.. note:: Currently, PyTorch is the only supported framework by Retiarii, and we have only tested with **PyTorch 1.6 to 1.9**. This documentation assumes PyTorch context but it should also apply to other frameworks, that is in our future plan.
Define your Model Space
-----------------------
......
......@@ -59,7 +59,7 @@ class PrimConstant(PyTorchOperation):
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] == 'None':
if self.parameters['type'] in ['None', 'NoneType']:
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'):
return f'{output} = {self.parameters["value"]}'
......@@ -238,7 +238,13 @@ class AtenIndex(PyTorchOperation):
ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')]
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
'aten::view': [('size', 'List[int]', 'None')],
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
# torch.std(input, unbiased) Tensor
'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')]
}
TensorOpExceptions = {
......@@ -426,4 +432,11 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
\ No newline at end of file
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
class AtenDet(PyTorchOperation):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = torch.det({inputs[0]})'
\ No newline at end of file
......@@ -375,7 +375,7 @@ class TestPytorch(unittest.TestCase):
# NOTE: torch script gets an incorrect graph...
def test_optional_inputs_with_mixed_optionals(self):
class MixedModel(nn.Module):
def forward(self, x: 'Tensor', y: 'Tensor', z: 'Tensor'):
def forward(self, x, y, z):
if y is not None:
return x + y
if z is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment