"sgl-kernel/vscode:/vscode.git/clone" did not exist on "de1350ea20530e0744b48d0d50415fa2ff5122cd"
sage.py 2.58 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch.nn as nn
import dgl
from dgl.base import dgl_warning

class GraphSAGE(nn.Module):
    def __init__(self,
                 data_info: dict,
                 embed_size: int = -1,
                 hidden_size: int = 16,
                 num_layers: int = 1,
                 activation: str = "relu",
                 dropout: float = 0.5,
                 aggregator_type: str = "gcn"):        
        """GraphSAGE model

        Parameters
        ----------
        data_info : dict
            The information about the input dataset.
        embed_size : int
            The dimension of created embedding table. -1 means using original node embedding
        hidden_size : int
            Hidden size.
        num_layers : int
            Number of hidden layers.
        dropout : float
            Dropout rate.
        activation : str
            Activation function name under torch.nn.functional
        aggregator_type : str
            Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
        """
        super(GraphSAGE, self).__init__()
        self.data_info = data_info
        self.embed_size = embed_size
        if embed_size > 0:
            self.embed = nn.Embedding(data_info["num_nodes"], embed_size)
            in_size = embed_size
        else:
            in_size = data_info["in_size"]
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        self.activation = getattr(nn.functional, activation)
        
        for i in range(num_layers):
            in_hidden = hidden_size if i > 0 else in_size
            out_hidden = hidden_size if i < num_layers - 1 else data_info["out_size"]
            self.layers.append(dgl.nn.SAGEConv(in_hidden, out_hidden, aggregator_type))

    def forward(self, graph, node_feat, edge_feat = None):
        if self.embed_size > 0:
            dgl_warning("The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.", norepeat=True)
            h = self.embed.weight
        else:
            h = node_feat
        h = self.dropout(h)
        for l, layer in enumerate(self.layers):
            h = layer(graph, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h
    
    def forward_block(self, blocks, node_feat, edge_feat = None):
        h = node_feat
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h