ggnn_gc.py 1.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
Gated Graph Neural Network module for graph classification tasks
"""
from dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling
import torch
from torch import nn


class GraphClsGGNN(nn.Module):
    def __init__(self,
                 annotation_size,
                 out_feats,
                 n_steps,
                 n_etypes,
                 num_cls):
        super(GraphClsGGNN, self).__init__()

        self.annotation_size = annotation_size
        self.out_feats = out_feats

        self.ggnn = GatedGraphConv(in_feats=out_feats,
                                   out_feats=out_feats,
                                   n_steps=n_steps,
                                   n_etypes=n_etypes)

        pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)
        self.pooling = GlobalAttentionPooling(pooling_gate_nn)
        self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, graph, labels=None):
        etypes = graph.edata.pop('type')
        annotation = graph.ndata.pop('annotation').float()

        assert annotation.size()[-1] == self.annotation_size

        node_num = graph.number_of_nodes()

        zero_pad = torch.zeros([node_num, self.out_feats - self.annotation_size],
                               dtype=torch.float,
                               device=annotation.device)

        h1 = torch.cat([annotation, zero_pad], -1)
        out = self.ggnn(graph, h1, etypes)

        out = torch.cat([out, annotation], -1)

        out = self.pooling(graph, out)

        logits = self.output_layer(out)
        preds = torch.argmax(logits, -1)

        if labels is not None:
            loss = self.loss_fn(logits, labels)
            return loss, preds
        return preds