dcrnn.py 4.04 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
import dgl
import dgl.function as fn
Chen Sirui's avatar
Chen Sirui committed
3
4
5
6
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
7
from dgl.base import DGLError
Chen Sirui's avatar
Chen Sirui committed
8
9
10


class DiffConv(nn.Module):
11
    """DiffConv is the implementation of diffusion convolution from paper DCRNN
Chen Sirui's avatar
Chen Sirui committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
    this layer can be used for traffic prediction, pedamic model.
    Parameter
    ==========
    in_feats : int
        number of input feature

    out_feats : int
        number of output feature

    k : int
        number of diffusion steps

    dir : str [both/in/out]
        direction of diffusion convolution
        From paper default both direction
28
    """
Chen Sirui's avatar
Chen Sirui committed
29

30
31
32
    def __init__(
        self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir="both"
    ):
Chen Sirui's avatar
Chen Sirui committed
33
34
35
36
37
        super(DiffConv, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.k = k
        self.dir = dir
38
        self.num_graphs = self.k - 1 if self.dir == "both" else 2 * self.k - 2
Chen Sirui's avatar
Chen Sirui committed
39
40
41
        self.project_fcs = nn.ModuleList()
        for i in range(self.num_graphs):
            self.project_fcs.append(
42
43
44
                nn.Linear(self.in_feats, self.out_feats, bias=False)
            )
        self.merger = nn.Parameter(torch.randn(self.num_graphs + 1))
Chen Sirui's avatar
Chen Sirui committed
45
46
47
48
49
50
51
52
53
        self.in_graph_list = in_graph_list
        self.out_graph_list = out_graph_list

    @staticmethod
    def attach_graph(g, k):
        device = g.device
        out_graph_list = []
        in_graph_list = []
        wadj, ind, outd = DiffConv.get_weight_matrix(g)
54
55
56
        adj = sparse.coo_matrix(wadj / outd.cpu().numpy())
        outg = dgl.from_scipy(adj, eweight_name="weight").to(device)
        outg.edata["weight"] = outg.edata["weight"].float().to(device)
Chen Sirui's avatar
Chen Sirui committed
57
        out_graph_list.append(outg)
58
59
60
61
62
63
64
        for i in range(k - 1):
            out_graph_list.append(
                DiffConv.diffuse(out_graph_list[-1], wadj, outd)
            )
        adj = sparse.coo_matrix(wadj.T / ind.cpu().numpy())
        ing = dgl.from_scipy(adj, eweight_name="weight").to(device)
        ing.edata["weight"] = ing.edata["weight"].float().to(device)
Chen Sirui's avatar
Chen Sirui committed
65
        in_graph_list.append(ing)
66
67
68
69
        for i in range(k - 1):
            in_graph_list.append(
                DiffConv.diffuse(in_graph_list[-1], wadj.T, ind)
            )
Chen Sirui's avatar
Chen Sirui committed
70
71
72
73
        return out_graph_list, in_graph_list

    @staticmethod
    def get_weight_matrix(g):
74
        adj = g.adj(scipy_fmt="coo")
Chen Sirui's avatar
Chen Sirui committed
75
76
        ind = g.in_degrees()
        outd = g.out_degrees()
77
        weight = g.edata["weight"]
Chen Sirui's avatar
Chen Sirui committed
78
79
80
81
82
83
        adj.data = weight.cpu().numpy()
        return adj, ind, outd

    @staticmethod
    def diffuse(progress_g, weighted_adj, degree):
        device = progress_g.device
84
85
86
87
88
89
90
        progress_adj = progress_g.adj(scipy_fmt="coo")
        progress_adj.data = progress_g.edata["weight"].cpu().numpy()
        ret_adj = sparse.coo_matrix(
            progress_adj @ (weighted_adj / degree.cpu().numpy())
        )
        ret_graph = dgl.from_scipy(ret_adj, eweight_name="weight").to(device)
        ret_graph.edata["weight"] = ret_graph.edata["weight"].float().to(device)
Chen Sirui's avatar
Chen Sirui committed
91
92
93
94
        return ret_graph

    def forward(self, g, x):
        feat_list = []
95
96
97
        if self.dir == "both":
            graph_list = self.in_graph_list + self.out_graph_list
        elif self.dir == "in":
Chen Sirui's avatar
Chen Sirui committed
98
            graph_list = self.in_graph_list
99
        elif self.dir == "out":
Chen Sirui's avatar
Chen Sirui committed
100
101
102
103
104
            graph_list = self.out_graph_list

        for i in range(self.num_graphs):
            g = graph_list[i]
            with g.local_scope():
105
106
107
108
109
                g.ndata["n"] = self.project_fcs[i](x)
                g.update_all(
                    fn.u_mul_e("n", "weight", "e"), fn.sum("e", "feat")
                )
                feat_list.append(g.ndata["feat"])
Chen Sirui's avatar
Chen Sirui committed
110
111
112
                # Each feat has shape [N,q_feats]
        feat_list.append(self.project_fcs[-1](x))
        feat_list = torch.cat(feat_list).view(
113
114
115
116
117
            len(feat_list), -1, self.out_feats
        )
        ret = (
            (self.merger * feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
        )
Chen Sirui's avatar
Chen Sirui committed
118
        return ret