# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import torch import torch.nn as nn import torch.nn.functional as F from nni.nas.pytorch.mutables import LayerChoice, InputChoice class NaiveSearchSpace(nn.Module): def __init__(self, test_case): super().__init__() self.test_case = test_case self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)]) self.pool = nn.MaxPool2d(2, 2) self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)], return_mask=True) self.conv3 = nn.Conv2d(16, 16, 1) self.skipconnect = InputChoice(n_candidates=1) self.skipconnect2 = InputChoice(n_candidates=2, return_mask=True) self.bn = nn.BatchNorm2d(16) self.gap = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(16, 10) def forward(self, x): bs = x.size(0) x = self.pool(F.relu(self.conv1(x))) x0, mask = self.conv2(x) self.test_case.assertEqual(mask.size(), torch.Size([2])) x1 = F.relu(self.conv3(x0)) _, mask = self.skipconnect2([x0, x1]) x0 = self.skipconnect([x0]) if x0 is not None: x1 += x0 x = self.pool(self.bn(x1)) self.test_case.assertEqual(mask.size(), torch.Size([2])) x = self.gap(x).view(bs, -1) x = self.fc(x) return x