model.py 2.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch as th
import torch.nn as nn
from dgl.nn import GATConv
from torch.nn import LSTM


class GeniePathConv(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, num_heads=1, residual=False):
        super(GeniePathConv, self).__init__()
        self.breadth_func = GATConv(in_dim, hid_dim, num_heads=num_heads, residual=residual)
        self.depth_func = LSTM(hid_dim, out_dim)

    def forward(self, graph, x, h, c):
        x = self.breadth_func(graph, x)
15
        x = th.tanh(x)
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        x = th.mean(x, dim=1)
        x, (h, c) = self.depth_func(x.unsqueeze(0), (h, c))
        x = x[0]
        return x, (h, c)


class GeniePath(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=16, num_layers=2, num_heads=1, residual=False):
        super(GeniePath, self).__init__()
        self.hid_dim = hid_dim
        self.linear1 = nn.Linear(in_dim, hid_dim)
        self.linear2 = nn.Linear(hid_dim, out_dim)
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(GeniePathConv(hid_dim, hid_dim, hid_dim, num_heads=num_heads, residual=residual))

    def forward(self, graph, x):
        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)

        x = self.linear1(x)
        for layer in self.layers:
            x, (h, c) = layer(graph, x, h, c)
        x = self.linear2(x)

        return x


class GeniePathLazy(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=16, num_layers=2, num_heads=1, residual=False):
        super(GeniePathLazy, self).__init__()
        self.hid_dim = hid_dim
        self.linear1 = nn.Linear(in_dim, hid_dim)
        self.linear2 = th.nn.Linear(hid_dim, out_dim)
        self.breaths = nn.ModuleList()
        self.depths = nn.ModuleList()
        for i in range(num_layers):
            self.breaths.append(GATConv(hid_dim, hid_dim, num_heads=num_heads, residual=residual))
            self.depths.append(LSTM(hid_dim*2, hid_dim))

    def forward(self, graph, x):
        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)

        x = self.linear1(x)
        h_tmps = []
        for layer in self.breaths:
63
            h_tmps.append(th.mean(th.tanh(layer(graph, x)), dim=1))
64
65
66
67
68
69
70
        x = x.unsqueeze(0)
        for h_tmp, layer in zip(h_tmps, self.depths):
            in_cat = th.cat((h_tmp.unsqueeze(0), x), -1)
            x, (h, c) = layer(in_cat, (h, c))
        x = self.linear2(x[0])

        return x