dgi.py 2.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Deep Graph Infomax in DGL

References
----------
Papers: https://arxiv.org/abs/1809.10341
Author's code: https://github.com/PetarV-/DGI
"""

import math
11
12
13

import numpy as np
import tensorflow as tf
14
from gcn import GCN
15
from tensorflow.keras import layers
16
17
18
19
20
21


class Encoder(layers.Layer):
    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
        super(Encoder, self).__init__()
        self.g = g
22
23
24
        self.conv = GCN(
            g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout
        )
25
26
27
28
29
30
31
32
33
34
35
36
37

    def call(self, features, corrupt=False):
        if corrupt:
            perm = np.random.permutation(self.g.number_of_nodes())
            features = tf.gather(features, perm)
        features = self.conv(features)
        return features


class Discriminator(layers.Layer):
    def __init__(self, n_hidden):
        super(Discriminator, self).__init__()
        uinit = tf.keras.initializers.RandomUniform(
38
39
40
41
42
43
            -1.0 / math.sqrt(n_hidden), 1.0 / math.sqrt(n_hidden)
        )
        self.weight = tf.Variable(
            initial_value=uinit(shape=(n_hidden, n_hidden), dtype="float32"),
            trainable=True,
        )
44
45

    def call(self, features, summary):
46
47
48
        features = tf.matmul(
            features, tf.matmul(self.weight, tf.expand_dims(summary, -1))
        )
49
50
51
52
53
54
        return features


class DGI(tf.keras.Model):
    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
        super(DGI, self).__init__()
55
56
57
        self.encoder = Encoder(
            g, in_feats, n_hidden, n_layers, activation, dropout
        )
58
59
60
61
62
63
64
65
66
67
68
        self.discriminator = Discriminator(n_hidden)
        self.loss = tf.nn.sigmoid_cross_entropy_with_logits

    def call(self, features):
        positive = self.encoder(features, corrupt=False)
        negative = self.encoder(features, corrupt=True)
        summary = tf.nn.sigmoid(tf.reduce_mean(positive, axis=0))

        positive = self.discriminator(positive, summary)
        negative = self.discriminator(negative, summary)

69
        l1 = self.loss(tf.ones(positive.shape), positive)
70
71
72
73
74
75
76
77
78
79
80
81
82
        l2 = self.loss(tf.zeros(negative.shape), negative)

        return tf.reduce_mean(l1) + tf.reduce_mean(l2)


class Classifier(layers.Layer):
    def __init__(self, n_hidden, n_classes):
        super(Classifier, self).__init__()
        self.fc = layers.Dense(n_classes)

    def call(self, features):
        features = self.fc(features)
        return features