model.py 2.21 KB
Newer Older
VoVAllen's avatar
VoVAllen committed
1
2
3
import torch
from torch import nn

4
5
from DGLDigitCapsule import DGLDigitCapsuleLayer
from DGLRoutingLayer import squash
VoVAllen's avatar
VoVAllen committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


class Net(nn.Module):
    def __init__(self, device='cpu'):
        super(Net, self).__init__()
        self.device = device
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,
                                             out_channels=256,
                                             kernel_size=9,
                                             stride=1), nn.ReLU(inplace=True))

        self.primary = PrimaryCapsuleLayer(device=device)
        self.digits = DGLDigitCapsuleLayer(device=device)

    def forward(self, x):
        out_conv1 = self.conv1(x)
        out_primary_caps = self.primary(out_conv1)
        out_digit_caps = self.digits(out_primary_caps)
        return out_digit_caps

    def margin_loss(self, input, target):
        batch_s = target.size(0)
        one_hot_vec = torch.zeros(batch_s, 10).to(self.device)
        for i in range(batch_s):
            one_hot_vec[i, target[i]] = 1.0
        batch_size = input.size(0)
        v_c = torch.sqrt((input ** 2).sum(dim=2, keepdim=True))
        zero = torch.zeros(1).to(self.device)
        m_plus = 0.9
        m_minus = 0.1
        loss_lambda = 0.5
        max_left = torch.max(m_plus - v_c, zero).view(batch_size, -1) ** 2
        max_right = torch.max(v_c - m_minus, zero).view(batch_size, -1) ** 2
        t_c = one_hot_vec
        l_c = t_c * max_left + loss_lambda * (1.0 - t_c) * max_right
        l_c = l_c.sum(dim=1)
        return l_c.mean()


class PrimaryCapsuleLayer(nn.Module):

    def __init__(self, in_channel=256, num_unit=8, device='cpu'):
        super(PrimaryCapsuleLayer, self).__init__()
        self.in_channel = in_channel
        self.num_unit = num_unit
        self.deivce = device
        self.conv_units = nn.ModuleList([
            nn.Conv2d(self.in_channel, 32, 9, 2) for _ in range(self.num_unit)
        ])

    def forward(self, x):
        unit = [self.conv_units[i](x) for i, l in enumerate(self.conv_units)]
        unit = torch.stack(unit, dim=1)
        batch_size = x.size(0)
        unit = unit.view(batch_size, 8, -1)
        return squash(unit, dim=2)