Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7a50be64
Commit
7a50be64
authored
Apr 23, 2018
by
zzhang-cn
Browse files
tensor/nn
parent
05547b37
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
37 deletions
+74
-37
mx.py
mx.py
+74
-37
No files found.
mx.py
View file @
7a50be64
import
networkx
as
nx
import
networkx
as
nx
from
networkx.classes.graph
import
Graph
#from networkx.classes.graph import Graph
from
networkx.classes.digraph
import
DiGraph
import
torch
as
th
#import torch.nn.functional as F
import
torch.nn
as
nn
from
torch.autograd
import
Variable
as
Var
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training
# TODO: loss functions and training
class
mx_Graph
(
Graph
):
class
mx_Graph
(
Di
Graph
):
def
__init__
(
self
,
*
args
,
**
kargs
):
def
__init__
(
self
,
*
args
,
**
kargs
):
super
(
mx_Graph
,
self
).
__init__
(
*
args
,
**
kargs
)
super
(
mx_Graph
,
self
).
__init__
(
*
args
,
**
kargs
)
self
.
set_msg_func
()
self
.
set_msg_func
()
self
.
set_gather_func
()
self
.
set_reduction_func
()
self
.
set_update_func
()
self
.
set_update_func
()
self
.
set_readout_func
()
self
.
set_readout_func
()
self
.
init_reprs
()
self
.
init_reprs
()
...
@@ -26,47 +32,58 @@ class mx_Graph(Graph):
...
@@ -26,47 +32,58 @@ class mx_Graph(Graph):
assert
u
in
self
.
nodes
assert
u
in
self
.
nodes
return
self
.
nodes
[
u
][
name
]
return
self
.
nodes
[
u
][
name
]
def
set_reduction_func
(
self
):
def
_default_reduction_func
(
x_s
):
out
=
th
.
stack
(
x_s
)
out
=
th
.
sum
(
out
,
dim
=
0
)
return
out
self
.
_reduction_func
=
_default_reduction_func
def
set_gather_func
(
self
,
u
=
None
):
pass
def
set_msg_func
(
self
,
func
=
None
,
u
=
None
):
def
set_msg_func
(
self
,
func
=
None
,
u
=
None
):
"""Function that gathers messages from neighbors"""
"""Function that gathers messages from neighbors"""
def
_default_msg_func
(
u
):
def
_default_msg_func
(
u
):
assert
u
in
self
.
nodes
assert
u
in
self
.
nodes
msg_gathered
=
0
msg_gathered
=
[]
for
v
in
self
.
adj
[
u
]:
for
v
in
self
.
pred
[
u
]:
x
=
self
.
get_repr
(
v
)
x
=
self
.
get_repr
(
v
)
if
x
is
not
None
:
if
x
is
not
None
:
msg_gathered
+=
x
msg_gathered
.
append
(
x
)
return
msg_gathered
return
self
.
_reduction_func
(
msg_gathered
)
# TODO: per node message function
# TODO: per node message function
# TODO: 'sum' should be a separate function
# TODO: 'sum' should be a separate function
if
func
==
None
:
if
func
==
None
:
self
.
msg_func
=
_default_msg_func
self
.
_
msg_func
=
_default_msg_func
else
:
else
:
self
.
msg_func
=
func
self
.
_
msg_func
=
func
def
set_update_func
(
self
,
func
=
None
,
u
=
None
):
def
set_update_func
(
self
,
func
=
None
,
u
=
None
):
"""
"""
Update function upon receiving an aggregate
Update function upon receiving an aggregate
message from a node's neighbor
message from a node's neighbor
"""
"""
def
_default_update_func
(
u
,
m
):
def
_default_update_func
(
x
,
m
):
h_new
=
self
.
nodes
[
u
][
'h'
]
+
m
return
x
+
m
self
.
set_repr
(
u
,
h_new
)
# TODO: per node update function
# TODO: per node update function
if
func
==
None
:
if
func
==
None
:
self
.
update_func
=
_default_update_func
self
.
_
update_func
=
_default_update_func
else
:
else
:
self
.
update_func
=
func
self
.
_
update_func
=
func
def
set_readout_func
(
self
,
func
=
None
):
def
set_readout_func
(
self
,
func
=
None
):
"""Readout function of the whole graph"""
"""Readout function of the whole graph"""
def
_default_readout_func
():
def
_default_readout_func
():
readout
=
0
valid_hs
=
[]
for
n
in
self
.
nodes
:
for
x
in
self
.
nodes
:
readout
+=
self
.
nodes
[
n
][
'h'
]
h
=
self
.
get_repr
(
x
)
return
readout
if
h
is
not
None
:
valid_hs
.
append
(
h
)
return
self
.
_reduction_func
(
valid_hs
)
#
if
func
==
None
:
if
func
==
None
:
self
.
readout_func
=
_default_readout_func
self
.
readout_func
=
_default_readout_func
else
:
else
:
...
@@ -78,15 +95,21 @@ class mx_Graph(Graph):
...
@@ -78,15 +95,21 @@ class mx_Graph(Graph):
def
update_to
(
self
,
u
):
def
update_to
(
self
,
u
):
"""Pull messages from 1-step away neighbors of u"""
"""Pull messages from 1-step away neighbors of u"""
assert
u
in
self
.
nodes
assert
u
in
self
.
nodes
m
=
self
.
msg_func
(
u
=
u
)
m
=
self
.
_msg_func
(
u
=
u
)
self
.
update_func
(
u
,
m
)
x
=
self
.
get_repr
(
u
)
# TODO: ugly hack
if
x
is
None
:
y
=
self
.
_update_func
(
m
)
else
:
y
=
self
.
_update_func
(
x
,
m
)
self
.
set_repr
(
u
,
y
)
def
update_from
(
self
,
u
):
def
update_from
(
self
,
u
):
"""Update u's 1-step away neighbors"""
"""Update u's 1-step away neighbors"""
assert
u
in
self
.
nodes
assert
u
in
self
.
nodes
# TODO: this asks v to pull from nodes other than
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
# TODO: u, is this a good thing?
for
v
in
self
.
adj
[
u
]:
for
v
in
self
.
succ
[
u
]:
self
.
update_to
(
v
)
self
.
update_to
(
v
)
def
print_all
(
self
):
def
print_all
(
self
):
...
@@ -95,25 +118,39 @@ class mx_Graph(Graph):
...
@@ -95,25 +118,39 @@ class mx_Graph(Graph):
print
()
print
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
th
.
random
.
manual_seed
(
0
)
''': this makes a digraph with double edges
tg = nx.path_graph(10)
tg = nx.path_graph(10)
g = mx_Graph(tg)
g = mx_Graph(tg)
g.print_all()
g.print_all()
tr
=
nx
.
balanced_tree
(
2
,
3
)
# this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr)
m_tr = mx_Graph(tr)
m_tr.print_all()
m_tr.print_all()
'''
print
(
"testing GRU update"
)
g
=
mx_Graph
(
nx
.
path_graph
(
3
))
g
=
mx_Graph
(
nx
.
path_graph
(
3
))
g
.
set_update_func
(
nn
.
GRUCell
(
4
,
4
))
for
n
in
g
:
for
n
in
g
:
g
.
set_repr
(
n
,
int
(
n
)
+
10
)
g
.
set_repr
(
n
,
Var
(
th
.
rand
(
2
,
4
)))
g
.
print_all
()
print
(
g
.
readout
())
print
(
"
\t
**before:"
);
g
.
print_all
()
g
.
update_from
(
0
)
print
(
"before update:
\t
"
,
g
.
nodes
[
0
])
g
.
update_from
(
1
)
g
.
update_to
(
0
)
print
(
"
\t
**after:"
);
g
.
print_all
()
print
(
'after update:
\t
'
,
g
.
nodes
[
0
])
g
.
print_all
()
print
(
"
\n
testing fwd update"
)
g
.
clear
()
print
(
g
.
readout
())
g
.
add_path
([
0
,
1
,
2
])
g
.
init_reprs
()
fwd_net
=
nn
.
Sequential
(
nn
.
Linear
(
4
,
4
),
nn
.
ReLU
())
g
.
set_update_func
(
fwd_net
)
g
.
set_repr
(
0
,
Var
(
th
.
rand
(
2
,
4
)))
print
(
"
\t
**before:"
);
g
.
print_all
()
g
.
update_from
(
0
)
g
.
update_from
(
1
)
print
(
"
\t
**after:"
);
g
.
print_all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment