Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
51391012
"tests/python/common/test_batch-heterograph.py" did not exist on "a936f9d9a680ff42a13a6b694fe4fba43d850fc1"
Commit
51391012
authored
May 05, 2018
by
Gan Quan
Browse files
graph and model change
parent
196e6a92
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
47 deletions
+121
-47
graph.py
graph.py
+14
-8
model.py
model.py
+107
-39
No files found.
graph.py
View file @
51391012
import
networkx
as
nx
import
torch
as
T
import
torch.nn
as
NN
from
util
import
*
class
DiGraph
(
nx
.
DiGraph
,
NN
.
Module
):
'''
...
...
@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module):
def
add_edge
(
self
,
u
,
v
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edge
(
self
,
u
,
v
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
None
,
**
attr
):
def
add_edges_from
(
self
,
ebunch
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
...
...
@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
node
[
v
][
'state'
]
=
T
.
zeros
(
shape
)
self
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
)
)
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))
)
,
*
args
)
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
node
[
v
][
'tag'
])
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))
)
,
*
args
)
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
[
u
][
v
][
'tag'
])
def
remove_node_tag
(
self
,
nodes
=
'all'
):
...
...
@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module):
'''
batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag),
and a
dictionary of
edge attribute dictionar
ies
and a
list of tuples (source node, target node,
edge attribute dictionar
y)
'''
self
.
update_funcs
.
append
((
self
.
_nodes_or_all
(
nodes
),
update_func
,
batched
))
...
...
@@ -126,7 +127,11 @@ class DiGraph(nx.DiGraph, NN.Module):
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
source
=
T
.
stack
([
self
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
edge_tags
=
[
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
]
if
all
(
t
is
None
for
t
in
edge_tags
):
edge_tag
=
None
else
:
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
message
=
f
(
source
,
edge_tag
)
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
self
[
u
][
v
][
'state'
]
=
message
[
i
]
...
...
@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module):
# update state
# TODO: does it make sense to batch update the nodes?
for
v
,
f
in
self
.
update_funcs
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
[
v
])
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
for
v
in
vbunch
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
.
in_edges
(
v
,
data
=
True
))
model.py
View file @
51391012
import
torch
as
T
import
torch.nn
as
NN
import
networkx
as
nx
import
torch.nn.init
as
INIT
import
torch.nn.functional
as
F
import
numpy
as
NP
import
numpy.random
as
RNG
from
util
import
*
from
glimpse
import
create_glimpse
from
zoneout
import
ZoneoutLSTMCell
from
collections
import
namedtuple
import
os
from
graph
import
DiGraph
import
networkx
as
nx
no_msg
=
os
.
getenv
(
'NOMSG'
,
False
)
def
build_cnn
(
**
config
):
cnn_list
=
[]
filters
=
config
[
'filters'
]
kernel_size
=
config
[
'kernel_size'
]
in_channels
=
config
.
get
(
'in_channels'
,
3
)
final_pool_size
=
config
[
'final_pool_size'
]
for
i
in
range
(
len
(
filters
)):
module
=
NN
.
Conv2d
(
in_channels
if
i
==
0
else
filters
[
i
-
1
],
filters
[
i
],
kernel_size
,
padding
=
tuple
((
_
-
1
)
//
2
for
_
in
kernel_size
),
)
INIT
.
xavier_uniform
(
module
.
weight
)
INIT
.
constant
(
module
.
bias
,
0
)
cnn_list
.
append
(
module
)
if
i
<
len
(
filters
)
-
1
:
cnn_list
.
append
(
NN
.
LeakyReLU
())
cnn_list
.
append
(
NN
.
AdaptiveMaxPool2d
(
final_pool_size
))
return
NN
.
Sequential
(
*
cnn_list
)
class
TreeGlimpsedClassifier
(
NN
.
Module
):
def
__init__
(
self
,
...
...
@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module):
h_dims
=
128
,
node_tag_dims
=
128
,
edge_tag_dims
=
128
,
h_dims
=
128
,
n_classes
=
10
,
steps
=
5
,
filters
=
[
16
,
32
,
64
,
128
,
256
],
kernel_size
=
(
3
,
3
),
final_pool_size
=
(
2
,
2
),
glimpse_type
=
'gaussian'
,
glimpse_size
=
(
15
,
15
),
):
'''
Basic idea:
...
...
@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module):
self
.
edge_tag_dims
=
edge_tag_dims
self
.
h_dims
=
h_dims
self
.
n_classes
=
n_classes
self
.
glimpse
=
create_glimpse
(
glimpse_type
,
glimpse_size
)
self
.
steps
=
steps
self
.
cnn
=
build_cnn
(
filters
=
filters
,
kernel_size
=
kernel_size
,
final_pool_size
=
final_pool_size
,
)
# Create graph of latent variables
G
=
nx
.
balanced_tree
(
self
.
n_children
,
self
.
n_depth
)
nx
.
relabel_nodes
(
G
,
{
i
:
'h%d'
%
i
for
i
in
range
(
self
.
G
.
nodes
())},
{
i
:
'h%d'
%
i
for
i
in
range
(
len
(
G
.
nodes
())
)
},
False
)
h_nodes_list
=
G
.
nodes
()
self
.
h_nodes_list
=
h_nodes_list
=
G
.
nodes
()
for
h
in
h_nodes_list
:
G
.
node
[
h
][
'type'
]
=
'h'
b_nodes_list
=
[
'b%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
y_nodes_list
=
[
'y%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
self
.
b_nodes_list
=
b_nodes_list
self
.
y_nodes_list
=
y_nodes_list
hy_edge_list
=
[(
h
,
y
)
for
h
,
y
in
zip
(
h_nodes_list
,
y_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
y
_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
b
_nodes_list
)]
yh_edge_list
=
[(
y
,
h
)
for
y
,
h
in
zip
(
y_nodes_list
,
h_nodes_list
)]
bh_edge_list
=
[(
b
,
h
)
for
b
,
h
in
zip
(
b_nodes_list
,
h_nodes_list
)]
...
...
@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag_dims
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
hy_edge_list
+
hb_edge_list
+
yh_edge_list
,
bh_edge_list
edges
=
hy_edge_list
+
hb_edge_list
+
bh_edge_list
)
self
.
G
.
init_edge_tag_with
(
h_dims
*
n_classes
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
yh_edge_list
)
# y -> h. An attention over embeddings dynamically generated through edge tags
self
.
yh_emb
=
NN
.
Sequential
(
NN
.
Linear
(
edge_tag_dims
,
h_dims
),
NN
.
ReLU
(),
NN
.
Linear
(
h_dims
,
n_classes
*
h_dims
),
)
self
.
G
.
register_message_func
(
self
.
_y_to_h
,
edges
=
yh_edge_list
,
batched
=
True
)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self
.
bh_1
=
NN
.
Linear
(
self
.
glimpse
.
att_params
,
h_dims
)
self
.
bh_2
=
NN
.
Linear
(
edge_tag_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
3
*
h_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
2
*
h_dims
+
filters
[
-
1
]
*
NP
.
prod
(
final_pool_size
)
,
h_dims
)
self
.
G
.
register_message_func
(
self
.
_b_to_h
,
edges
=
bh_edge_list
,
batched
=
True
)
# h -> h. Just passes h itself
...
...
@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module):
# h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb
self
.
hb
=
NN
.
Linear
(
h
idden_layer
s
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
hb
=
NN
.
Linear
(
h
_dim
s
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
G
.
register_message_func
(
self
.
_h_to_b
,
edges
=
hb_edge_list
,
batched
=
True
)
# h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy
self
.
hy
=
NN
.
Linear
(
h
idden_layer
s
+
edge_tag_dims
,
self
.
n_classes
)
self
.
hy
=
NN
.
Linear
(
h
_dim
s
+
edge_tag_dims
,
self
.
n_classes
)
self
.
G
.
register_message_func
(
self
.
_h_to_y
,
edges
=
hy_edge_list
,
batched
=
True
)
# b update: just adds the original b by Δb
...
...
@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module):
'''
n_yh_edges
,
batch_size
,
_
=
source
.
shape
if
not
self
.
_yh_emb_cached
:
self
.
_yh_emb_cached
=
True
self
.
_yh_emb_w
=
self
.
yh_emb
(
edge_tag
)
self
.
_yh_emb_w
=
self
.
_yh_emb_w
.
reshape
(
n_yh_edges
,
self
.
n_classes
,
self
.
h_dims
)
w
=
self
.
_yh_emb_w
[:,
None
]
w
=
edge_tag
.
reshape
(
n_yh_edges
,
1
,
self
.
n_classes
,
self
.
h_dims
)
w
=
w
.
expand
(
n_yh_edges
,
batch_size
,
self
.
n_classes
,
self
.
h_dims
)
source
=
source
[:,
:,
None
,
:]
return
(
F
.
softmax
(
source
)
@
w
).
reshape
(
n_yh_edges
,
batch_size
,
self
.
h_dims
)
...
...
@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag: (n_bh_edges, edge_tag_dims)
'''
n_bh_edges
,
batch_size
,
_
=
source
.
shape
# FIXME: really using self.x is a bad design here
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
_
source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
m_b
=
T
.
relu
(
self
.
bh_1
(
source
))
m_b
=
T
.
relu
(
self
.
bh_1
(
_
source
))
m_t
=
T
.
relu
(
self
.
bh_2
(
edge_tag
))
m_t
=
m_t
[:,
None
,
:].
expand
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
m_t
=
m_t
.
reshape
(
-
1
,
self
.
h_dims
)
...
...
@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module):
# here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them
g
=
self
.
glimpse
(
self
.
x
,
source
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
nrows
,
ncols
)
grows
,
gcols
=
g
.
size
()[
-
2
:]
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
grows
,
gcols
)
phi
=
self
.
cnn
(
g
).
reshape
(
n_bh_edges
*
batch_size
,
-
1
)
# TODO: add an attribute (g) to h
m
=
self
.
bh_all
(
T
.
cat
([
m_b
,
m_t
,
phi
],
1
))
m
=
m
.
reshape
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
...
...
@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_hb_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hb_edges
*
batch_size
,
-
1
)
b
=
self
.
hb
(
I
)
return
db
d
b
=
self
.
hb
(
I
)
return
db
.
reshape
(
n_hb_edges
,
batch_size
,
-
1
)
def
_h_to_y
(
self
,
source
,
edge_tag
):
n_hy_edges
,
batch_size
,
_
=
source
.
shape
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_h
b
_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_h
b
_edges
*
batch_size
,
-
1
)
y
=
self
.
hy
(
I
)
return
dy
edge_tag
=
edge_tag
.
expand
(
n_h
y
_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_h
y
_edges
*
batch_size
,
-
1
)
d
y
=
self
.
hy
(
I
)
return
dy
.
reshape
(
n_hy_edges
,
batch_size
,
-
1
)
def
_update_b
(
self
,
b
,
b_n
):
return
b
[
'state'
]
+
list
(
b_n
.
values
())[
0
][
'state'
]
return
b
[
'state'
]
+
b_n
[
0
][
2
][
'state'
]
def
_update_y
(
self
,
y
,
y_n
):
return
y
[
'state'
]
+
list
(
y_n
.
values
())[
0
][
'state'
]
return
y
[
'state'
]
+
y_n
[
0
][
2
][
'state'
]
def
_update_h
(
self
,
h
,
h_n
):
m
=
T
.
stack
([
e
[
'state'
]
for
e
in
h_n
]).
mean
(
0
)
return
T
.
relu
(
h
+
m
)
m
=
T
.
stack
([
e
[
2
][
'state'
]
for
e
in
h_n
]).
mean
(
0
)
return
T
.
relu
(
h
[
'state'
]
+
m
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
y
=
None
):
self
.
x
=
x
batch_size
=
x
.
shape
[
0
]
self
.
G
.
zero_node_state
(
self
.
h_dims
,
batch_size
,
nodes
=
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
full
=
self
.
glimpse
.
full
().
unsqueeze
(
0
).
expand
(
batch_size
,
self
.
glimpse
.
att_params
)
for
v
in
self
.
G
.
nodes
():
if
G
.
node
[
v
][
'type'
]
==
'b'
:
if
self
.
G
.
node
[
v
][
'type'
]
==
'b'
:
# Initialize bbox variables to cover the entire canvas
self
.
G
.
node
[
v
][
'state'
]
=
full
self
.
_yh_emb_cached
=
False
for
t
in
range
(
self
.
steps
):
self
.
G
.
step
()
# We don't change b of the root
self
.
G
.
node
[
'b0'
][
'state'
]
=
full
self
.
y_pre
=
T
.
stack
(
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
1
)
self
.
v_B
=
T
.
stack
(
[
self
.
G
.
node
[
'b%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
)],
1
,
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
return
self
.
G
.
node
[
'h0'
][
'state'
]
@
property
def
n_nodes
(
self
):
return
(
self
.
n_children
**
self
.
n_depth
-
1
)
//
(
self
.
n_children
-
1
)
@
property
def
n_leaves
(
self
):
return
self
.
n_children
**
(
self
.
n_depth
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment