gnn.py 2.82 KB
Newer Older
1
2
import torch
import torch.nn as nn
3
from conv import GNN_node, GNN_node_Virtualnode
4

5
6
7
8
9
10
11
from dgl.nn.pytorch import (
    AvgPooling,
    GlobalAttentionPooling,
    MaxPooling,
    Set2Set,
    SumPooling,
)
12
13
14


class GNN(nn.Module):
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    def __init__(
        self,
        num_tasks=1,
        num_layers=5,
        emb_dim=300,
        gnn_type="gin",
        virtual_node=True,
        residual=False,
        drop_ratio=0,
        JK="last",
        graph_pooling="sum",
    ):
        """
        num_tasks (int): number of labels to be predicted
        virtual_node (bool): whether to add virtual node or not
        """
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        super(GNN, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
45
46
47
48
49
50
51
52
            self.gnn_node = GNN_node_Virtualnode(
                num_layers,
                emb_dim,
                JK=JK,
                drop_ratio=drop_ratio,
                residual=residual,
                gnn_type=gnn_type,
            )
53
        else:
54
55
56
57
58
59
60
61
            self.gnn_node = GNN_node(
                num_layers,
                emb_dim,
                JK=JK,
                drop_ratio=drop_ratio,
                residual=residual,
                gnn_type=gnn_type,
            )
62
63
64
65
66
67
68
69
70
71

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = SumPooling()
        elif self.graph_pooling == "mean":
            self.pool = AvgPooling()
        elif self.graph_pooling == "max":
            self.pool = MaxPooling
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttentionPooling(
72
73
74
75
76
77
78
                gate_nn=nn.Sequential(
                    nn.Linear(emb_dim, 2 * emb_dim),
                    nn.BatchNorm1d(2 * emb_dim),
                    nn.ReLU(),
                    nn.Linear(2 * emb_dim, 1),
                )
            )
79
80

        elif self.graph_pooling == "set2set":
81
            self.pool = Set2Set(emb_dim, n_iters=2, n_layers=2)
82
83
84
85
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
86
            self.graph_pred_linear = nn.Linear(2 * self.emb_dim, self.num_tasks)
87
88
89
90
91
92
93
94
95
96
97
98
99
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, g, x, edge_attr):
        h_node = self.gnn_node(g, x, edge_attr)

        h_graph = self.pool(g, h_node)
        output = self.graph_pred_linear(h_graph)

        if self.training:
            return output
        else:
            return torch.clamp(output, min=0, max=50)