"driver/driver.cpp" did not exist on "aa0199a31ca262f1a62746dc08e54ee6dc71fd5c"
Unverified Commit 08fe2924 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #3938 from microsoft/nn-meter

[DO NOT SQUASH] Support nn-Meter in Retiarii framework
parents 3bce6926 5e04d56c
......@@ -111,6 +111,33 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, i, 3, 3]))
def test_nested_layer_choice(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.LayerChoice([
nn.LayerChoice([nn.Conv2d(3, 3, kernel_size=1),
nn.Conv2d(3, 4, kernel_size=1),
nn.Conv2d(3, 5, kernel_size=1)]),
nn.Conv2d(3, 1, kernel_size=1)
])
def forward(self, x):
return self.module(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 3, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 1, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 5, 5, 5]))
def test_input_choice(self):
@self.get_serializer()
class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment