Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
75a43e4a
Unverified
Commit
75a43e4a
authored
May 11, 2018
by
Zheng Zhang
Committed by
GitHub
May 11, 2018
Browse files
Merge pull request #7 from zzhang-cn/tensor
Tensor
parents
447d16bd
fa900595
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
613 additions
and
37 deletions
+613
-37
graph.py
graph.py
+151
-0
model.py
model.py
+260
-0
mx.py
mx.py
+74
-37
mx_scalar.py
mx_scalar.py
+128
-0
No files found.
graph.py
0 → 100644
View file @
75a43e4a
import
networkx
as
nx
import
torch
as
T
import
torch.nn
as
NN
from
util
import
*
class
DiGraph
(
NN
.
Module
):
'''
Reserved attributes:
* state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data
'''
def
__init__
(
self
,
graph
):
NN
.
Module
.
__init__
(
self
)
self
.
G
=
graph
self
.
message_funcs
=
[]
self
.
update_funcs
=
[]
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
return
self
.
G
.
nodes
()
if
nodes
==
'all'
else
nodes
def
_edges_or_all
(
self
,
edges
=
'all'
):
return
self
.
G
.
edges
()
if
edges
==
'all'
else
edges
def
_node_tag_name
(
self
,
v
):
return
'(%s)'
%
v
def
_edge_tag_name
(
self
,
u
,
v
):
return
'(%s, %s)'
%
(
min
(
u
,
v
),
max
(
u
,
v
))
def
zero_node_state
(
self
,
state_dims
,
batch_size
=
None
,
nodes
=
'all'
):
shape
=
(
[
batch_size
]
+
list
(
state_dims
)
if
batch_size
is
not
None
else
state_dims
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
G
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
))
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
G
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
G
.
node
[
v
][
'tag'
])
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
self
.
G
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
.
G
[
u
][
v
][
'tag'
])
def
remove_node_tag
(
self
,
nodes
=
'all'
):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
delattr
(
self
,
self
.
_node_tag_name
(
v
))
del
self
.
G
.
node
[
v
][
'tag'
]
def
remove_edge_tag
(
self
,
edges
=
'all'
):
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
delattr
(
self
,
self
.
_edge_tag_name
(
u
,
v
))
del
self
.
G
[
u
][
v
][
'tag'
]
@
property
def
node
(
self
):
return
self
.
G
.
node
@
property
def
edges
(
self
):
return
self
.
G
.
edges
def
edge_tags
(
self
):
for
u
,
v
in
self
.
G
.
edges
():
yield
self
.
G
[
u
][
v
][
'tag'
]
def
node_tags
(
self
):
for
v
in
self
.
G
.
nodes
():
yield
self
.
G
.
node
[
v
][
'tag'
]
def
states
(
self
):
for
v
in
self
.
G
.
nodes
():
yield
self
.
G
.
node
[
v
][
'state'
]
def
named_edge_tags
(
self
):
for
u
,
v
in
self
.
G
.
edges
():
yield
((
u
,
v
),
self
.
G
[
u
][
v
][
'tag'
])
def
named_node_tags
(
self
):
for
v
in
self
.
G
.
nodes
():
yield
(
v
,
self
.
G
.
node
[
v
][
'tag'
])
def
named_states
(
self
):
for
v
in
self
.
G
.
nodes
():
yield
(
v
,
self
.
G
.
node
[
v
][
'state'
])
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batched
=
False
):
'''
batched: whether to do a single batched computation instead of iterating
message function: accepts source state tensor and edge tag tensor, and
returns a message tensor
'''
self
.
message_funcs
.
append
((
self
.
_edges_or_all
(
edges
),
message_func
,
batched
))
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batched
=
False
):
'''
batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag),
and a list of tuples (source node, target node, edge attribute dictionary)
'''
self
.
update_funcs
.
append
((
self
.
_nodes_or_all
(
nodes
),
update_func
,
batched
))
def
draw
(
self
):
from
networkx.drawing.nx_agraph
import
graphviz_layout
pos
=
graphviz_layout
(
self
.
G
,
prog
=
'dot'
)
nx
.
draw
(
self
.
G
,
pos
,
with_labels
=
True
)
def
step
(
self
):
# update message
for
ebunch
,
f
,
batched
in
self
.
message_funcs
:
if
batched
:
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
source
=
T
.
stack
([
self
.
G
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
edge_tags
=
[
self
.
G
[
u
][
v
].
get
(
'tag'
,
None
)
for
u
,
v
in
ebunch
]
if
all
(
t
is
None
for
t
in
edge_tags
):
edge_tag
=
None
else
:
edge_tag
=
T
.
stack
([
self
.
G
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
message
=
f
(
source
,
edge_tag
)
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
self
.
G
[
u
][
v
][
'state'
]
=
message
[
i
]
else
:
for
u
,
v
in
ebunch
:
self
.
G
[
u
][
v
][
'state'
]
=
f
(
self
.
G
.
node
[
u
][
'state'
],
self
.
G
[
u
][
v
][
'tag'
]
)
# update state
# TODO: does it make sense to batch update the nodes?
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
for
v
in
vbunch
:
self
.
G
.
node
[
v
][
'state'
]
=
f
(
self
.
G
.
node
[
v
],
list
(
self
.
G
.
in_edges
(
v
,
data
=
True
)))
model.py
0 → 100644
View file @
75a43e4a
import
torch
as
T
import
torch.nn
as
NN
import
torch.nn.init
as
INIT
import
torch.nn.functional
as
F
import
numpy
as
NP
import
numpy.random
as
RNG
from
util
import
*
from
glimpse
import
create_glimpse
from
zoneout
import
ZoneoutLSTMCell
from
collections
import
namedtuple
import
os
from
graph
import
DiGraph
import
networkx
as
nx
no_msg
=
os
.
getenv
(
'NOMSG'
,
False
)
def
build_cnn
(
**
config
):
cnn_list
=
[]
filters
=
config
[
'filters'
]
kernel_size
=
config
[
'kernel_size'
]
in_channels
=
config
.
get
(
'in_channels'
,
3
)
final_pool_size
=
config
[
'final_pool_size'
]
for
i
in
range
(
len
(
filters
)):
module
=
NN
.
Conv2d
(
in_channels
if
i
==
0
else
filters
[
i
-
1
],
filters
[
i
],
kernel_size
,
padding
=
tuple
((
_
-
1
)
//
2
for
_
in
kernel_size
),
)
INIT
.
xavier_uniform_
(
module
.
weight
)
INIT
.
constant_
(
module
.
bias
,
0
)
cnn_list
.
append
(
module
)
if
i
<
len
(
filters
)
-
1
:
cnn_list
.
append
(
NN
.
LeakyReLU
())
cnn_list
.
append
(
NN
.
AdaptiveMaxPool2d
(
final_pool_size
))
return
NN
.
Sequential
(
*
cnn_list
)
class
TreeGlimpsedClassifier
(
NN
.
Module
):
def
__init__
(
self
,
n_children
=
2
,
n_depth
=
3
,
h_dims
=
128
,
node_tag_dims
=
128
,
edge_tag_dims
=
128
,
n_classes
=
10
,
steps
=
5
,
filters
=
[
16
,
32
,
64
,
128
,
256
],
kernel_size
=
(
3
,
3
),
final_pool_size
=
(
2
,
2
),
glimpse_type
=
'gaussian'
,
glimpse_size
=
(
15
,
15
),
):
'''
Basic idea:
* We detect objects through an undirected graphical model.
* The graphical model consists of a balanced tree of latent variables h
* Each h is then connected to a bbox variable b and a class variable y
* b of the root is fixed to cover the entire canvas
* All other h, b and y are updated through message passing
* The loss function should be either (not completed yet)
* multiset loss, or
* maximum bipartite matching (like Order Matters paper)
'''
NN
.
Module
.
__init__
(
self
)
self
.
n_children
=
n_children
self
.
n_depth
=
n_depth
self
.
h_dims
=
h_dims
self
.
node_tag_dims
=
node_tag_dims
self
.
edge_tag_dims
=
edge_tag_dims
self
.
h_dims
=
h_dims
self
.
n_classes
=
n_classes
self
.
glimpse
=
create_glimpse
(
glimpse_type
,
glimpse_size
)
self
.
steps
=
steps
self
.
cnn
=
build_cnn
(
filters
=
filters
,
kernel_size
=
kernel_size
,
final_pool_size
=
final_pool_size
,
)
# Create graph of latent variables
G
=
nx
.
balanced_tree
(
self
.
n_children
,
self
.
n_depth
)
nx
.
relabel_nodes
(
G
,
{
i
:
'h%d'
%
i
for
i
in
range
(
len
(
G
.
nodes
()))},
False
)
self
.
h_nodes_list
=
h_nodes_list
=
list
(
G
.
nodes
)
for
h
in
h_nodes_list
:
G
.
node
[
h
][
'type'
]
=
'h'
b_nodes_list
=
[
'b%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
y_nodes_list
=
[
'y%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
self
.
b_nodes_list
=
b_nodes_list
self
.
y_nodes_list
=
y_nodes_list
hy_edge_list
=
[(
h
,
y
)
for
h
,
y
in
zip
(
h_nodes_list
,
y_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
b_nodes_list
)]
yh_edge_list
=
[(
y
,
h
)
for
y
,
h
in
zip
(
y_nodes_list
,
h_nodes_list
)]
bh_edge_list
=
[(
b
,
h
)
for
b
,
h
in
zip
(
b_nodes_list
,
h_nodes_list
)]
G
.
add_nodes_from
(
b_nodes_list
,
type
=
'b'
)
G
.
add_nodes_from
(
y_nodes_list
,
type
=
'y'
)
G
.
add_edges_from
(
hy_edge_list
)
G
.
add_edges_from
(
hb_edge_list
)
self
.
G
=
DiGraph
(
nx
.
DiGraph
(
G
))
hh_edge_list
=
[(
u
,
v
)
for
u
,
v
in
self
.
G
.
edges
()
if
self
.
G
.
node
[
u
][
'type'
]
==
self
.
G
.
node
[
v
][
'type'
]
==
'h'
]
self
.
G
.
init_node_tag_with
(
node_tag_dims
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
))
self
.
G
.
init_edge_tag_with
(
edge_tag_dims
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
hy_edge_list
+
hb_edge_list
+
bh_edge_list
)
self
.
G
.
init_edge_tag_with
(
h_dims
*
n_classes
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
yh_edge_list
)
# y -> h. An attention over embeddings dynamically generated through edge tags
self
.
G
.
register_message_func
(
self
.
_y_to_h
,
edges
=
yh_edge_list
,
batched
=
True
)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self
.
bh_1
=
NN
.
Linear
(
self
.
glimpse
.
att_params
,
h_dims
)
self
.
bh_2
=
NN
.
Linear
(
edge_tag_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
2
*
h_dims
+
filters
[
-
1
]
*
NP
.
prod
(
final_pool_size
),
h_dims
)
self
.
G
.
register_message_func
(
self
.
_b_to_h
,
edges
=
bh_edge_list
,
batched
=
True
)
# h -> h. Just passes h itself
self
.
G
.
register_message_func
(
self
.
_h_to_h
,
edges
=
hh_edge_list
,
batched
=
True
)
# h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb
self
.
hb
=
NN
.
Linear
(
h_dims
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
G
.
register_message_func
(
self
.
_h_to_b
,
edges
=
hb_edge_list
,
batched
=
True
)
# h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy
self
.
hy
=
NN
.
Linear
(
h_dims
+
edge_tag_dims
,
self
.
n_classes
)
self
.
G
.
register_message_func
(
self
.
_h_to_y
,
edges
=
hy_edge_list
,
batched
=
True
)
# b update: just adds the original b by Δb
self
.
G
.
register_update_func
(
self
.
_update_b
,
nodes
=
b_nodes_list
,
batched
=
False
)
# y update: also adds y by Δy
self
.
G
.
register_update_func
(
self
.
_update_y
,
nodes
=
y_nodes_list
,
batched
=
False
)
# h update: simply adds h by the average messages and then passes it through ReLU
self
.
G
.
register_update_func
(
self
.
_update_h
,
nodes
=
h_nodes_list
,
batched
=
False
)
def
_y_to_h
(
self
,
source
,
edge_tag
):
'''
source: (n_yh_edges, batch_size, 10) logits
edge_tag: (n_yh_edges, edge_tag_dims)
'''
n_yh_edges
,
batch_size
,
_
=
source
.
shape
w
=
edge_tag
.
reshape
(
n_yh_edges
,
1
,
self
.
n_classes
,
self
.
h_dims
)
w
=
w
.
expand
(
n_yh_edges
,
batch_size
,
self
.
n_classes
,
self
.
h_dims
)
source
=
source
[:,
:,
None
,
:]
return
(
F
.
softmax
(
source
)
@
w
).
reshape
(
n_yh_edges
,
batch_size
,
self
.
h_dims
)
def
_b_to_h
(
self
,
source
,
edge_tag
):
'''
source: (n_bh_edges, batch_size, 6) bboxes
edge_tag: (n_bh_edges, edge_tag_dims)
'''
n_bh_edges
,
batch_size
,
_
=
source
.
shape
# FIXME: really using self.x is a bad design here
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
source
,
_
=
self
.
glimpse
.
rescale
(
source
,
False
)
_source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
m_b
=
T
.
relu
(
self
.
bh_1
(
_source
))
m_t
=
T
.
relu
(
self
.
bh_2
(
edge_tag
))
m_t
=
m_t
[:,
None
,
:].
expand
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
m_t
=
m_t
.
reshape
(
-
1
,
self
.
h_dims
)
# glimpse takes batch dimension first, glimpse dimension second.
# here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them
g
=
self
.
glimpse
(
self
.
x
,
source
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
grows
,
gcols
=
g
.
size
()[
-
2
:]
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
grows
,
gcols
)
phi
=
self
.
cnn
(
g
).
reshape
(
n_bh_edges
*
batch_size
,
-
1
)
# TODO: add an attribute (g) to h
m
=
self
.
bh_all
(
T
.
cat
([
m_b
,
m_t
,
phi
],
1
))
m
=
m
.
reshape
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
return
m
def
_h_to_h
(
self
,
source
,
edge_tag
):
return
source
def
_h_to_b
(
self
,
source
,
edge_tag
):
n_hb_edges
,
batch_size
,
_
=
source
.
shape
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_hb_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hb_edges
*
batch_size
,
-
1
)
db
=
self
.
hb
(
I
)
return
db
.
reshape
(
n_hb_edges
,
batch_size
,
-
1
)
def
_h_to_y
(
self
,
source
,
edge_tag
):
n_hy_edges
,
batch_size
,
_
=
source
.
shape
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_hy_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hy_edges
*
batch_size
,
-
1
)
dy
=
self
.
hy
(
I
)
return
dy
.
reshape
(
n_hy_edges
,
batch_size
,
-
1
)
def
_update_b
(
self
,
b
,
b_n
):
return
b
[
'state'
]
+
b_n
[
0
][
2
][
'state'
]
def
_update_y
(
self
,
y
,
y_n
):
return
y
[
'state'
]
+
y_n
[
0
][
2
][
'state'
]
def
_update_h
(
self
,
h
,
h_n
):
m
=
T
.
stack
([
e
[
2
][
'state'
]
for
e
in
h_n
]).
mean
(
0
)
return
T
.
relu
(
h
[
'state'
]
+
m
)
def
forward
(
self
,
x
,
y
=
None
):
self
.
x
=
x
batch_size
=
x
.
shape
[
0
]
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
glimpse
.
att_params
,),
batch_size
,
nodes
=
self
.
b_nodes_list
)
for
t
in
range
(
self
.
steps
):
self
.
G
.
step
()
# We don't change b of the root
self
.
G
.
node
[
'b0'
][
'state'
].
zero_
()
self
.
y_pre
=
T
.
stack
(
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
1
)
self
.
v_B
=
T
.
stack
(
[
self
.
glimpse
.
rescale
(
self
.
G
.
node
[
'b%d'
%
i
][
'state'
],
False
)[
0
]
for
i
in
range
(
self
.
n_nodes
)],
1
,
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
return
self
.
G
.
node
[
'h0'
][
'state'
]
@
property
def
n_nodes
(
self
):
return
(
self
.
n_children
**
self
.
n_depth
-
1
)
//
(
self
.
n_children
-
1
)
@
property
def
n_leaves
(
self
):
return
self
.
n_children
**
(
self
.
n_depth
-
1
)
mx.py
View file @
75a43e4a
import
networkx
as
nx
from
networkx.classes.graph
import
Graph
#from networkx.classes.graph import Graph
from
networkx.classes.digraph
import
DiGraph
import
torch
as
th
#import torch.nn.functional as F
import
torch.nn
as
nn
from
torch.autograd
import
Variable
as
Var
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training
class
mx_Graph
(
Graph
):
class
mx_Graph
(
Di
Graph
):
def
__init__
(
self
,
*
args
,
**
kargs
):
super
(
mx_Graph
,
self
).
__init__
(
*
args
,
**
kargs
)
self
.
set_msg_func
()
self
.
set_gather_func
()
self
.
set_reduction_func
()
self
.
set_update_func
()
self
.
set_readout_func
()
self
.
init_reprs
()
...
...
@@ -26,47 +32,58 @@ class mx_Graph(Graph):
assert
u
in
self
.
nodes
return
self
.
nodes
[
u
][
name
]
def
set_reduction_func
(
self
):
def
_default_reduction_func
(
x_s
):
out
=
th
.
stack
(
x_s
)
out
=
th
.
sum
(
out
,
dim
=
0
)
return
out
self
.
_reduction_func
=
_default_reduction_func
def
set_gather_func
(
self
,
u
=
None
):
pass
def
set_msg_func
(
self
,
func
=
None
,
u
=
None
):
"""Function that gathers messages from neighbors"""
def
_default_msg_func
(
u
):
assert
u
in
self
.
nodes
msg_gathered
=
0
for
v
in
self
.
adj
[
u
]:
msg_gathered
=
[]
for
v
in
self
.
pred
[
u
]:
x
=
self
.
get_repr
(
v
)
if
x
is
not
None
:
msg_gathered
+=
x
return
msg_gathered
msg_gathered
.
append
(
x
)
return
self
.
_reduction_func
(
msg_gathered
)
# TODO: per node message function
# TODO: 'sum' should be a separate function
if
func
==
None
:
self
.
msg_func
=
_default_msg_func
self
.
_
msg_func
=
_default_msg_func
else
:
self
.
msg_func
=
func
self
.
_
msg_func
=
func
def
set_update_func
(
self
,
func
=
None
,
u
=
None
):
"""
Update function upon receiving an aggregate
message from a node's neighbor
"""
def
_default_update_func
(
u
,
m
):
h_new
=
self
.
nodes
[
u
][
'h'
]
+
m
self
.
set_repr
(
u
,
h_new
)
def
_default_update_func
(
x
,
m
):
return
x
+
m
# TODO: per node update function
if
func
==
None
:
self
.
update_func
=
_default_update_func
self
.
_
update_func
=
_default_update_func
else
:
self
.
update_func
=
func
self
.
_
update_func
=
func
def
set_readout_func
(
self
,
func
=
None
):
"""Readout function of the whole graph"""
def
_default_readout_func
():
readout
=
0
for
n
in
self
.
nodes
:
readout
+=
self
.
nodes
[
n
][
'h'
]
return
readout
valid_hs
=
[]
for
x
in
self
.
nodes
:
h
=
self
.
get_repr
(
x
)
if
h
is
not
None
:
valid_hs
.
append
(
h
)
return
self
.
_reduction_func
(
valid_hs
)
#
if
func
==
None
:
self
.
readout_func
=
_default_readout_func
else
:
...
...
@@ -78,15 +95,21 @@ class mx_Graph(Graph):
def
update_to
(
self
,
u
):
"""Pull messages from 1-step away neighbors of u"""
assert
u
in
self
.
nodes
m
=
self
.
msg_func
(
u
=
u
)
self
.
update_func
(
u
,
m
)
m
=
self
.
_msg_func
(
u
=
u
)
x
=
self
.
get_repr
(
u
)
# TODO: ugly hack
if
x
is
None
:
y
=
self
.
_update_func
(
m
)
else
:
y
=
self
.
_update_func
(
x
,
m
)
self
.
set_repr
(
u
,
y
)
def
update_from
(
self
,
u
):
"""Update u's 1-step away neighbors"""
assert
u
in
self
.
nodes
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
for
v
in
self
.
adj
[
u
]:
for
v
in
self
.
succ
[
u
]:
self
.
update_to
(
v
)
def
print_all
(
self
):
...
...
@@ -95,25 +118,39 @@ class mx_Graph(Graph):
print
()
if
__name__
==
'__main__'
:
th
.
random
.
manual_seed
(
0
)
''': this makes a digraph with double edges
tg = nx.path_graph(10)
g = mx_Graph(tg)
g.print_all()
tr
=
nx
.
balanced_tree
(
2
,
3
)
# this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr)
m_tr.print_all()
'''
print
(
"testing GRU update"
)
g
=
mx_Graph
(
nx
.
path_graph
(
3
))
g
.
set_update_func
(
nn
.
GRUCell
(
4
,
4
))
for
n
in
g
:
g
.
set_repr
(
n
,
int
(
n
)
+
10
)
g
.
print_all
()
print
(
g
.
readout
())
print
(
"before update:
\t
"
,
g
.
nodes
[
0
])
g
.
update_to
(
0
)
print
(
'after update:
\t
'
,
g
.
nodes
[
0
])
g
.
print_all
()
print
(
g
.
readout
())
g
.
set_repr
(
n
,
Var
(
th
.
rand
(
2
,
4
)))
print
(
"
\t
**before:"
);
g
.
print_all
()
g
.
update_from
(
0
)
g
.
update_from
(
1
)
print
(
"
\t
**after:"
);
g
.
print_all
()
print
(
"
\n
testing fwd update"
)
g
.
clear
()
g
.
add_path
([
0
,
1
,
2
])
g
.
init_reprs
()
fwd_net
=
nn
.
Sequential
(
nn
.
Linear
(
4
,
4
),
nn
.
ReLU
())
g
.
set_update_func
(
fwd_net
)
g
.
set_repr
(
0
,
Var
(
th
.
rand
(
2
,
4
)))
print
(
"
\t
**before:"
);
g
.
print_all
()
g
.
update_from
(
0
)
g
.
update_from
(
1
)
print
(
"
\t
**after:"
);
g
.
print_all
()
mx_scalar.py
0 → 100644
View file @
75a43e4a
import
networkx
as
nx
from
networkx.classes.graph
import
Graph
from
networkx.classes.digraph
import
DiGraph
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training
class
mx_Graph
(
DiGraph
):
def
__init__
(
self
,
*
args
,
**
kargs
):
super
(
mx_Graph
,
self
).
__init__
(
*
args
,
**
kargs
)
self
.
set_msg_func
()
self
.
set_update_func
()
self
.
set_readout_func
()
self
.
init_reprs
()
def
init_reprs
(
self
,
h_init
=
None
):
for
n
in
self
.
nodes
:
self
.
set_repr
(
n
,
h_init
)
def
set_repr
(
self
,
u
,
h_u
,
name
=
'h'
):
assert
u
in
self
.
nodes
kwarg
=
{
name
:
h_u
}
self
.
add_node
(
u
,
**
kwarg
)
def
get_repr
(
self
,
u
,
name
=
'h'
):
assert
u
in
self
.
nodes
return
self
.
nodes
[
u
][
name
]
def
set_msg_func
(
self
,
func
=
None
,
u
=
None
):
"""Function that gathers messages from neighbors"""
def
_default_msg_func
(
u
):
assert
u
in
self
.
nodes
msg_gathered
=
0
for
v
in
self
.
adj
[
u
]:
x
=
self
.
get_repr
(
v
)
if
x
is
not
None
:
msg_gathered
+=
x
return
msg_gathered
# TODO: per node message function
# TODO: 'sum' should be a separate function
if
func
==
None
:
self
.
msg_func
=
_default_msg_func
else
:
self
.
msg_func
=
func
def
set_update_func
(
self
,
func
=
None
,
u
=
None
):
"""
Update function upon receiving an aggregate
message from a node's neighbor
"""
def
_default_update_func
(
u
,
m
):
if
self
.
nodes
[
u
][
'h'
]
is
None
:
h_new
=
m
else
:
h_new
=
self
.
nodes
[
u
][
'h'
]
+
m
self
.
set_repr
(
u
,
h_new
)
# TODO: per node update function
if
func
==
None
:
self
.
update_func
=
_default_update_func
else
:
self
.
update_func
=
func
def
set_readout_func
(
self
,
func
=
None
):
"""Readout function of the whole graph"""
def
_default_readout_func
():
readout
=
0
for
n
in
self
.
nodes
:
readout
+=
self
.
nodes
[
n
][
'h'
]
return
readout
if
func
==
None
:
self
.
readout_func
=
_default_readout_func
else
:
self
.
readout_func
=
func
def
readout
(
self
):
return
self
.
readout_func
()
def
update_to
(
self
,
u
):
"""Pull messages from 1-step away neighbors of u"""
assert
u
in
self
.
nodes
m
=
self
.
msg_func
(
u
=
u
)
self
.
update_func
(
u
,
m
)
def
update_from
(
self
,
u
):
"""Update u's 1-step away neighbors"""
assert
u
in
self
.
nodes
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
for
v
in
self
.
adj
[
u
]:
self
.
update_to
(
v
)
def
print_all
(
self
):
for
n
in
self
.
nodes
:
print
(
n
,
self
.
nodes
[
n
])
print
()
if
__name__
==
'__main__'
:
tg
=
nx
.
path_graph
(
10
)
g
=
mx_Graph
(
tg
)
g
.
print_all
()
tr
=
nx
.
balanced_tree
(
2
,
3
)
m_tr
=
mx_Graph
(
tr
)
m_tr
.
print_all
()
g
=
mx_Graph
(
nx
.
path_graph
(
3
))
for
n
in
g
:
g
.
set_repr
(
n
,
int
(
n
)
+
10
)
g
.
print_all
()
print
(
g
.
readout
())
print
(
"before update:
\t
"
,
g
.
nodes
[
0
])
g
.
update_to
(
0
)
print
(
'after update:
\t
'
,
g
.
nodes
[
0
])
g
.
print_all
()
print
(
g
.
readout
())
g
=
mx_Graph
(
nx
.
bfs_tree
(
nx
.
path_graph
(
3
),
0
))
g
.
set_repr
(
0
,
10
)
g
.
print_all
()
g
.
update_from
(
0
)
g
.
print_all
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment