model.py 2.86 KB
Newer Older
1
2
3
4
import torch as th
import torch.nn as nn
from torch.nn import LSTM

5
6
from dgl.nn import GATConv

7
8
9
10

class GeniePathConv(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, num_heads=1, residual=False):
        super(GeniePathConv, self).__init__()
11
12
13
        self.breadth_func = GATConv(
            in_dim, hid_dim, num_heads=num_heads, residual=residual
        )
14
15
16
17
        self.depth_func = LSTM(hid_dim, out_dim)

    def forward(self, graph, x, h, c):
        x = self.breadth_func(graph, x)
18
        x = th.tanh(x)
19
20
21
22
23
24
25
        x = th.mean(x, dim=1)
        x, (h, c) = self.depth_func(x.unsqueeze(0), (h, c))
        x = x[0]
        return x, (h, c)


class GeniePath(nn.Module):
26
27
28
29
30
31
32
33
34
    def __init__(
        self,
        in_dim,
        out_dim,
        hid_dim=16,
        num_layers=2,
        num_heads=1,
        residual=False,
    ):
35
36
37
38
39
40
        super(GeniePath, self).__init__()
        self.hid_dim = hid_dim
        self.linear1 = nn.Linear(in_dim, hid_dim)
        self.linear2 = nn.Linear(hid_dim, out_dim)
        self.layers = nn.ModuleList()
        for i in range(num_layers):
41
42
43
44
45
46
47
48
49
            self.layers.append(
                GeniePathConv(
                    hid_dim,
                    hid_dim,
                    hid_dim,
                    num_heads=num_heads,
                    residual=residual,
                )
            )
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    def forward(self, graph, x):
        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)

        x = self.linear1(x)
        for layer in self.layers:
            x, (h, c) = layer(graph, x, h, c)
        x = self.linear2(x)

        return x


class GeniePathLazy(nn.Module):
64
65
66
67
68
69
70
71
72
    def __init__(
        self,
        in_dim,
        out_dim,
        hid_dim=16,
        num_layers=2,
        num_heads=1,
        residual=False,
    ):
73
74
75
76
77
78
79
        super(GeniePathLazy, self).__init__()
        self.hid_dim = hid_dim
        self.linear1 = nn.Linear(in_dim, hid_dim)
        self.linear2 = th.nn.Linear(hid_dim, out_dim)
        self.breaths = nn.ModuleList()
        self.depths = nn.ModuleList()
        for i in range(num_layers):
80
81
82
83
84
85
            self.breaths.append(
                GATConv(
                    hid_dim, hid_dim, num_heads=num_heads, residual=residual
                )
            )
            self.depths.append(LSTM(hid_dim * 2, hid_dim))
86
87
88
89
90
91
92
93

    def forward(self, graph, x):
        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)

        x = self.linear1(x)
        h_tmps = []
        for layer in self.breaths:
94
            h_tmps.append(th.mean(th.tanh(layer(graph, x)), dim=1))
95
96
97
98
99
100
101
        x = x.unsqueeze(0)
        for h_tmp, layer in zip(h_tmps, self.depths):
            in_cat = th.cat((h_tmp.unsqueeze(0), x), -1)
            x, (h, c) = layer(in_cat, (h, c))
        x = self.linear2(x[0])

        return x