"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a7e9f85e21dde12f2f2489702ea82db40ebb31d2"
Commit bce9dd6e authored by zzhang-cn's avatar zzhang-cn
Browse files

Merge branch 'master' of https://github.com/zzhang-cn/dgl

parents e7c51805 fb214b30
"""Molecular GCN model proposed by Kearnes et al. (2016).
We use the description from "Neural Message Passing for Quantum Chemistry" Sec.2.
The model has an edge representation e_vw that is updated during message passing.
The message function is:
- M(h_v, h_w, e_vw) = e_vw
The update function is:
- U_v(h_v, m_v) = Affine(Affine(h_v) || m_v)
The edge update function is:
- U_e(e_vw, h_v, h_w) = Affine(ReLU(W_e || e_vw) || Affine(h_v || h_w))
"""
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import dgl
class NodeUpdateModule(nn.Module):
def __init__(self, hv_dims):
self.net1 = nn.Sequential(
nn.Linear(hv_dims),
nn.ReLU()
)
self.net2 = nn.Sequential(
nn.Linear(hv_dims),
nn.ReLU()
)
def forward(self, node, msgs):
m = T.stack(msgs).mean(0)
new_h = self.net2(T.cat(self.net1(node['hv']), m))
return {'hv' : new_h}
class MessageModule(nn.Module):
def __init__(self):
pass
def forward(self, src, dst, edge):
return edge['he']
class EdgeUpdateModule(nn.Module):
def __init__(self, he_dims):
self.net1 = nn.Sequential(
nn.Linear(he_dims),
nn.ReLU()
)
self.net2 = nn.Sequential(
nn.Linear(he_dims),
nn.ReLU()
)
self.net3 = nn.Sequential(
nn.Linear(he_dims),
nn.ReLU()
)
def forward(self, src, dst, edge):
new_he = self.net1(src['hv']) + self.net2(dst['hv']) + self.net3(edge['he'])
return {'he' : new_he}
class EdgeModule(nn.Module):
def __init__(self, he_dims):
# use a flag to trigger either message module or edge update module.
self.is_msg = True
self.msg_mod = MessageModule()
self.upd_mod = EdgeUpdateModule()
def forward(self, src, dst, edge):
if self.is_msg:
self.is_msg = not self.is_msg
return self.msg_mod(src, dst, edge)
else:
self.is_msg = not self.is_msg
return self.upd_mod(src, dst, edge)
def train(g):
# TODO(minjie): finish the complete training algorithm.
g = dgl.DGLGraph(g)
g.register_message_func(EdgeModule())
g.register_update_func(NodeUpdateModule())
# TODO(minjie): init hv and he
num_iter = 10
for i in range(num_iter):
# The first call triggers message function and update all the nodes.
g.update_all()
# The second sendall updates all the edge features.
g.send_all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment