"vscode:/vscode.git/clone" did not exist on "3bde101099697dfd24eb95609b52ec9f9630e389"
capsule_model.py 1.75 KB
Newer Older
Allen Zhou's avatar
capsule  
Allen Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import dgl
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from original import CapsuleLayer


class DGLCapsuleLayer(CapsuleLayer):
    def __init__(self, in_units, in_channels, num_units, unit_size, use_routing=True):
        super(DGLCapsuleLayer, self).__init__(in_units, in_channels, num_units, unit_size, use_routing=True)
        self.g = dgl.DGLGraph(nx.from_numpy_matrix(np.ones((10, 10))))
        self.W = nn.Parameter(torch.randn(1, in_channels, num_units, unit_size, in_units))
        # self.node_features = nn.Parameter(torch.randn(()))

    def routing(self, x):
        batch_size = x.size(0)
        x = x.transpose(1, 2)
        x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)
        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)
        self.u_hat = u_hat
        self.node_feature = u_hat.clone().detach().transpose(0, 2).transpose(1, 2)
        self.g.set_n_repr({'ft': self.node_feature})

        self.edge_features = torch.zeros(100, 1)
        self.g.set_e_repr({'b_ij': self.edge_features})

        self.g.update_all(self.capsule_msg, self.capsule_reduce, lambda x: x)
        self.g.update_all(self.capsule_msg, self.capsule_reduce, lambda x: x)
        self.g.update_all(self.capsule_msg, self.capsule_reduce, lambda x: x)

        self.edge_features = self.edge_features + torch.dot(self.u_hat, self.node_feature)

    @staticmethod
    def capsule_msg(src, edge):
        return {'b_ij': edge['weight'], 'h': src['ft']}

    def capsule_reduce(self, node, msg):
        b_ij, h = msg
        b_ij_c, h_c = torch.cat(b_ij, dim=1), torch.cat(h, dim=1)
        c_i = F.softmax(b_ij_c, dim=1)
        s_j = torch.dot(c_i, self.u_hat)
        v_j = self.squash(s_j)