"examples/pytorch/rgcn-hetero/entity_classify.py" did not exist on "65e1ba4f604e7ad4764aaec573bb7638da1ac333"
mgcn.py 2.6 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Molecular GCN model proposed by Kearnes et al. (2016).
We use the description from "Neural Message Passing for Quantum Chemistry" Sec.2.
The model has an edge representation e_vw that is updated during message passing.
The message function is:
    - M(h_v, h_w, e_vw) = e_vw
The update function is:
    - U_v(h_v, m_v) = Affine(Affine(h_v) || m_v)
The edge update function is:
    - U_e(e_vw, h_v, h_w) = Affine(ReLU(W_e || e_vw) || Affine(h_v || h_w))
"""
import torch as T
import torch.nn as nn
import torch.nn.functional as F

import dgl

class NodeUpdateModule(nn.Module):
    def __init__(self, hv_dims):
        self.net1 = nn.Sequential(
                nn.Linear(hv_dims),
                nn.ReLU()
                )
        self.net2 = nn.Sequential(
                nn.Linear(hv_dims),
                nn.ReLU()
                )
    def forward(self, node, msgs):
        m = T.stack(msgs).mean(0)
        new_h = self.net2(T.cat(self.net1(node['hv']), m))
        return {'hv' : new_h}

class MessageModule(nn.Module):
    def __init__(self):
        pass
    def forward(self, src, dst, edge):
        return edge['he']

class EdgeUpdateModule(nn.Module):
    def __init__(self, he_dims):
        self.net1 = nn.Sequential(
                nn.Linear(he_dims),
                nn.ReLU()
                )
        self.net2 = nn.Sequential(
                nn.Linear(he_dims),
                nn.ReLU()
                )
        self.net3 = nn.Sequential(
                nn.Linear(he_dims),
                nn.ReLU()
                )
    def forward(self, src, dst, edge):
        new_he = self.net1(src['hv']) + self.net2(dst['hv']) + self.net3(edge['he'])
        return {'he' : new_he}

class EdgeModule(nn.Module):
    def __init__(self, he_dims):
        # use a flag to trigger either message module or edge update module.
        self.is_msg = True
        self.msg_mod = MessageModule()
        self.upd_mod = EdgeUpdateModule()
    def forward(self, src, dst, edge):
        if self.is_msg:
            self.is_msg = not self.is_msg
            return self.msg_mod(src, dst, edge)
        else:
            self.is_msg = not self.is_msg
            return self.upd_mod(src, dst, edge)

def train(g):
    # TODO(minjie): finish the complete training algorithm.
    g = dgl.DGLGraph(g)
    g.register_message_func(EdgeModule())
    g.register_update_func(NodeUpdateModule())
    # TODO(minjie): init hv and he
    num_iter = 10
    for i in range(num_iter):
        # The first call triggers message function and update all the nodes.
        g.update_all()
        # The second sendall updates all the edge features.
        g.send_all()