lander.py 6.37 KB
Newer Older
1
2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import dgl
import dgl.function as fn
5
6
7
8
9
10
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .focal_loss import FocalLoss
11
12
from .graphconv import GraphConv

13
14

class LANDER(nn.Module):
15
16
17
18
19
20
21
22
23
24
25
26
27
    def __init__(
        self,
        feature_dim,
        nhid,
        num_conv=4,
        dropout=0,
        use_GAT=True,
        K=1,
        balance=False,
        use_cluster_feat=True,
        use_focal_loss=True,
        **kwargs
    ):
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        super(LANDER, self).__init__()
        nhid_half = int(nhid / 2)
        self.use_cluster_feat = use_cluster_feat
        self.use_focal_loss = use_focal_loss

        if self.use_cluster_feat:
            self.feature_dim = feature_dim * 2
        else:
            self.feature_dim = feature_dim

        input_dim = (feature_dim, nhid, nhid, nhid_half)
        output_dim = (nhid, nhid, nhid_half, nhid_half)
        self.conv = nn.ModuleList()
        self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K))
        for i in range(1, num_conv):
43
44
45
            self.conv.append(
                GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K)
            )
46
47
48
49

        self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
        self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)

50
51
52
53
54
55
        self.classifier_conn = nn.Sequential(
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, nhid_half),
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, 2),
        )
56
57
58
59
60
61
62
63
64
65

        if self.use_focal_loss:
            self.loss_conn = FocalLoss(2)
        else:
            self.loss_conn = nn.CrossEntropyLoss()
        self.loss_den = nn.MSELoss()

        self.balance = balance

    def pred_conn(self, edges):
66
67
        src_feat = self.src_mlp(edges.src["conv_features"])
        dst_feat = self.dst_mlp(edges.dst["conv_features"])
68
        pred_conn = self.classifier_conn(src_feat + dst_feat)
69
        return {"pred_conn": pred_conn}
70
71

    def pred_den_msg(self, edges):
72
73
74
        prob = edges.data["prob_conn"]
        res = edges.data["raw_affine"] * (prob[:, 1] - prob[:, 0])
        return {"pred_den_msg": res}
75
76
77
78
79

    def forward(self, bipartites):
        if isinstance(bipartites, dgl.DGLGraph):
            bipartites = [bipartites] * len(self.conv)
            if self.use_cluster_feat:
80
81
82
83
84
85
86
                neighbor_x = torch.cat(
                    [
                        bipartites[0].ndata["features"],
                        bipartites[0].ndata["cluster_features"],
                    ],
                    axis=1,
                )
87
            else:
88
                neighbor_x = bipartites[0].ndata["features"]
89
90
91
92
93

            for i in range(len(self.conv)):
                neighbor_x = self.conv[i](bipartites[i], neighbor_x)

            output_bipartite = bipartites[-1]
94
            output_bipartite.ndata["conv_features"] = neighbor_x
95
96
        else:
            if self.use_cluster_feat:
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                neighbor_x_src = torch.cat(
                    [
                        bipartites[0].srcdata["features"],
                        bipartites[0].srcdata["cluster_features"],
                    ],
                    axis=1,
                )
                center_x_src = torch.cat(
                    [
                        bipartites[1].srcdata["features"],
                        bipartites[1].srcdata["cluster_features"],
                    ],
                    axis=1,
                )
111
            else:
112
113
                neighbor_x_src = bipartites[0].srcdata["features"]
                center_x_src = bipartites[1].srcdata["features"]
114
115

            for i in range(len(self.conv)):
116
117
118
119
120
121
122
123
                neighbor_x_dst = neighbor_x_src[: bipartites[i].num_dst_nodes()]
                neighbor_x_src = self.conv[i](
                    bipartites[i], (neighbor_x_src, neighbor_x_dst)
                )
                center_x_dst = center_x_src[: bipartites[i + 1].num_dst_nodes()]
                center_x_src = self.conv[i](
                    bipartites[i + 1], (center_x_src, center_x_dst)
                )
124
125

            output_bipartite = bipartites[-1]
126
127
            output_bipartite.srcdata["conv_features"] = neighbor_x_src
            output_bipartite.dstdata["conv_features"] = center_x_src
128
129

        output_bipartite.apply_edges(self.pred_conn)
130
131
132
133
134
135
        output_bipartite.edata["prob_conn"] = F.softmax(
            output_bipartite.edata["pred_conn"], dim=1
        )
        output_bipartite.update_all(
            self.pred_den_msg, fn.mean("pred_den_msg", "pred_den")
        )
136
137
138
        return output_bipartite

    def compute_loss(self, bipartite):
139
140
        pred_den = bipartite.dstdata["pred_den"]
        loss_den = self.loss_den(pred_den, bipartite.dstdata["density"])
141

142
143
        labels_conn = bipartite.edata["labels_conn"]
        mask_conn = bipartite.edata["mask_conn"]
144
145

        if self.balance:
146
147
148
149
            labels_conn = bipartite.edata["labels_conn"]
            neg_check = torch.logical_and(
                bipartite.edata["labels_conn"] == 0, mask_conn
            )
150
151
            num_neg = torch.sum(neg_check).item()
            neg_indices = torch.where(neg_check)[0]
152
153
154
            pos_check = torch.logical_and(
                bipartite.edata["labels_conn"] == 1, mask_conn
            )
155
156
157
            num_pos = torch.sum(pos_check).item()
            pos_indices = torch.where(pos_check)[0]
            if num_pos > num_neg:
158
159
160
161
162
163
164
                mask_conn[
                    pos_indices[
                        np.random.choice(
                            num_pos, num_pos - num_neg, replace=False
                        )
                    ]
                ] = 0
165
            elif num_pos < num_neg:
166
167
168
169
170
171
172
                mask_conn[
                    neg_indices[
                        np.random.choice(
                            num_neg, num_neg - num_pos, replace=False
                        )
                    ]
                ] = 0
173
174
175

        # In subgraph training, it may happen that all edges are masked in a batch
        if mask_conn.sum() > 0:
176
177
178
            loss_conn = self.loss_conn(
                bipartite.edata["pred_conn"][mask_conn], labels_conn[mask_conn]
            )
179
180
181
182
183
184
185
186
187
            loss = loss_den + loss_conn
            loss_den_val = loss_den.item()
            loss_conn_val = loss_conn.item()
        else:
            loss = loss_den
            loss_den_val = loss_den.item()
            loss_conn_val = 0

        return loss, loss_den_val, loss_conn_val