bundler.py 971 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F


class Bundler(nn.Module):
    """
    Bundler, which will be the node_apply function in DGL paradigm
    """
10

11
12
13
    def __init__(self, in_feats, out_feats, activation, dropout, bias=True):
        super(Bundler, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
14
        self.linear = nn.Linear(in_feats * 2, out_feats, bias)
15
16
        self.activation = activation

17
18
19
        nn.init.xavier_uniform_(
            self.linear.weight, gain=nn.init.calculate_gain("relu")
        )
20
21

    def concat(self, h, aggre_result):
22
        bundle = torch.cat((h, aggre_result), 1)
23
24
25
26
        bundle = self.linear(bundle)
        return bundle

    def forward(self, node):
27
28
        h = node.data["h"]
        c = node.data["c"]
29
30
31
32
        bundle = self.concat(h, c)
        bundle = F.normalize(bundle, p=2, dim=1)
        if self.activation:
            bundle = self.activation(bundle)
33
        return {"h": bundle}