dcrnn.py 4.04 KB
Newer Older
Chen Sirui's avatar
Chen Sirui committed
1
2
3
4
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
5

Chen Sirui's avatar
Chen Sirui committed
6
7
import dgl
import dgl.function as fn
8
from dgl.base import DGLError
Chen Sirui's avatar
Chen Sirui committed
9
10
11


class DiffConv(nn.Module):
12
    """DiffConv is the implementation of diffusion convolution from paper DCRNN
Chen Sirui's avatar
Chen Sirui committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
    this layer can be used for traffic prediction, pedamic model.
    Parameter
    ==========
    in_feats : int
        number of input feature

    out_feats : int
        number of output feature

    k : int
        number of diffusion steps

    dir : str [both/in/out]
        direction of diffusion convolution
        From paper default both direction
29
    """
Chen Sirui's avatar
Chen Sirui committed
30

31
32
33
    def __init__(
        self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir="both"
    ):
Chen Sirui's avatar
Chen Sirui committed
34
35
36
37
38
        super(DiffConv, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.k = k
        self.dir = dir
39
        self.num_graphs = self.k - 1 if self.dir == "both" else 2 * self.k - 2
Chen Sirui's avatar
Chen Sirui committed
40
41
42
        self.project_fcs = nn.ModuleList()
        for i in range(self.num_graphs):
            self.project_fcs.append(
43
44
45
                nn.Linear(self.in_feats, self.out_feats, bias=False)
            )
        self.merger = nn.Parameter(torch.randn(self.num_graphs + 1))
Chen Sirui's avatar
Chen Sirui committed
46
47
48
49
50
51
52
53
54
        self.in_graph_list = in_graph_list
        self.out_graph_list = out_graph_list

    @staticmethod
    def attach_graph(g, k):
        device = g.device
        out_graph_list = []
        in_graph_list = []
        wadj, ind, outd = DiffConv.get_weight_matrix(g)
55
56
57
        adj = sparse.coo_matrix(wadj / outd.cpu().numpy())
        outg = dgl.from_scipy(adj, eweight_name="weight").to(device)
        outg.edata["weight"] = outg.edata["weight"].float().to(device)
Chen Sirui's avatar
Chen Sirui committed
58
        out_graph_list.append(outg)
59
60
61
62
63
64
65
        for i in range(k - 1):
            out_graph_list.append(
                DiffConv.diffuse(out_graph_list[-1], wadj, outd)
            )
        adj = sparse.coo_matrix(wadj.T / ind.cpu().numpy())
        ing = dgl.from_scipy(adj, eweight_name="weight").to(device)
        ing.edata["weight"] = ing.edata["weight"].float().to(device)
Chen Sirui's avatar
Chen Sirui committed
66
        in_graph_list.append(ing)
67
68
69
70
        for i in range(k - 1):
            in_graph_list.append(
                DiffConv.diffuse(in_graph_list[-1], wadj.T, ind)
            )
Chen Sirui's avatar
Chen Sirui committed
71
72
73
74
        return out_graph_list, in_graph_list

    @staticmethod
    def get_weight_matrix(g):
75
        adj = g.adj(scipy_fmt="coo")
Chen Sirui's avatar
Chen Sirui committed
76
77
        ind = g.in_degrees()
        outd = g.out_degrees()
78
        weight = g.edata["weight"]
Chen Sirui's avatar
Chen Sirui committed
79
80
81
82
83
84
        adj.data = weight.cpu().numpy()
        return adj, ind, outd

    @staticmethod
    def diffuse(progress_g, weighted_adj, degree):
        device = progress_g.device
85
86
87
88
89
90
91
        progress_adj = progress_g.adj(scipy_fmt="coo")
        progress_adj.data = progress_g.edata["weight"].cpu().numpy()
        ret_adj = sparse.coo_matrix(
            progress_adj @ (weighted_adj / degree.cpu().numpy())
        )
        ret_graph = dgl.from_scipy(ret_adj, eweight_name="weight").to(device)
        ret_graph.edata["weight"] = ret_graph.edata["weight"].float().to(device)
Chen Sirui's avatar
Chen Sirui committed
92
93
94
95
        return ret_graph

    def forward(self, g, x):
        feat_list = []
96
97
98
        if self.dir == "both":
            graph_list = self.in_graph_list + self.out_graph_list
        elif self.dir == "in":
Chen Sirui's avatar
Chen Sirui committed
99
            graph_list = self.in_graph_list
100
        elif self.dir == "out":
Chen Sirui's avatar
Chen Sirui committed
101
102
103
104
105
            graph_list = self.out_graph_list

        for i in range(self.num_graphs):
            g = graph_list[i]
            with g.local_scope():
106
107
108
109
110
                g.ndata["n"] = self.project_fcs[i](x)
                g.update_all(
                    fn.u_mul_e("n", "weight", "e"), fn.sum("e", "feat")
                )
                feat_list.append(g.ndata["feat"])
Chen Sirui's avatar
Chen Sirui committed
111
112
113
                # Each feat has shape [N,q_feats]
        feat_list.append(self.project_fcs[-1](x))
        feat_list = torch.cat(feat_list).view(
114
115
116
117
118
            len(feat_list), -1, self.out_feats
        )
        ret = (
            (self.merger * feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
        )
Chen Sirui's avatar
Chen Sirui committed
119
        return ret