Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fb214b30
You need to sign in or sign up before continuing.
Commit
fb214b30
authored
Jun 18, 2018
by
Minjie Wang
Browse files
Add mgcn example for discussion
parent
2455934a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
81 additions
and
0 deletions
+81
-0
examples/pytorch/mgcn.py
examples/pytorch/mgcn.py
+81
-0
No files found.
examples/pytorch/mgcn.py
0 → 100644
View file @
fb214b30
"""Molecular GCN model proposed by Kearnes et al. (2016).
We use the description from "Neural Message Passing for Quantum Chemistry" Sec.2.
The model has an edge representation e_vw that is updated during message passing.
The message function is:
- M(h_v, h_w, e_vw) = e_vw
The update function is:
- U_v(h_v, m_v) = Affine(Affine(h_v) || m_v)
The edge update function is:
- U_e(e_vw, h_v, h_w) = Affine(ReLU(W_e || e_vw) || Affine(h_v || h_w))
"""
import
torch
as
T
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
class
NodeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
hv_dims
):
self
.
net1
=
nn
.
Sequential
(
nn
.
Linear
(
hv_dims
),
nn
.
ReLU
()
)
self
.
net2
=
nn
.
Sequential
(
nn
.
Linear
(
hv_dims
),
nn
.
ReLU
()
)
def
forward
(
self
,
node
,
msgs
):
m
=
T
.
stack
(
msgs
).
mean
(
0
)
new_h
=
self
.
net2
(
T
.
cat
(
self
.
net1
(
node
[
'hv'
]),
m
))
return
{
'hv'
:
new_h
}
class
MessageModule
(
nn
.
Module
):
def
__init__
(
self
):
pass
def
forward
(
self
,
src
,
dst
,
edge
):
return
edge
[
'he'
]
class
EdgeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
he_dims
):
self
.
net1
=
nn
.
Sequential
(
nn
.
Linear
(
he_dims
),
nn
.
ReLU
()
)
self
.
net2
=
nn
.
Sequential
(
nn
.
Linear
(
he_dims
),
nn
.
ReLU
()
)
self
.
net3
=
nn
.
Sequential
(
nn
.
Linear
(
he_dims
),
nn
.
ReLU
()
)
def
forward
(
self
,
src
,
dst
,
edge
):
new_he
=
self
.
net1
(
src
[
'hv'
])
+
self
.
net2
(
dst
[
'hv'
])
+
self
.
net3
(
edge
[
'he'
])
return
{
'he'
:
new_he
}
class
EdgeModule
(
nn
.
Module
):
def
__init__
(
self
,
he_dims
):
# use a flag to trigger either message module or edge update module.
self
.
is_msg
=
True
self
.
msg_mod
=
MessageModule
()
self
.
upd_mod
=
EdgeUpdateModule
()
def
forward
(
self
,
src
,
dst
,
edge
):
if
self
.
is_msg
:
self
.
is_msg
=
not
self
.
is_msg
return
self
.
msg_mod
(
src
,
dst
,
edge
)
else
:
self
.
is_msg
=
not
self
.
is_msg
return
self
.
upd_mod
(
src
,
dst
,
edge
)
def
train
(
g
):
# TODO(minjie): finish the complete training algorithm.
g
=
dgl
.
DGLGraph
(
g
)
g
.
register_message_func
(
EdgeModule
())
g
.
register_update_func
(
NodeUpdateModule
())
# TODO(minjie): init hv and he
num_iter
=
10
for
i
in
range
(
num_iter
):
# The first call triggers message function and update all the nodes.
g
.
update_all
()
# The second sendall updates all the edge features.
g
.
send_all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment