model.py 4.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import numpy as np
import torch as th
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F

def drop_node(feats, drop_rate, training):
    
    n = feats.shape[0]
    drop_rates = th.FloatTensor(np.ones(n) * drop_rate)
    
    if training:
            
        masks = th.bernoulli(1. - drop_rates).unsqueeze(1)
        feats = masks.to(feats.device) * feats
        
    else:
        feats = feats * (1. - drop_rate)

    return feats

class MLP(nn.Module):
    def __init__(self, nfeat, nhid, nclass, input_droprate, hidden_droprate, use_bn =False):
        super(MLP, self).__init__()
        
        self.layer1 = nn.Linear(nfeat, nhid, bias = True)
        self.layer2 = nn.Linear(nhid, nclass, bias = True)

        self.input_dropout = nn.Dropout(input_droprate)
        self.hidden_dropout = nn.Dropout(hidden_droprate)
        self.bn1 = nn.BatchNorm1d(nfeat)
        self.bn2 = nn.BatchNorm1d(nhid)
        self.use_bn = use_bn
    
    def reset_parameters(self):
        self.layer1.reset_parameters()
        self.layer2.reset_parameters()
        
    def forward(self, x):
         
        if self.use_bn: 
            x = self.bn1(x)
        x = self.input_dropout(x)
        x = F.relu(self.layer1(x))
        
        if self.use_bn:
            x = self.bn2(x)
        x = self.hidden_dropout(x)
        x = self.layer2(x)

        return x   
        

def GRANDConv(graph, feats, order):
    '''
    Parameters
    -----------
    graph: dgl.Graph
        The input graph
    feats: Tensor (n_nodes * feat_dim)
        Node features
    order: int 
        Propagation Steps
    '''
    with graph.local_scope():
        
        ''' Calculate Symmetric normalized adjacency matrix   \hat{A} '''
        degs = graph.in_degrees().float().clamp(min=1)
        norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1)

        graph.ndata['norm'] = norm
        graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight')) 
        
        ''' Graph Conv '''
        x = feats
        y = 0+feats

        for i in range(order):
            graph.ndata['h'] = x
            graph.update_all(fn.u_mul_e('h', 'weight', 'm'), fn.sum('m', 'h'))
            x = graph.ndata.pop('h')
            y.add_(x)

    return y /(order + 1)

class GRAND(nn.Module):
    r"""

    Parameters
    -----------
    in_dim: int
        Input feature size. i.e, the number of dimensions of: math: `H^{(i)}`.
    hid_dim: int
        Hidden feature size.
    n_class: int
        Number of classes.
    S: int
        Number of Augmentation samples
    K: int
        Number of Propagation Steps
    node_dropout: float
        Dropout rate on node features.
    input_dropout: float
        Dropout rate of the input layer of a MLP
    hidden_dropout: float
        Dropout rate of the hidden layer of a MLPx
    batchnorm: bool, optional
        If True, use batch normalization.

    """
    def __init__(self,
                 in_dim,
                 hid_dim,
                 n_class,
                 S = 1,
                 K = 3,
                 node_dropout=0.0,
                 input_droprate = 0.0, 
                 hidden_droprate = 0.0,
                 batchnorm=False):

        super(GRAND, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.S = S
        self.K = K
        self.n_class = n_class
        
        self.mlp = MLP(in_dim, hid_dim, n_class, input_droprate, hidden_droprate, batchnorm)
        
        self.dropout = node_dropout
        self.node_dropout = nn.Dropout(node_dropout)

    def forward(self, graph, feats, training = True):
        
        X = feats
        S = self.S
        
        if training: # Training Mode
            output_list = []
            for s in range(S):
                drop_feat = drop_node(X, self.dropout, True)  # Drop node
                feat = GRANDConv(graph, drop_feat, self.K)    # Graph Convolution
                output_list.append(th.log_softmax(self.mlp(feat), dim=-1))  # Prediction
        
            return output_list
        else:   # Inference Mode
            drop_feat = drop_node(X, self.dropout, False) 
            X =  GRANDConv(graph, drop_feat, self.K)

            return th.log_softmax(self.mlp(X), dim = -1)