"vscode:/vscode.git/clone" did not exist on "fab7a9500360498d703f6068cb59497ba55fce62"
models.py 1.2 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.function as fn
KounianhuaDu's avatar
KounianhuaDu committed
2
3
4
import torch as th
import torch.nn as nn
import torch.nn.functional as F
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5

KounianhuaDu's avatar
KounianhuaDu committed
6

7
class Layer(nn.Module):
KounianhuaDu's avatar
KounianhuaDu committed
8
    def __init__(self, in_dim, out_dim):
9
        super().__init__()
KounianhuaDu's avatar
KounianhuaDu committed
10
11
        self.layer = nn.Linear(in_dim * 2, out_dim, bias=True)

12
13
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
14
            graph.ndata["h"] = feat
KounianhuaDu's avatar
KounianhuaDu committed
15

16
            if eweight is None:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
                graph.update_all(fn.copy_u("h", "m"), fn.mean("m", "h"))
18
            else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
19
20
                graph.edata["ew"] = eweight
                graph.update_all(fn.u_mul_e("h", "ew", "m"), fn.mean("m", "h"))
KounianhuaDu's avatar
KounianhuaDu committed
21

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
22
            h = self.layer(th.cat([graph.ndata["h"], feat], dim=-1))
KounianhuaDu's avatar
KounianhuaDu committed
23

24
            return h
KounianhuaDu's avatar
KounianhuaDu committed
25

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
26

27
28
29
30
31
32
class Model(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=40):
        super().__init__()
        self.in_layer = Layer(in_dim, hid_dim)
        self.hid_layer = Layer(hid_dim, hid_dim)
        self.out_layer = Layer(hid_dim, out_dim)
KounianhuaDu's avatar
KounianhuaDu committed
33

34
35
    def forward(self, graph, feat, eweight=None):
        h = self.in_layer(graph, feat.float(), eweight)
KounianhuaDu's avatar
KounianhuaDu committed
36
        h = F.relu(h)
37
        h = self.hid_layer(graph, h, eweight)
KounianhuaDu's avatar
KounianhuaDu committed
38
        h = F.relu(h)
39
        h = self.out_layer(graph, h, eweight)
KounianhuaDu's avatar
KounianhuaDu committed
40
        return h