"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "ee390c0b552e2af13a52672a667ff75a934aab6d"
Unverified Commit bb397480 authored by kalineid's avatar kalineid Committed by GitHub
Browse files

Add tests for GraphConverterWithShape (#3951)

parent 403195f0
import torch
from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape
class ConvertMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
return model_ir
class ConvertWithShapeMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input)
return model_ir
...@@ -13,9 +13,10 @@ import torchvision ...@@ -13,9 +13,10 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class MnistNet(nn.Module): class MnistNet(nn.Module):
def __init__(self): def __init__(self):
super(MnistNet, self).__init__() super(MnistNet, self).__init__()
...@@ -48,7 +49,7 @@ class Linear(nn.Module): ...@@ -48,7 +49,7 @@ class Linear(nn.Module):
out = self.linear(input.view(size[0] * size[1], -1)) out = self.linear(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1) return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase): class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod @staticmethod
def _match_state_dict(current_values, expected_format): def _match_state_dict(current_values, expected_format):
result = {} result = {}
...@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase): ...@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
return result return result
def checkExportImport(self, model, input): def checkExportImport(self, model, input):
script_module = torch.jit.script(model) model_ir = self._convert_model(model, input)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
exec_vars = {} exec_vars = {}
...@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase): ...@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(model, (x,)) self.checkExportImport(model, (x,))
finally: finally:
remove_inject_pytorch_nn() remove_inject_pytorch_nn()
class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
...@@ -9,12 +9,13 @@ import torchvision ...@@ -9,12 +9,13 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
# following pytorch v1.7.1 # following pytorch v1.7.1
class TestConvert(unittest.TestCase): class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod @staticmethod
def _match_state_dict(current_values, expected_format): def _match_state_dict(current_values, expected_format):
result = {} result = {}
...@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase): ...@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
return result return result
def checkExportImport(self, model, input, check_value=True): def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model) model_ir = self._convert_model(model, input)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
print(model_code) print(model_code)
...@@ -188,7 +188,7 @@ class TestConvert(unittest.TestCase): ...@@ -188,7 +188,7 @@ class TestConvert(unittest.TestCase):
out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2) out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2)
return out1, out2 return out1, out2
self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), )) self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), ))
def test_basic_addr(self): def test_basic_addr(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
def forward(self, x, y, z): def forward(self, x, y, z):
...@@ -204,7 +204,7 @@ class TestConvert(unittest.TestCase): ...@@ -204,7 +204,7 @@ class TestConvert(unittest.TestCase):
out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False) out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False)
return out1, out2 return out1, out2
self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), )) self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), ))
def test_basic_angle(self): def test_basic_angle(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -229,7 +229,7 @@ class TestConvert(unittest.TestCase): ...@@ -229,7 +229,7 @@ class TestConvert(unittest.TestCase):
o4 = x.argmin(dim=1, keepdim=True) o4 = x.argmin(dim=1, keepdim=True)
return out1, out2, out3, out4, out5, o1, o2, o3, o4 return out1, out2, out3, out4, out5, o1, o2, o3, o4
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), )) self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))
def test_basic_argsort(self): def test_basic_argsort(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -241,7 +241,7 @@ class TestConvert(unittest.TestCase): ...@@ -241,7 +241,7 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(SimpleOp(), (torch.randn(4, 4), )) self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))
# skip backward(gradient=None, retain_graph=None, create_graph=False) # skip backward(gradient=None, retain_graph=None, create_graph=False)
def test_basic_bernoulli(self): def test_basic_bernoulli(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -261,7 +261,7 @@ class TestConvert(unittest.TestCase): ...@@ -261,7 +261,7 @@ class TestConvert(unittest.TestCase):
out4 = x.bincount(weights=y, minlength=2) out4 = x.bincount(weights=y, minlength=2)
return out1, out2, out3, out4 return out1, out2, out3, out4
self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), )) self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), ))
def test_basic_bitwise(self): def test_basic_bitwise(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
def forward(self, x, y): def forward(self, x, y):
...@@ -279,4 +279,8 @@ class TestConvert(unittest.TestCase): ...@@ -279,4 +279,8 @@ class TestConvert(unittest.TestCase):
def forward(self, x): def forward(self, x):
out1 = x.ceil() out1 = x.ceil()
return out1 return out1
self.checkExportImport(SimpleOp(), (torch.randn(4), )) self.checkExportImport(SimpleOp(), (torch.randn(4), ))
\ No newline at end of file
class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
...@@ -10,11 +10,12 @@ import torchvision ...@@ -10,11 +10,12 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestModels(unittest.TestCase):
class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod @staticmethod
def _match_state_dict(current_values, expected_format): def _match_state_dict(current_values, expected_format):
result = {} result = {}
...@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase): ...@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
return result return result
def run_test(self, model, input, check_value=True): def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model) model_ir = self._convert_model(model, input)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
print(model_code) print(model_code)
...@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase): ...@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
model = Net(4) model = Net(4)
x = torch.rand((1, 16), dtype=torch.float) x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], )) self.run_test(model, ([x], ))
class TestModelsWithShape(TestModels, ConvertWithShapeMixin):
pass
...@@ -15,13 +15,14 @@ import torch.nn.functional as F ...@@ -15,13 +15,14 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
# following pytorch v1.7.1 # following pytorch v1.7.1
class TestOperators(unittest.TestCase): class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod @staticmethod
def _match_state_dict(current_values, expected_format): def _match_state_dict(current_values, expected_format):
result = {} result = {}
...@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase): ...@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
return result return result
def checkExportImport(self, model, input, check_value=True): def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model) model_ir = self._convert_model(model, input)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
#print(model_code) #print(model_code)
...@@ -1042,7 +1042,7 @@ class TestOperators(unittest.TestCase): ...@@ -1042,7 +1042,7 @@ class TestOperators(unittest.TestCase):
x = torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True) x = torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_batchnorm(self): def test_basic_batchnorm(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1056,7 +1056,7 @@ class TestOperators(unittest.TestCase): ...@@ -1056,7 +1056,7 @@ class TestOperators(unittest.TestCase):
x = torch.ones(2, 2, 2, 2, requires_grad=True) x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_batchnorm_1d(self): def test_basic_batchnorm_1d(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1084,7 +1084,7 @@ class TestOperators(unittest.TestCase): ...@@ -1084,7 +1084,7 @@ class TestOperators(unittest.TestCase):
x = torch.ones(20, 16, 50, 40, requires_grad=True) x = torch.ones(20, 16, 50, 40, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_conv_onnx_irv4_opset8(self): def test_conv_onnx_irv4_opset8(self):
# This test point checks that for opset 8 (or lower), even if # This test point checks that for opset 8 (or lower), even if
# keep_initializers_as_inputs is set to False, it is ignored, # keep_initializers_as_inputs is set to False, it is ignored,
...@@ -1129,7 +1129,7 @@ class TestOperators(unittest.TestCase): ...@@ -1129,7 +1129,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(20, 16, 50) x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_maxpool_dilations(self): def test_basic_maxpool_dilations(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1143,7 +1143,7 @@ class TestOperators(unittest.TestCase): ...@@ -1143,7 +1143,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(20, 16, 50) x = torch.randn(20, 16, 50)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_avg_pool2d(self): def test_basic_avg_pool2d(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1157,7 +1157,7 @@ class TestOperators(unittest.TestCase): ...@@ -1157,7 +1157,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(20, 16, 50, 32) x = torch.randn(20, 16, 50, 32)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip('jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"') @unittest.skip('jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"')
def test_basic_maxpool_indices(self): def test_basic_maxpool_indices(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1200,7 +1200,7 @@ class TestOperators(unittest.TestCase): ...@@ -1200,7 +1200,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True) x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_elu(self): def test_basic_elu(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1214,7 +1214,7 @@ class TestOperators(unittest.TestCase): ...@@ -1214,7 +1214,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True) x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_selu(self): def test_basic_selu(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1261,7 +1261,7 @@ class TestOperators(unittest.TestCase): ...@@ -1261,7 +1261,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(128, 128, 1, 1, requires_grad=True) x = torch.randn(128, 128, 1, 1, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_embedding_bags(self): def test_embedding_bags(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
def __init__(self): def __init__(self):
...@@ -1288,7 +1288,7 @@ class TestOperators(unittest.TestCase): ...@@ -1288,7 +1288,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(1, 2, 3, 4) x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_prelu(self): def test_basic_prelu(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1302,7 +1302,7 @@ class TestOperators(unittest.TestCase): ...@@ -1302,7 +1302,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(1, 2, 3, 4) x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_log_sigmoid(self): def test_basic_log_sigmoid(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1316,7 +1316,7 @@ class TestOperators(unittest.TestCase): ...@@ -1316,7 +1316,7 @@ class TestOperators(unittest.TestCase):
x = torch.randn(1, 2, 3, 4) x = torch.randn(1, 2, 3, 4)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_linear(self): def test_basic_linear(self):
class SimpleOp(nn.Module): class SimpleOp(nn.Module):
...@@ -1385,4 +1385,7 @@ class TestOperators(unittest.TestCase): ...@@ -1385,4 +1385,7 @@ class TestOperators(unittest.TestCase):
return out return out
x = torch.randn(20, 5, 10, 10) x = torch.randn(20, 5, 10, 10)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
\ No newline at end of file
class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin):
pass
...@@ -15,11 +15,12 @@ import torchvision ...@@ -15,11 +15,12 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestPytorch(unittest.TestCase):
class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod @staticmethod
def _match_state_dict(current_values, expected_format): def _match_state_dict(current_values, expected_format):
result = {} result = {}
...@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase): ...@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
return result return result
def run_test(self, model, input, check_value=True): def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model) model_ir = self._convert_model(model, input)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
print(model_code) print(model_code)
...@@ -1230,4 +1230,7 @@ class TestPytorch(unittest.TestCase): ...@@ -1230,4 +1230,7 @@ class TestPytorch(unittest.TestCase):
return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape) return torch.arange(input.size(0)), torch.arange(input.size(-1)), torch.ones(input.shape)
x = torch.randn(5, 3, 2) x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, )) self.run_test(SizeModel(10, 5), (x, ))
\ No newline at end of file
class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin):
pass
import unittest
import torch
import nni.retiarii.nn.pytorch as nn
from .convert_mixin import ConvertWithShapeMixin
class TestShape(unittest.TestCase, ConvertWithShapeMixin):
def test_simple_convnet(self):
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 1, 3)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.relu(self.conv(x)))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0]
relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0]
pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0]
self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]])
def test_nested_module(self):
class ConvRelu(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 1, 3)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = ConvRelu()
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.conv(x))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
# check if shape propagation works
cell_node = model_ir.get_nodes_by_type('_cell')[0]
self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
def test_layerchoice(self):
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.LayerChoice([
nn.Conv2d(3, 1, 3),
nn.Conv2d(3, 1, 5, padding=1),
])
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.conv(x))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
# check shape info of each candidates
conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')
self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment