gnn.py 4.62 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
import dgl.function as fn
import numpy as np
3
4
import torch
import torch.nn as nn
5
import torch.nn.functional as F
6
7
from scipy.linalg import block_diag

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
9
from model.loss import EntropyLoss
from ..model_utils import masked_softmax
10

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
from .aggregator import LSTMAggregator, MaxPoolAggregator, MeanAggregator
12
13
14
15
16
17
18
19
from .bundler import Bundler


class GraphSageLayer(nn.Module):
    """
    GraphSage layer in Inductive learning paper by hamilton
    Here, graphsage layer is a reduced function in DGL framework
    """
20

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
21
22
23
24
25
26
27
28
29
30
    def __init__(
        self,
        in_feats,
        out_feats,
        activation,
        dropout,
        aggregator_type,
        bn=False,
        bias=True,
    ):
31
32
        super(GraphSageLayer, self).__init__()
        self.use_bn = bn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
33
34
35
        self.bundler = Bundler(
            in_feats, out_feats, activation, dropout, bias=bias
        )
36
37
38
        self.dropout = nn.Dropout(p=dropout)

        if aggregator_type == "maxpool":
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
39
40
41
            self.aggregator = MaxPoolAggregator(
                in_feats, in_feats, activation, bias
            )
42
43
44
45
46
47
48
        elif aggregator_type == "lstm":
            self.aggregator = LSTMAggregator(in_feats, in_feats)
        else:
            self.aggregator = MeanAggregator()

    def forward(self, g, h):
        h = self.dropout(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
49
50
        g.ndata["h"] = h
        if self.use_bn and not hasattr(self, "bn"):
51
52
            device = h.device
            self.bn = nn.BatchNorm1d(h.size()[1]).to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
53
        g.update_all(fn.copy_u(u="h", out="m"), self.aggregator, self.bundler)
54
55
        if self.use_bn:
            h = self.bn(h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
        h = g.ndata.pop("h")
57
58
59
60
61
62
63
        return h


class GraphSage(nn.Module):
    """
    Grahpsage network that concatenate several graphsage layer
    """
64

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
65
66
67
68
69
70
71
72
73
74
    def __init__(
        self,
        in_feats,
        n_hidden,
        n_classes,
        n_layers,
        activation,
        dropout,
        aggregator_type,
    ):
75
76
77
        super(GraphSage, self).__init__()
        self.layers = nn.ModuleList()

78
        # input layer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
79
80
81
82
83
        self.layers.append(
            GraphSageLayer(
                in_feats, n_hidden, activation, dropout, aggregator_type
            )
        )
84
        # hidden layers
85
        for _ in range(n_layers - 1):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
86
87
88
89
90
            self.layers.append(
                GraphSageLayer(
                    n_hidden, n_hidden, activation, dropout, aggregator_type
                )
            )
91
        # output layer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
92
93
94
        self.layers.append(
            GraphSageLayer(n_hidden, n_classes, None, dropout, aggregator_type)
        )
95
96
97
98
99
100
101

    def forward(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        return h

102

103
class DiffPoolBatchedGraphLayer(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
105
106
107
108
109
110
111
112
113
    def __init__(
        self,
        input_dim,
        assign_dim,
        output_feat_dim,
        activation,
        dropout,
        aggregator_type,
        link_pred,
    ):
114
115
116
117
118
        super(DiffPoolBatchedGraphLayer, self).__init__()
        self.embedding_dim = input_dim
        self.assign_dim = assign_dim
        self.hidden_dim = output_feat_dim
        self.link_pred = link_pred
119
        self.feat_gc = GraphSageLayer(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
120
121
            input_dim, output_feat_dim, activation, dropout, aggregator_type
        )
122
        self.pool_gc = GraphSageLayer(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
123
124
            input_dim, assign_dim, activation, dropout, aggregator_type
        )
125
126
127
128
129
        self.reg_loss = nn.ModuleList([])
        self.loss_log = {}
        self.reg_loss.append(EntropyLoss())

    def forward(self, g, h):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
130
131
132
        feat = self.feat_gc(
            g, h
        )  # size = (sum_N, F_out), sum_N is num of nodes in this batch
133
        device = feat.device
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
134
135
136
        assign_tensor = self.pool_gc(
            g, h
        )  # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
137
138
        assign_tensor = F.softmax(assign_tensor, dim=1)
        assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist())
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
140
141
        assign_tensor = torch.block_diag(
            *assign_tensor
        )  # size = (sum_N, batch_size * N_a)
142

143
        h = torch.matmul(torch.t(assign_tensor), feat)
144
        adj = g.adjacency_matrix(transpose=True, ctx=device)
145
146
147
148
        adj_new = torch.sparse.mm(adj, assign_tensor)
        adj_new = torch.mm(torch.t(assign_tensor), adj_new)

        if self.link_pred:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
149
150
151
152
            current_lp_loss = torch.norm(
                adj.to_dense() - torch.mm(assign_tensor, torch.t(assign_tensor))
            ) / np.power(g.number_of_nodes(), 2)
            self.loss_log["LinkPredLoss"] = current_lp_loss
153
154
155
156
157

        for loss_layer in self.reg_loss:
            loss_name = str(type(loss_layer).__name__)
            self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)

158
        return adj_new, h