lander.py 5.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn

from .graphconv import GraphConv
from .focal_loss import FocalLoss

class LANDER(nn.Module):
    def __init__(self, feature_dim, nhid, num_conv=4, dropout=0,
                 use_GAT=True, K=1, balance=False,
                 use_cluster_feat = True, use_focal_loss = True, **kwargs):
        super(LANDER, self).__init__()
        nhid_half = int(nhid / 2)
        self.use_cluster_feat = use_cluster_feat
        self.use_focal_loss = use_focal_loss

        if self.use_cluster_feat:
            self.feature_dim = feature_dim * 2
        else:
            self.feature_dim = feature_dim

        input_dim = (feature_dim, nhid, nhid, nhid_half)
        output_dim = (nhid, nhid, nhid_half, nhid_half)
        self.conv = nn.ModuleList()
        self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K))
        for i in range(1, num_conv):
            self.conv.append(GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K))

        self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)
        self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)

        self.classifier_conn = nn.Sequential(nn.PReLU(nhid_half),
                                          nn.Linear(nhid_half, nhid_half),
                                          nn.PReLU(nhid_half),
                                          nn.Linear(nhid_half, 2))

        if self.use_focal_loss:
            self.loss_conn = FocalLoss(2)
        else:
            self.loss_conn = nn.CrossEntropyLoss()
        self.loss_den = nn.MSELoss()

        self.balance = balance

    def pred_conn(self, edges):
        src_feat = self.src_mlp(edges.src['conv_features'])
        dst_feat = self.dst_mlp(edges.dst['conv_features'])
        pred_conn = self.classifier_conn(src_feat + dst_feat)
        return {'pred_conn': pred_conn}

    def pred_den_msg(self, edges):
        prob = edges.data['prob_conn']
        res = edges.data['raw_affine'] * (prob[:, 1] - prob[:, 0])
        return {'pred_den_msg': res}

    def forward(self, bipartites):
        if isinstance(bipartites, dgl.DGLGraph):
            bipartites = [bipartites] * len(self.conv)
            if self.use_cluster_feat:
                neighbor_x = torch.cat([bipartites[0].ndata['features'], bipartites[0].ndata['cluster_features']], axis=1)
            else:
                neighbor_x = bipartites[0].ndata['features']

            for i in range(len(self.conv)):
                neighbor_x = self.conv[i](bipartites[i], neighbor_x)

            output_bipartite = bipartites[-1]
            output_bipartite.ndata['conv_features'] = neighbor_x
        else:
            if self.use_cluster_feat:
                neighbor_x_src = torch.cat([bipartites[0].srcdata['features'], bipartites[0].srcdata['cluster_features']], axis=1)
                center_x_src = torch.cat([bipartites[1].srcdata['features'], bipartites[1].srcdata['cluster_features']], axis=1)
            else:
                neighbor_x_src = bipartites[0].srcdata['features']
                center_x_src = bipartites[1].srcdata['features']

            for i in range(len(self.conv)):
                neighbor_x_dst = neighbor_x_src[:bipartites[i].num_dst_nodes()]
                neighbor_x_src = self.conv[i](bipartites[i], (neighbor_x_src, neighbor_x_dst))
                center_x_dst = center_x_src[:bipartites[i+1].num_dst_nodes()]
                center_x_src = self.conv[i](bipartites[i+1], (center_x_src, center_x_dst))

            output_bipartite = bipartites[-1]
            output_bipartite.srcdata['conv_features'] = neighbor_x_src
            output_bipartite.dstdata['conv_features'] = center_x_src

        output_bipartite.apply_edges(self.pred_conn)
        output_bipartite.edata['prob_conn'] = F.softmax(output_bipartite.edata['pred_conn'], dim=1)
        output_bipartite.update_all(self.pred_den_msg, fn.mean('pred_den_msg', 'pred_den'))
        return output_bipartite

    def compute_loss(self, bipartite):
        pred_den = bipartite.dstdata['pred_den']
        loss_den = self.loss_den(pred_den, bipartite.dstdata['density'])

        labels_conn = bipartite.edata['labels_conn']
        mask_conn = bipartite.edata['mask_conn']

        if self.balance:
            labels_conn = bipartite.edata['labels_conn']
            neg_check = torch.logical_and(bipartite.edata['labels_conn'] == 0, mask_conn)
            num_neg = torch.sum(neg_check).item()
            neg_indices = torch.where(neg_check)[0]
            pos_check = torch.logical_and(bipartite.edata['labels_conn'] == 1, mask_conn)
            num_pos = torch.sum(pos_check).item()
            pos_indices = torch.where(pos_check)[0]
            if num_pos > num_neg:
                mask_conn[pos_indices[np.random.choice(num_pos, num_pos - num_neg, replace = False)]] = 0
            elif num_pos < num_neg:
                mask_conn[neg_indices[np.random.choice(num_neg, num_neg - num_pos, replace = False)]] = 0

        # In subgraph training, it may happen that all edges are masked in a batch
        if mask_conn.sum() > 0:
            loss_conn = self.loss_conn(bipartite.edata['pred_conn'][mask_conn], labels_conn[mask_conn])
            loss = loss_den + loss_conn
            loss_den_val = loss_den.item()
            loss_conn_val = loss_conn.item()
        else:
            loss = loss_den
            loss_den_val = loss_den.item()
            loss_conn_val = 0

        return loss, loss_den_val, loss_conn_val