"template/vscode:/vscode.git/clone" did not exist on "fb6cbc02fbe0ff8d791413a81558a1fe9725b778"
dgi.py 2.47 KB
Newer Older
Zhengwei's avatar
Zhengwei committed
1
2
3
4
5
6
7
8
9
"""
Deep Graph Infomax in DGL

References
----------
Papers: https://arxiv.org/abs/1809.10341
Author's code: https://github.com/PetarV-/DGI
"""

10
11
import math

Zhengwei's avatar
Zhengwei committed
12
13
14
15
import torch
import torch.nn as nn
from gcn import GCN

16

Zhengwei's avatar
Zhengwei committed
17
18
19
20
class Encoder(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
        super(Encoder, self).__init__()
        self.g = g
21
22
23
        self.conv = GCN(
            g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout
        )
Zhengwei's avatar
Zhengwei committed
24
25
26

    def forward(self, features, corrupt=False):
        if corrupt:
27
            perm = torch.randperm(self.g.num_nodes())
Zhengwei's avatar
Zhengwei committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            features = features[perm]
        features = self.conv(features)
        return features


class Discriminator(nn.Module):
    def __init__(self, n_hidden):
        super(Discriminator, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(n_hidden, n_hidden))
        self.reset_parameters()

    def uniform(self, size, tensor):
        bound = 1.0 / math.sqrt(size)
        if tensor is not None:
            tensor.data.uniform_(-bound, bound)

    def reset_parameters(self):
        size = self.weight.size(0)
        self.uniform(size, self.weight)

    def forward(self, features, summary):
        features = torch.matmul(features, torch.matmul(self.weight, summary))
        return features


class DGI(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
        super(DGI, self).__init__()
56
57
58
        self.encoder = Encoder(
            g, in_feats, n_hidden, n_layers, activation, dropout
        )
Zhengwei's avatar
Zhengwei committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        self.discriminator = Discriminator(n_hidden)
        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, features):
        positive = self.encoder(features, corrupt=False)
        negative = self.encoder(features, corrupt=True)
        summary = torch.sigmoid(positive.mean(dim=0))

        positive = self.discriminator(positive, summary)
        negative = self.discriminator(negative, summary)

        l1 = self.loss(positive, torch.ones_like(positive))
        l2 = self.loss(negative, torch.zeros_like(negative))

        return l1 + l2


class Classifier(nn.Module):
    def __init__(self, n_hidden, n_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(n_hidden, n_classes)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc.reset_parameters()

    def forward(self, features):
        features = self.fc(features)
        return torch.log_softmax(features, dim=-1)