capsule_dgl.py 1.13 KB
Newer Older
Allen Zhou's avatar
capsule  
Allen Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import dgl
import networkx as nx
import numpy as np
from torch import nn
import torch
import torch.nn.functional as F

def capsule_message(src, edge):
    return {'ft' : src['ft'], 'bij' : edge['b']}

class GATReduce(nn.Module):
    def __init__(self, attn_drop):
        super(GATReduce, self).__init__()
        self.attn_drop = attn_drop

    def forward(self, node, msgs):
        a = torch.unsqueeze(node['a'], 0)  # shape (1, 1)
        ft = torch.cat([torch.unsqueeze(m['ft'], 0) for m in msgs], dim=0) # shape (deg, D)
        # attention
        e = F.softmax(a, dim=0)
        if self.attn_drop != 0.0:
            e = F.dropout(e, self.attn_drop)
        return torch.sum(e * ft, dim=0) # shape (D,)

class Capsule(nn.Module):
    def __init__(self):
        super(Capsule, self).__init__()
        self.g = dgl.DGLGraph(nx.from_numpy_matrix(np.ones((10, 10))))

    def forward(self, node, msgs):
        a1 = torch.unsqueeze(node['a1'], 0)  # shape (1, 1)
        a2 = torch.cat([torch.unsqueeze(m['a2'], 0) for m in msgs], dim=0)  # shape (deg, 1)
        ft = torch.cat([torch.unsqueeze(m['ft'], 0) for m in msgs], dim=0)  # shape (deg, D)