debug_mnist_pytorch.py 1.21 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

6
7
import nni.retiarii.nn.pytorch

8
9
import torch

10
11
12
13
14

class _model(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = stem()
15
16
17
18
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
        self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
        self.softmax = torch.nn.Softmax()
19
20
21

    def forward(self, image):
        stem = self.stem(image)
22
        flatten = self.flatten(stem)
23
24
        fc1 = self.fc1(flatten)
        fc2 = self.fc2(fc1)
25
        softmax = self.softmax(fc2)
26
27
28
29
30
31
32
        return softmax



class stem(nn.Module):
    def __init__(self):
        super().__init__()
33
34
35
36
        self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
37
38
39
40
41
42
43

    def forward(self, *_inputs):
        conv1 = self.conv1(_inputs[0])
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        return pool2