gaan.py 6.11 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
2
3
import numpy as np
import torch
import torch.nn as nn
4

Chen Sirui's avatar
Chen Sirui committed
5
import dgl
6
import dgl.function as fn
Chen Sirui's avatar
Chen Sirui committed
7
8
9
10
11
12
import dgl.nn as dglnn
from dgl.base import DGLError
from dgl.nn.functional import edge_softmax


class WeightedGATConv(dglnn.GATConv):
13
    """
Chen Sirui's avatar
Chen Sirui committed
14
15
    This model inherit from dgl GATConv for traffic prediction task,
    it add edge weight when aggregating the node feature.
16
    """
Chen Sirui's avatar
Chen Sirui committed
17
18
19
20
21

    def forward(self, graph, feat, get_attention=False):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
22
23
24
25
26
27
28
29
30
31
32
                    raise DGLError(
                        "There are 0-in-degree nodes in the graph, "
                        "output for those nodes will be invalid. "
                        "This is harmful for some applications, "
                        "causing silent performance regression. "
                        "Adding self-loop on the input graph by "
                        "calling `g = dgl.add_self_loop(g)` will resolve "
                        "the issue. Setting ``allow_zero_in_degree`` "
                        "to be `True` when constructing this module will "
                        "suppress the check and let the code run."
                    )
Chen Sirui's avatar
Chen Sirui committed
33
34
35
36

            if isinstance(feat, tuple):
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
37
38
39
40
41
42
43
                if not hasattr(self, "fc_src"):
                    feat_src = self.fc(h_src).view(
                        -1, self._num_heads, self._out_feats
                    )
                    feat_dst = self.fc(h_dst).view(
                        -1, self._num_heads, self._out_feats
                    )
Chen Sirui's avatar
Chen Sirui committed
44
                else:
45
46
47
48
49
50
                    feat_src = self.fc_src(h_src).view(
                        -1, self._num_heads, self._out_feats
                    )
                    feat_dst = self.fc_dst(h_dst).view(
                        -1, self._num_heads, self._out_feats
                    )
Chen Sirui's avatar
Chen Sirui committed
51
52
53
            else:
                h_src = h_dst = self.feat_drop(feat)
                feat_src = feat_dst = self.fc(h_src).view(
54
55
                    -1, self._num_heads, self._out_feats
                )
Chen Sirui's avatar
Chen Sirui committed
56
                if graph.is_block:
57
                    feat_dst = feat_src[: graph.number_of_dst_nodes()]
Chen Sirui's avatar
Chen Sirui committed
58
59
60
61
62
63
64
65
66
67
68
69
            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
70
71
            graph.srcdata.update({"ft": feat_src, "el": el})
            graph.dstdata.update({"er": er})
Chen Sirui's avatar
Chen Sirui committed
72
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
73
74
            graph.apply_edges(fn.u_add_v("el", "er", "e"))
            e = self.leaky_relu(graph.edata.pop("e"))
Chen Sirui's avatar
Chen Sirui committed
75
            # compute softmax
76
            graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
Chen Sirui's avatar
Chen Sirui committed
77
            # compute weighted attention
78
79
80
            graph.edata["a"] = (
                graph.edata["a"].permute(1, 2, 0) * graph.edata["weight"]
            ).permute(2, 0, 1)
Chen Sirui's avatar
Chen Sirui committed
81
            # message passing
82
83
            graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
            rst = graph.dstdata["ft"]
Chen Sirui's avatar
Chen Sirui committed
84
85
86
            # residual
            if self.res_fc is not None:
                resval = self.res_fc(h_dst).view(
87
88
                    h_dst.shape[0], -1, self._out_feats
                )
Chen Sirui's avatar
Chen Sirui committed
89
90
91
92
93
94
                rst = rst + resval
            # activation
            if self.activation:
                rst = self.activation(rst)

            if get_attention:
95
                return rst, graph.edata["a"]
Chen Sirui's avatar
Chen Sirui committed
96
97
98
99
100
            else:
                return rst


class GatedGAT(nn.Module):
101
    """Gated Graph Attention module, it is a general purpose
Chen Sirui's avatar
Chen Sirui committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    graph attention module proposed in paper GaAN. The paper use
    it for traffic prediction task
    Parameter
    ==========
    in_feats : int
        number of input feature

    out_feats : int
        number of output feature

    map_feats : int
        intermediate feature size for gate computation

    num_heads : int
        number of head for multihead attention
117
    """
Chen Sirui's avatar
Chen Sirui committed
118
119
120
121
122
123
124

    def __init__(self, in_feats, out_feats, map_feats, num_heads):
        super(GatedGAT, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.map_feats = map_feats
        self.num_heads = num_heads
125
126
127
        self.gatlayer = WeightedGATConv(
            self.in_feats, self.out_feats, self.num_heads
        )
Chen Sirui's avatar
Chen Sirui committed
128
        self.gate_fn = nn.Linear(
129
130
            2 * self.in_feats + self.map_feats, self.num_heads
        )
Chen Sirui's avatar
Chen Sirui committed
131
132
        self.gate_m = nn.Linear(self.in_feats, self.map_feats)
        self.merger_layer = nn.Linear(
133
134
            self.in_feats + self.out_feats, self.out_feats
        )
Chen Sirui's avatar
Chen Sirui committed
135
136
137

    def forward(self, g, x):
        with g.local_scope():
138
139
140
141
142
143
144
            g.ndata["x"] = x
            g.ndata["z"] = self.gate_m(x)
            g.update_all(fn.copy_u("x", "x"), fn.mean("x", "mean_z"))
            g.update_all(fn.copy_u("z", "z"), fn.max("z", "max_z"))
            nft = torch.cat(
                [g.ndata["x"], g.ndata["max_z"], g.ndata["mean_z"]], dim=1
            )
Chen Sirui's avatar
Chen Sirui committed
145
146
147
            gate = self.gate_fn(nft).sigmoid()
            attn_out = self.gatlayer(g, x)
            node_num = g.num_nodes()
148
149
150
            gated_out = (
                (gate.view(-1) * attn_out.view(-1, self.out_feats).T).T
            ).view(node_num, self.num_heads, self.out_feats)
Chen Sirui's avatar
Chen Sirui committed
151
152
153
            gated_out = gated_out.mean(1)
            merge = self.merger_layer(torch.cat([x, gated_out], dim=1))
            return merge