gaan.py 6.11 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
import dgl
2
import dgl.function as fn
Chen Sirui's avatar
Chen Sirui committed
3
import dgl.nn as dglnn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
6
import numpy as np
import torch
import torch.nn as nn
Chen Sirui's avatar
Chen Sirui committed
7
8
9
10
11
from dgl.base import DGLError
from dgl.nn.functional import edge_softmax


class WeightedGATConv(dglnn.GATConv):
12
    """
Chen Sirui's avatar
Chen Sirui committed
13
14
    This model inherit from dgl GATConv for traffic prediction task,
    it add edge weight when aggregating the node feature.
15
    """
Chen Sirui's avatar
Chen Sirui committed
16
17
18
19
20

    def forward(self, graph, feat, get_attention=False):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
21
22
23
24
25
26
27
28
29
30
31
                    raise DGLError(
                        "There are 0-in-degree nodes in the graph, "
                        "output for those nodes will be invalid. "
                        "This is harmful for some applications, "
                        "causing silent performance regression. "
                        "Adding self-loop on the input graph by "
                        "calling `g = dgl.add_self_loop(g)` will resolve "
                        "the issue. Setting ``allow_zero_in_degree`` "
                        "to be `True` when constructing this module will "
                        "suppress the check and let the code run."
                    )
Chen Sirui's avatar
Chen Sirui committed
32
33
34
35

            if isinstance(feat, tuple):
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
36
37
38
39
40
41
42
                if not hasattr(self, "fc_src"):
                    feat_src = self.fc(h_src).view(
                        -1, self._num_heads, self._out_feats
                    )
                    feat_dst = self.fc(h_dst).view(
                        -1, self._num_heads, self._out_feats
                    )
Chen Sirui's avatar
Chen Sirui committed
43
                else:
44
45
46
47
48
49
                    feat_src = self.fc_src(h_src).view(
                        -1, self._num_heads, self._out_feats
                    )
                    feat_dst = self.fc_dst(h_dst).view(
                        -1, self._num_heads, self._out_feats
                    )
Chen Sirui's avatar
Chen Sirui committed
50
51
52
            else:
                h_src = h_dst = self.feat_drop(feat)
                feat_src = feat_dst = self.fc(h_src).view(
53
54
                    -1, self._num_heads, self._out_feats
                )
Chen Sirui's avatar
Chen Sirui committed
55
                if graph.is_block:
56
                    feat_dst = feat_src[: graph.number_of_dst_nodes()]
Chen Sirui's avatar
Chen Sirui committed
57
58
59
60
61
62
63
64
65
66
67
68
            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
69
70
            graph.srcdata.update({"ft": feat_src, "el": el})
            graph.dstdata.update({"er": er})
Chen Sirui's avatar
Chen Sirui committed
71
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
72
73
            graph.apply_edges(fn.u_add_v("el", "er", "e"))
            e = self.leaky_relu(graph.edata.pop("e"))
Chen Sirui's avatar
Chen Sirui committed
74
            # compute softmax
75
            graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
Chen Sirui's avatar
Chen Sirui committed
76
            # compute weighted attention
77
78
79
            graph.edata["a"] = (
                graph.edata["a"].permute(1, 2, 0) * graph.edata["weight"]
            ).permute(2, 0, 1)
Chen Sirui's avatar
Chen Sirui committed
80
            # message passing
81
82
            graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
            rst = graph.dstdata["ft"]
Chen Sirui's avatar
Chen Sirui committed
83
84
85
            # residual
            if self.res_fc is not None:
                resval = self.res_fc(h_dst).view(
86
87
                    h_dst.shape[0], -1, self._out_feats
                )
Chen Sirui's avatar
Chen Sirui committed
88
89
90
91
92
93
                rst = rst + resval
            # activation
            if self.activation:
                rst = self.activation(rst)

            if get_attention:
94
                return rst, graph.edata["a"]
Chen Sirui's avatar
Chen Sirui committed
95
96
97
98
99
            else:
                return rst


class GatedGAT(nn.Module):
100
    """Gated Graph Attention module, it is a general purpose
Chen Sirui's avatar
Chen Sirui committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    graph attention module proposed in paper GaAN. The paper use
    it for traffic prediction task
    Parameter
    ==========
    in_feats : int
        number of input feature

    out_feats : int
        number of output feature

    map_feats : int
        intermediate feature size for gate computation

    num_heads : int
        number of head for multihead attention
116
    """
Chen Sirui's avatar
Chen Sirui committed
117
118
119
120
121
122
123

    def __init__(self, in_feats, out_feats, map_feats, num_heads):
        super(GatedGAT, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.map_feats = map_feats
        self.num_heads = num_heads
124
125
126
        self.gatlayer = WeightedGATConv(
            self.in_feats, self.out_feats, self.num_heads
        )
Chen Sirui's avatar
Chen Sirui committed
127
        self.gate_fn = nn.Linear(
128
129
            2 * self.in_feats + self.map_feats, self.num_heads
        )
Chen Sirui's avatar
Chen Sirui committed
130
131
        self.gate_m = nn.Linear(self.in_feats, self.map_feats)
        self.merger_layer = nn.Linear(
132
133
            self.in_feats + self.out_feats, self.out_feats
        )
Chen Sirui's avatar
Chen Sirui committed
134
135
136

    def forward(self, g, x):
        with g.local_scope():
137
138
139
140
141
142
143
            g.ndata["x"] = x
            g.ndata["z"] = self.gate_m(x)
            g.update_all(fn.copy_u("x", "x"), fn.mean("x", "mean_z"))
            g.update_all(fn.copy_u("z", "z"), fn.max("z", "max_z"))
            nft = torch.cat(
                [g.ndata["x"], g.ndata["max_z"], g.ndata["mean_z"]], dim=1
            )
Chen Sirui's avatar
Chen Sirui committed
144
145
146
            gate = self.gate_fn(nft).sigmoid()
            attn_out = self.gatlayer(g, x)
            node_num = g.num_nodes()
147
148
149
            gated_out = (
                (gate.view(-1) * attn_out.view(-1, self.out_feats).T).T
            ).view(node_num, self.num_heads, self.out_feats)
Chen Sirui's avatar
Chen Sirui committed
150
151
152
            gated_out = gated_out.mean(1)
            merge = self.merger_layer(torch.cat([x, gated_out], dim=1))
            return merge