ggnn_ns.py 1.84 KB
Newer Older
1
2
3
4
"""
Gated Graph Neural Network module for node selection tasks
"""
import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
import torch
6
from dgl.nn.pytorch import GatedGraphConv
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
from torch import nn
8
9
10


class NodeSelectionGGNN(nn.Module):
11
    def __init__(self, annotation_size, out_feats, n_steps, n_etypes):
12
13
14
15
16
        super(NodeSelectionGGNN, self).__init__()

        self.annotation_size = annotation_size
        self.out_feats = out_feats

17
18
19
20
21
22
        self.ggnn = GatedGraphConv(
            in_feats=out_feats,
            out_feats=out_feats,
            n_steps=n_steps,
            n_etypes=n_etypes,
        )
23
24
25
26
27

        self.output_layer = nn.Linear(annotation_size + out_feats, 1)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, graph, labels=None):
28
29
        etypes = graph.edata.pop("type")
        annotation = graph.ndata.pop("annotation").float()
30
31
32
33
34

        assert annotation.size()[-1] == self.annotation_size

        node_num = graph.number_of_nodes()

35
36
37
38
39
        zero_pad = torch.zeros(
            [node_num, self.out_feats - self.annotation_size],
            dtype=torch.float,
            device=annotation.device,
        )
40
41
42
43

        h1 = torch.cat([annotation, zero_pad], -1)
        out = self.ggnn(graph, h1, etypes)

44
45
46
47
        all_logits = self.output_layer(
            torch.cat([out, annotation], -1)
        ).squeeze(-1)
        graph.ndata["logits"] = all_logits
48
49
50
51
52
53
54

        batch_g = dgl.unbatch(graph)

        preds = []
        if labels is not None:
            loss = 0.0
        for i, g in enumerate(batch_g):
55
            logits = g.ndata["logits"]
56
57
58
59
60
61
62
63
64
            preds.append(torch.argmax(logits))
            if labels is not None:
                logits = logits.unsqueeze(0)
                y = labels[i].unsqueeze(0)
                loss += self.loss_fn(logits, y)

        if labels is not None:
            loss /= float(len(batch_g))
            return loss, preds
65
        return preds