debug_mnist_pytorch.py 1.39 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

6
import nni.nas.nn.pytorch
7

8
9
import torch

10
11
12
13
14

class _model(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = stem()
15
16
17
18
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
        self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
        self.softmax = torch.nn.Softmax()
19
        self._mapping_ = {'stem': None, 'flatten': None, 'fc1': None, 'fc2': None, 'softmax': None}
20
21
22

    def forward(self, image):
        stem = self.stem(image)
23
        flatten = self.flatten(stem)
24
25
        fc1 = self.fc1(flatten)
        fc2 = self.fc2(fc1)
26
        softmax = self.softmax(fc2)
27
28
29
30
31
32
33
        return softmax



class stem(nn.Module):
    def __init__(self):
        super().__init__()
34
35
36
37
        self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
38
        self._mapping_ = {'conv1': None, 'pool1': None, 'conv2': None, 'pool2': None}
39
40
41
42
43
44
45

    def forward(self, *_inputs):
        conv1 = self.conv1(_inputs[0])
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        return pool2