models.py 1.2 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
5
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

6
class Layer(nn.Module):
KounianhuaDu's avatar
KounianhuaDu committed
7
    def __init__(self, in_dim, out_dim):
8
        super().__init__()
KounianhuaDu's avatar
KounianhuaDu committed
9
10
        self.layer = nn.Linear(in_dim * 2, out_dim, bias=True)

11
12
13
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            graph.ndata['h'] = feat
KounianhuaDu's avatar
KounianhuaDu committed
14

15
16
17
18
19
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
            else:
                graph.edata['ew'] = eweight
                graph.update_all(fn.u_mul_e('h', 'ew', 'm'), fn.mean('m', 'h'))
KounianhuaDu's avatar
KounianhuaDu committed
20

21
            h = self.layer(th.cat([graph.ndata['h'], feat], dim=-1))
KounianhuaDu's avatar
KounianhuaDu committed
22

23
            return h
KounianhuaDu's avatar
KounianhuaDu committed
24

25
26
27
28
29
30
class Model(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=40):
        super().__init__()
        self.in_layer = Layer(in_dim, hid_dim)
        self.hid_layer = Layer(hid_dim, hid_dim)
        self.out_layer = Layer(hid_dim, out_dim)
KounianhuaDu's avatar
KounianhuaDu committed
31

32
33
    def forward(self, graph, feat, eweight=None):
        h = self.in_layer(graph, feat.float(), eweight)
KounianhuaDu's avatar
KounianhuaDu committed
34
        h = F.relu(h)
35
        h = self.hid_layer(graph, h, eweight)
KounianhuaDu's avatar
KounianhuaDu committed
36
        h = F.relu(h)
37
        h = self.out_layer(graph, h, eweight)
KounianhuaDu's avatar
KounianhuaDu committed
38
        return h