Unverified Commit bb397480 authored by kalineid's avatar kalineid Committed by GitHub
Browse files

Add tests for GraphConverterWithShape (#3951)

parent 403195f0
import torch
from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape
class ConvertMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
return model_ir
class ConvertWithShapeMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input)
return model_ir
......@@ -13,9 +13,10 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class MnistNet(nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
......@@ -48,7 +49,7 @@ class Linear(nn.Module):
out = self.linear(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
return result
def checkExportImport(self, model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
exec_vars = {}
......@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(model, (x,))
finally:
remove_inject_pytorch_nn()
class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
......@@ -9,12 +9,13 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script
# following pytorch v1.7.1
class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
return result
def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
......@@ -280,3 +280,7 @@ class TestConvert(unittest.TestCase):
out1 = x.ceil()
return out1
self.checkExportImport(SimpleOp(), (torch.randn(4), ))
class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
......@@ -10,11 +10,12 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestModels(unittest.TestCase):
class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
return result
def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
......@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
model = Net(4)
x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], ))
class TestModelsWithShape(TestModels, ConvertWithShapeMixin):
pass
......@@ -15,13 +15,14 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
# following pytorch v1.7.1
class TestOperators(unittest.TestCase):
class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
return result
def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
#print(model_code)
......@@ -1386,3 +1386,6 @@ class TestOperators(unittest.TestCase):
x = torch.randn(20, 5, 10, 10)
self.checkExportImport(SimpleOp(), (x, ))
class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin):
pass
......@@ -15,11 +15,12 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestPytorch(unittest.TestCase):
class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
return result
def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
......@@ -1231,3 +1231,6 @@ class TestPytorch(unittest.TestCase):
x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, ))
class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin):
pass
import unittest
import torch
import nni.retiarii.nn.pytorch as nn
from .convert_mixin import ConvertWithShapeMixin
class TestShape(unittest.TestCase, ConvertWithShapeMixin):
def test_simple_convnet(self):
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 1, 3)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.relu(self.conv(x)))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0]
relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0]
pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0]
self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]])
def test_nested_module(self):
class ConvRelu(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 1, 3)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = ConvRelu()
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.conv(x))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
# check if shape propagation works
cell_node = model_ir.get_nodes_by_type('_cell')[0]
self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
def test_layerchoice(self):
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.LayerChoice([
nn.Conv2d(3, 1, 3),
nn.Conv2d(3, 1, 5, padding=1),
])
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.conv(x))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
# check shape info of each candidates
conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')
self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment