lander.py 6.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn

from .focal_loss import FocalLoss
12
13
from .graphconv import GraphConv

14
15

class LANDER(nn.Module):
16
17
18
19
20
21
22
23
24
25
26
27
28
    def __init__(
        self,
        feature_dim,
        nhid,
        num_conv=4,
        dropout=0,
        use_GAT=True,
        K=1,
        balance=False,
        use_cluster_feat=True,
        use_focal_loss=True,
        **kwargs
    ):
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        super(LANDER, self).__init__()
        nhid_half = int(nhid / 2)
        self.use_cluster_feat = use_cluster_feat
        self.use_focal_loss = use_focal_loss

        if self.use_cluster_feat:
            self.feature_dim = feature_dim * 2
        else:
            self.feature_dim = feature_dim

        input_dim = (feature_dim, nhid, nhid, nhid_half)
        output_dim = (nhid, nhid, nhid_half, nhid_half)
        self.conv = nn.ModuleList()
        self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K))
        for i in range(1, num_conv):
44
45
46
            self.conv.append(
                GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K)
            )
47
48
49
50

        self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
        self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)

51
52
53
54
55
56
        self.classifier_conn = nn.Sequential(
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, nhid_half),
            nn.PReLU(nhid_half),
            nn.Linear(nhid_half, 2),
        )
57
58
59
60
61
62
63
64
65
66

        if self.use_focal_loss:
            self.loss_conn = FocalLoss(2)
        else:
            self.loss_conn = nn.CrossEntropyLoss()
        self.loss_den = nn.MSELoss()

        self.balance = balance

    def pred_conn(self, edges):
67
68
        src_feat = self.src_mlp(edges.src["conv_features"])
        dst_feat = self.dst_mlp(edges.dst["conv_features"])
69
        pred_conn = self.classifier_conn(src_feat + dst_feat)
70
        return {"pred_conn": pred_conn}
71
72

    def pred_den_msg(self, edges):
73
74
75
        prob = edges.data["prob_conn"]
        res = edges.data["raw_affine"] * (prob[:, 1] - prob[:, 0])
        return {"pred_den_msg": res}
76
77
78
79
80

    def forward(self, bipartites):
        if isinstance(bipartites, dgl.DGLGraph):
            bipartites = [bipartites] * len(self.conv)
            if self.use_cluster_feat:
81
82
83
84
85
86
87
                neighbor_x = torch.cat(
                    [
                        bipartites[0].ndata["features"],
                        bipartites[0].ndata["cluster_features"],
                    ],
                    axis=1,
                )
88
            else:
89
                neighbor_x = bipartites[0].ndata["features"]
90
91
92
93
94

            for i in range(len(self.conv)):
                neighbor_x = self.conv[i](bipartites[i], neighbor_x)

            output_bipartite = bipartites[-1]
95
            output_bipartite.ndata["conv_features"] = neighbor_x
96
97
        else:
            if self.use_cluster_feat:
98
99
100
101
102
103
104
105
106
107
108
109
110
111
                neighbor_x_src = torch.cat(
                    [
                        bipartites[0].srcdata["features"],
                        bipartites[0].srcdata["cluster_features"],
                    ],
                    axis=1,
                )
                center_x_src = torch.cat(
                    [
                        bipartites[1].srcdata["features"],
                        bipartites[1].srcdata["cluster_features"],
                    ],
                    axis=1,
                )
112
            else:
113
114
                neighbor_x_src = bipartites[0].srcdata["features"]
                center_x_src = bipartites[1].srcdata["features"]
115
116

            for i in range(len(self.conv)):
117
118
119
120
121
122
123
124
                neighbor_x_dst = neighbor_x_src[: bipartites[i].num_dst_nodes()]
                neighbor_x_src = self.conv[i](
                    bipartites[i], (neighbor_x_src, neighbor_x_dst)
                )
                center_x_dst = center_x_src[: bipartites[i + 1].num_dst_nodes()]
                center_x_src = self.conv[i](
                    bipartites[i + 1], (center_x_src, center_x_dst)
                )
125
126

            output_bipartite = bipartites[-1]
127
128
            output_bipartite.srcdata["conv_features"] = neighbor_x_src
            output_bipartite.dstdata["conv_features"] = center_x_src
129
130

        output_bipartite.apply_edges(self.pred_conn)
131
132
133
134
135
136
        output_bipartite.edata["prob_conn"] = F.softmax(
            output_bipartite.edata["pred_conn"], dim=1
        )
        output_bipartite.update_all(
            self.pred_den_msg, fn.mean("pred_den_msg", "pred_den")
        )
137
138
139
        return output_bipartite

    def compute_loss(self, bipartite):
140
141
        pred_den = bipartite.dstdata["pred_den"]
        loss_den = self.loss_den(pred_den, bipartite.dstdata["density"])
142

143
144
        labels_conn = bipartite.edata["labels_conn"]
        mask_conn = bipartite.edata["mask_conn"]
145
146

        if self.balance:
147
148
149
150
            labels_conn = bipartite.edata["labels_conn"]
            neg_check = torch.logical_and(
                bipartite.edata["labels_conn"] == 0, mask_conn
            )
151
152
            num_neg = torch.sum(neg_check).item()
            neg_indices = torch.where(neg_check)[0]
153
154
155
            pos_check = torch.logical_and(
                bipartite.edata["labels_conn"] == 1, mask_conn
            )
156
157
158
            num_pos = torch.sum(pos_check).item()
            pos_indices = torch.where(pos_check)[0]
            if num_pos > num_neg:
159
160
161
162
163
164
165
                mask_conn[
                    pos_indices[
                        np.random.choice(
                            num_pos, num_pos - num_neg, replace=False
                        )
                    ]
                ] = 0
166
            elif num_pos < num_neg:
167
168
169
170
171
172
173
                mask_conn[
                    neg_indices[
                        np.random.choice(
                            num_neg, num_neg - num_pos, replace=False
                        )
                    ]
                ] = 0
174
175
176

        # In subgraph training, it may happen that all edges are masked in a batch
        if mask_conn.sum() > 0:
177
178
179
            loss_conn = self.loss_conn(
                bipartite.edata["pred_conn"][mask_conn], labels_conn[mask_conn]
            )
180
181
182
183
184
185
186
187
188
            loss = loss_den + loss_conn
            loss_den_val = loss_den.item()
            loss_conn_val = loss_conn.item()
        else:
            loss = loss_den
            loss_den_val = loss_den.item()
            loss_conn_val = 0

        return loss, loss_den_val, loss_conn_val