debug_mnist_pytorch.py 1.18 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

6
7
import torch

8
9
10
11
12

class _model(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = stem()
13
14
15
16
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
        self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
        self.softmax = torch.nn.Softmax()
17
18
19

    def forward(self, image):
        stem = self.stem(image)
20
        flatten = self.flatten(stem)
21
22
        fc1 = self.fc1(flatten)
        fc2 = self.fc2(fc1)
23
        softmax = self.softmax(fc2)
24
25
26
27
28
29
30
        return softmax



class stem(nn.Module):
    def __init__(self):
        super().__init__()
31
32
33
34
        self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
35
36
37
38
39
40
41

    def forward(self, *_inputs):
        conv1 = self.conv1(_inputs[0])
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        return pool2