bundler.py 981 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F


class Bundler(nn.Module):
    """
    Bundler, which will be the node_apply function in DGL paradigm
    """
10

11
12
13
    def __init__(self, in_feats, out_feats, activation, dropout, bias=True):
        super(Bundler, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
14
        self.linear = nn.Linear(in_feats * 2, out_feats, bias)
15
16
17
18
19
20
        self.activation = activation

        nn.init.xavier_uniform_(self.linear.weight,
                                gain=nn.init.calculate_gain('relu'))

    def concat(self, h, aggre_result):
21
        bundle = torch.cat((h, aggre_result), 1)
22
23
24
25
26
27
28
29
30
31
        bundle = self.linear(bundle)
        return bundle

    def forward(self, node):
        h = node.data['h']
        c = node.data['c']
        bundle = self.concat(h, c)
        bundle = F.normalize(bundle, p=2, dim=1)
        if self.activation:
            bundle = self.activation(bundle)
32
        return {"h": bundle}