sgc.py 1.58 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from dgl.nn import SGConv
from dgl.base import dgl_warning


class SGC(nn.Module):
    def __init__(self,
                 data_info: dict,
                 embed_size: int = -1,
                 bias=True, k=2):
        """ Simplifying Graph Convolutional Networks

        Parameters
        ----------
        data_info : dict
            The information about the input dataset.
        embed_size : int
            The dimension of created embedding table. -1 means using original node embedding
        bias : bool
            If True, adds a learnable bias to the output. Default: ``True``.
        k : int
            Number of hops :math:`K`. Defaults:``1``.
        """
        super().__init__()
        self.data_info = data_info
        self.out_size = data_info["out_size"]
        self.in_size = data_info["in_size"]
        self.embed_size = embed_size
        if embed_size > 0:
            self.embed = nn.Embedding(data_info["num_nodes"], embed_size)
        self.sgc = SGConv(self.in_size, self.out_size, k=k, cached=True,
                          bias=bias, norm=self.normalize)

    def forward(self, g, node_feat, edge_feat=None):
        if self.embed_size > 0:
            dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.", norepeat=True)
            h = self.embed.weight
        else:
            h = node_feat
        return self.sgc(g, h)

    @staticmethod
    def normalize(h):
        return (h-h.mean(0))/(h.std(0) + 1e-5)