Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
bce9dd6e
Commit
bce9dd6e
authored
Jun 19, 2018
by
zzhang-cn
Browse files
Merge branch 'master' of
https://github.com/zzhang-cn/dgl
parents
e7c51805
fb214b30
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
81 additions
and
0 deletions
+81
-0
examples/pytorch/mgcn.py
examples/pytorch/mgcn.py
+81
-0
No files found.
examples/pytorch/mgcn.py
0 → 100644
View file @
bce9dd6e
"""Molecular GCN model proposed by Kearnes et al. (2016).
We use the description from "Neural Message Passing for Quantum Chemistry" Sec.2.
The model has an edge representation e_vw that is updated during message passing.
The message function is:
- M(h_v, h_w, e_vw) = e_vw
The update function is:
- U_v(h_v, m_v) = Affine(Affine(h_v) || m_v)
The edge update function is:
- U_e(e_vw, h_v, h_w) = Affine(ReLU(W_e || e_vw) || Affine(h_v || h_w))
"""
import
torch
as
T
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
dgl
class
NodeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
hv_dims
):
self
.
net1
=
nn
.
Sequential
(
nn
.
Linear
(
hv_dims
),
nn
.
ReLU
()
)
self
.
net2
=
nn
.
Sequential
(
nn
.
Linear
(
hv_dims
),
nn
.
ReLU
()
)
def
forward
(
self
,
node
,
msgs
):
m
=
T
.
stack
(
msgs
).
mean
(
0
)
new_h
=
self
.
net2
(
T
.
cat
(
self
.
net1
(
node
[
'hv'
]),
m
))
return
{
'hv'
:
new_h
}
class
MessageModule
(
nn
.
Module
):
def
__init__
(
self
):
pass
def
forward
(
self
,
src
,
dst
,
edge
):
return
edge
[
'he'
]
class
EdgeUpdateModule
(
nn
.
Module
):
def
__init__
(
self
,
he_dims
):
self
.
net1
=
nn
.
Sequential
(
nn
.
Linear
(
he_dims
),
nn
.
ReLU
()
)
self
.
net2
=
nn
.
Sequential
(
nn
.
Linear
(
he_dims
),
nn
.
ReLU
()
)
self
.
net3
=
nn
.
Sequential
(
nn
.
Linear
(
he_dims
),
nn
.
ReLU
()
)
def
forward
(
self
,
src
,
dst
,
edge
):
new_he
=
self
.
net1
(
src
[
'hv'
])
+
self
.
net2
(
dst
[
'hv'
])
+
self
.
net3
(
edge
[
'he'
])
return
{
'he'
:
new_he
}
class
EdgeModule
(
nn
.
Module
):
def
__init__
(
self
,
he_dims
):
# use a flag to trigger either message module or edge update module.
self
.
is_msg
=
True
self
.
msg_mod
=
MessageModule
()
self
.
upd_mod
=
EdgeUpdateModule
()
def
forward
(
self
,
src
,
dst
,
edge
):
if
self
.
is_msg
:
self
.
is_msg
=
not
self
.
is_msg
return
self
.
msg_mod
(
src
,
dst
,
edge
)
else
:
self
.
is_msg
=
not
self
.
is_msg
return
self
.
upd_mod
(
src
,
dst
,
edge
)
def
train
(
g
):
# TODO(minjie): finish the complete training algorithm.
g
=
dgl
.
DGLGraph
(
g
)
g
.
register_message_func
(
EdgeModule
())
g
.
register_update_func
(
NodeUpdateModule
())
# TODO(minjie): init hv and he
num_iter
=
10
for
i
in
range
(
num_iter
):
# The first call triggers message function and update all the nodes.
g
.
update_all
()
# The second sendall updates all the edge features.
g
.
send_all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment