ggsnn.py 3.03 KB
Newer Older
1
2
3
4
5
6
"""
Gated Graph Sequence Neural Network for sequence outputs
"""

import torch
import torch.nn.functional as F
7
8

from dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from torch import nn
10
11
12


class GGSNN(nn.Module):
13
14
15
16
17
18
19
20
21
    def __init__(
        self,
        annotation_size,
        out_feats,
        n_steps,
        n_etypes,
        max_seq_length,
        num_cls,
    ):
22
23
24
25
26
27
        super(GGSNN, self).__init__()

        self.annotation_size = annotation_size
        self.out_feats = out_feats
        self.max_seq_length = max_seq_length

28
29
30
31
32
33
        self.ggnn = GatedGraphConv(
            in_feats=out_feats,
            out_feats=out_feats,
            n_steps=n_steps,
            n_etypes=n_etypes,
        )
34

35
36
37
        self.annotation_out_layer = nn.Linear(
            annotation_size + out_feats, annotation_size
        )
38
39
40
41
42

        pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)
        self.pooling = GlobalAttentionPooling(pooling_gate_nn)

        self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)
43
        self.loss_fn = nn.CrossEntropyLoss(reduction="none")
44
45

    def forward(self, graph, seq_lengths, ground_truth=None):
46
47
        etypes = graph.edata.pop("type")
        annotation = graph.ndata.pop("annotation").float()
48
49
50
51
52
53
54

        assert annotation.size()[-1] == self.annotation_size

        node_num = graph.number_of_nodes()

        all_logits = []
        for _ in range(self.max_seq_length):
55
56
57
58
59
            zero_pad = torch.zeros(
                [node_num, self.out_feats - self.annotation_size],
                dtype=torch.float,
                device=annotation.device,
            )
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

            h1 = torch.cat([annotation.detach(), zero_pad], -1)
            out = self.ggnn(graph, h1, etypes)
            out = torch.cat([out, annotation], -1)
            logits = self.pooling(graph, out)
            logits = self.output_layer(logits)
            all_logits.append(logits)

            annotation = self.annotation_out_layer(out)
            annotation = F.softmax(annotation, -1)

        all_logits = torch.stack(all_logits, 1)
        preds = torch.argmax(all_logits, -1)
        if ground_truth is not None:
            loss = sequence_loss(all_logits, ground_truth, seq_lengths)
            return loss, preds
        return preds


def sequence_loss(logits, ground_truth, seq_length=None):
    def sequence_mask(length):
        max_length = logits.size(1)
        batch_size = logits.size(0)
83
84
85
86
        range_tensor = torch.arange(
            0, max_length, dtype=seq_length.dtype, device=seq_length.device
        )
        range_tensor = torch.stack([range_tensor] * batch_size, 0)
87

88
        expanded_length = torch.stack([length] * max_length, -1)
89
90
91
        mask = (range_tensor < expanded_length).float()
        return mask

92
93
94
    loss = nn.CrossEntropyLoss(reduction="none")(
        logits.permute((0, 2, 1)), ground_truth
    )
95
96
97
98
99
100
101

    if seq_length is None:
        loss = loss.mean()
    else:
        mask = sequence_mask(seq_length)
        loss = (loss * mask).sum(-1) / seq_length.float()
        loss = loss.mean()
102
    return loss