Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
572b289e
Unverified
Commit
572b289e
authored
May 08, 2018
by
Zheng Zhang
Committed by
GitHub
May 08, 2018
Browse files
Merge pull request #4 from BarclayII/gq-pytorch
more changes
parents
83e84e67
51391012
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
47 deletions
+121
-47
graph.py
graph.py
+14
-8
model.py
model.py
+107
-39
No files found.
graph.py
View file @
572b289e
import
networkx
as
nx
import
torch
as
T
import
torch.nn
as
NN
from
util
import
*
class
DiGraph
(
nx
.
DiGraph
,
NN
.
Module
):
'''
...
...
@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module):
def
add_edge
(
self
,
u
,
v
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edge
(
self
,
u
,
v
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
None
,
**
attr
):
def
add_edges_from
(
self
,
ebunch
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
...
...
@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
node
[
v
][
'state'
]
=
T
.
zeros
(
shape
)
self
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
)
)
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))
)
,
*
args
)
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
node
[
v
][
'tag'
])
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))
)
,
*
args
)
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
[
u
][
v
][
'tag'
])
def
remove_node_tag
(
self
,
nodes
=
'all'
):
...
...
@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module):
'''
batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag),
and a
dictionary of
edge attribute dictionar
ies
and a
list of tuples (source node, target node,
edge attribute dictionar
y)
'''
self
.
update_funcs
.
append
((
self
.
_nodes_or_all
(
nodes
),
update_func
,
batched
))
...
...
@@ -126,7 +127,11 @@ class DiGraph(nx.DiGraph, NN.Module):
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
source
=
T
.
stack
([
self
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
edge_tags
=
[
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
]
if
all
(
t
is
None
for
t
in
edge_tags
):
edge_tag
=
None
else
:
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
message
=
f
(
source
,
edge_tag
)
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
self
[
u
][
v
][
'state'
]
=
message
[
i
]
...
...
@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module):
# update state
# TODO: does it make sense to batch update the nodes?
for
v
,
f
in
self
.
update_funcs
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
[
v
])
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
for
v
in
vbunch
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
.
in_edges
(
v
,
data
=
True
))
model.py
View file @
572b289e
import
torch
as
T
import
torch.nn
as
NN
import
networkx
as
nx
import
torch.nn.init
as
INIT
import
torch.nn.functional
as
F
import
numpy
as
NP
import
numpy.random
as
RNG
from
util
import
*
from
glimpse
import
create_glimpse
from
zoneout
import
ZoneoutLSTMCell
from
collections
import
namedtuple
import
os
from
graph
import
DiGraph
import
networkx
as
nx
no_msg
=
os
.
getenv
(
'NOMSG'
,
False
)
def
build_cnn
(
**
config
):
cnn_list
=
[]
filters
=
config
[
'filters'
]
kernel_size
=
config
[
'kernel_size'
]
in_channels
=
config
.
get
(
'in_channels'
,
3
)
final_pool_size
=
config
[
'final_pool_size'
]
for
i
in
range
(
len
(
filters
)):
module
=
NN
.
Conv2d
(
in_channels
if
i
==
0
else
filters
[
i
-
1
],
filters
[
i
],
kernel_size
,
padding
=
tuple
((
_
-
1
)
//
2
for
_
in
kernel_size
),
)
INIT
.
xavier_uniform
(
module
.
weight
)
INIT
.
constant
(
module
.
bias
,
0
)
cnn_list
.
append
(
module
)
if
i
<
len
(
filters
)
-
1
:
cnn_list
.
append
(
NN
.
LeakyReLU
())
cnn_list
.
append
(
NN
.
AdaptiveMaxPool2d
(
final_pool_size
))
return
NN
.
Sequential
(
*
cnn_list
)
class
TreeGlimpsedClassifier
(
NN
.
Module
):
def
__init__
(
self
,
...
...
@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module):
h_dims
=
128
,
node_tag_dims
=
128
,
edge_tag_dims
=
128
,
h_dims
=
128
,
n_classes
=
10
,
steps
=
5
,
filters
=
[
16
,
32
,
64
,
128
,
256
],
kernel_size
=
(
3
,
3
),
final_pool_size
=
(
2
,
2
),
glimpse_type
=
'gaussian'
,
glimpse_size
=
(
15
,
15
),
):
'''
Basic idea:
...
...
@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module):
self
.
edge_tag_dims
=
edge_tag_dims
self
.
h_dims
=
h_dims
self
.
n_classes
=
n_classes
self
.
glimpse
=
create_glimpse
(
glimpse_type
,
glimpse_size
)
self
.
steps
=
steps
self
.
cnn
=
build_cnn
(
filters
=
filters
,
kernel_size
=
kernel_size
,
final_pool_size
=
final_pool_size
,
)
# Create graph of latent variables
G
=
nx
.
balanced_tree
(
self
.
n_children
,
self
.
n_depth
)
nx
.
relabel_nodes
(
G
,
{
i
:
'h%d'
%
i
for
i
in
range
(
self
.
G
.
nodes
())},
{
i
:
'h%d'
%
i
for
i
in
range
(
len
(
G
.
nodes
())
)
},
False
)
h_nodes_list
=
G
.
nodes
()
self
.
h_nodes_list
=
h_nodes_list
=
G
.
nodes
()
for
h
in
h_nodes_list
:
G
.
node
[
h
][
'type'
]
=
'h'
b_nodes_list
=
[
'b%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
y_nodes_list
=
[
'y%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
self
.
b_nodes_list
=
b_nodes_list
self
.
y_nodes_list
=
y_nodes_list
hy_edge_list
=
[(
h
,
y
)
for
h
,
y
in
zip
(
h_nodes_list
,
y_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
y
_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
b
_nodes_list
)]
yh_edge_list
=
[(
y
,
h
)
for
y
,
h
in
zip
(
y_nodes_list
,
h_nodes_list
)]
bh_edge_list
=
[(
b
,
h
)
for
b
,
h
in
zip
(
b_nodes_list
,
h_nodes_list
)]
...
...
@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag_dims
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
hy_edge_list
+
hb_edge_list
+
yh_edge_list
,
bh_edge_list
edges
=
hy_edge_list
+
hb_edge_list
+
bh_edge_list
)
self
.
G
.
init_edge_tag_with
(
h_dims
*
n_classes
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
yh_edge_list
)
# y -> h. An attention over embeddings dynamically generated through edge tags
self
.
yh_emb
=
NN
.
Sequential
(
NN
.
Linear
(
edge_tag_dims
,
h_dims
),
NN
.
ReLU
(),
NN
.
Linear
(
h_dims
,
n_classes
*
h_dims
),
)
self
.
G
.
register_message_func
(
self
.
_y_to_h
,
edges
=
yh_edge_list
,
batched
=
True
)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self
.
bh_1
=
NN
.
Linear
(
self
.
glimpse
.
att_params
,
h_dims
)
self
.
bh_2
=
NN
.
Linear
(
edge_tag_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
3
*
h_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
2
*
h_dims
+
filters
[
-
1
]
*
NP
.
prod
(
final_pool_size
)
,
h_dims
)
self
.
G
.
register_message_func
(
self
.
_b_to_h
,
edges
=
bh_edge_list
,
batched
=
True
)
# h -> h. Just passes h itself
...
...
@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module):
# h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb
self
.
hb
=
NN
.
Linear
(
h
idden_layer
s
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
hb
=
NN
.
Linear
(
h
_dim
s
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
G
.
register_message_func
(
self
.
_h_to_b
,
edges
=
hb_edge_list
,
batched
=
True
)
# h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy
self
.
hy
=
NN
.
Linear
(
h
idden_layer
s
+
edge_tag_dims
,
self
.
n_classes
)
self
.
hy
=
NN
.
Linear
(
h
_dim
s
+
edge_tag_dims
,
self
.
n_classes
)
self
.
G
.
register_message_func
(
self
.
_h_to_y
,
edges
=
hy_edge_list
,
batched
=
True
)
# b update: just adds the original b by Δb
...
...
@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module):
'''
n_yh_edges
,
batch_size
,
_
=
source
.
shape
if
not
self
.
_yh_emb_cached
:
self
.
_yh_emb_cached
=
True
self
.
_yh_emb_w
=
self
.
yh_emb
(
edge_tag
)
self
.
_yh_emb_w
=
self
.
_yh_emb_w
.
reshape
(
n_yh_edges
,
self
.
n_classes
,
self
.
h_dims
)
w
=
self
.
_yh_emb_w
[:,
None
]
w
=
edge_tag
.
reshape
(
n_yh_edges
,
1
,
self
.
n_classes
,
self
.
h_dims
)
w
=
w
.
expand
(
n_yh_edges
,
batch_size
,
self
.
n_classes
,
self
.
h_dims
)
source
=
source
[:,
:,
None
,
:]
return
(
F
.
softmax
(
source
)
@
w
).
reshape
(
n_yh_edges
,
batch_size
,
self
.
h_dims
)
...
...
@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag: (n_bh_edges, edge_tag_dims)
'''
n_bh_edges
,
batch_size
,
_
=
source
.
shape
# FIXME: really using self.x is a bad design here
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
_
source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
m_b
=
T
.
relu
(
self
.
bh_1
(
source
))
m_b
=
T
.
relu
(
self
.
bh_1
(
_
source
))
m_t
=
T
.
relu
(
self
.
bh_2
(
edge_tag
))
m_t
=
m_t
[:,
None
,
:].
expand
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
m_t
=
m_t
.
reshape
(
-
1
,
self
.
h_dims
)
...
...
@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module):
# here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them
g
=
self
.
glimpse
(
self
.
x
,
source
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
nrows
,
ncols
)
grows
,
gcols
=
g
.
size
()[
-
2
:]
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
grows
,
gcols
)
phi
=
self
.
cnn
(
g
).
reshape
(
n_bh_edges
*
batch_size
,
-
1
)
# TODO: add an attribute (g) to h
m
=
self
.
bh_all
(
T
.
cat
([
m_b
,
m_t
,
phi
],
1
))
m
=
m
.
reshape
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
...
...
@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_hb_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hb_edges
*
batch_size
,
-
1
)
b
=
self
.
hb
(
I
)
return
db
d
b
=
self
.
hb
(
I
)
return
db
.
reshape
(
n_hb_edges
,
batch_size
,
-
1
)
def
_h_to_y
(
self
,
source
,
edge_tag
):
n_hy_edges
,
batch_size
,
_
=
source
.
shape
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_h
b
_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_h
b
_edges
*
batch_size
,
-
1
)
y
=
self
.
hy
(
I
)
return
dy
edge_tag
=
edge_tag
.
expand
(
n_h
y
_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_h
y
_edges
*
batch_size
,
-
1
)
d
y
=
self
.
hy
(
I
)
return
dy
.
reshape
(
n_hy_edges
,
batch_size
,
-
1
)
def
_update_b
(
self
,
b
,
b_n
):
return
b
[
'state'
]
+
list
(
b_n
.
values
())[
0
][
'state'
]
return
b
[
'state'
]
+
b_n
[
0
][
2
][
'state'
]
def
_update_y
(
self
,
y
,
y_n
):
return
y
[
'state'
]
+
list
(
y_n
.
values
())[
0
][
'state'
]
return
y
[
'state'
]
+
y_n
[
0
][
2
][
'state'
]
def
_update_h
(
self
,
h
,
h_n
):
m
=
T
.
stack
([
e
[
'state'
]
for
e
in
h_n
]).
mean
(
0
)
return
T
.
relu
(
h
+
m
)
m
=
T
.
stack
([
e
[
2
][
'state'
]
for
e
in
h_n
]).
mean
(
0
)
return
T
.
relu
(
h
[
'state'
]
+
m
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
y
=
None
):
self
.
x
=
x
batch_size
=
x
.
shape
[
0
]
self
.
G
.
zero_node_state
(
self
.
h_dims
,
batch_size
,
nodes
=
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
full
=
self
.
glimpse
.
full
().
unsqueeze
(
0
).
expand
(
batch_size
,
self
.
glimpse
.
att_params
)
for
v
in
self
.
G
.
nodes
():
if
G
.
node
[
v
][
'type'
]
==
'b'
:
if
self
.
G
.
node
[
v
][
'type'
]
==
'b'
:
# Initialize bbox variables to cover the entire canvas
self
.
G
.
node
[
v
][
'state'
]
=
full
self
.
_yh_emb_cached
=
False
for
t
in
range
(
self
.
steps
):
self
.
G
.
step
()
# We don't change b of the root
self
.
G
.
node
[
'b0'
][
'state'
]
=
full
self
.
y_pre
=
T
.
stack
(
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
1
)
self
.
v_B
=
T
.
stack
(
[
self
.
G
.
node
[
'b%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
)],
1
,
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
return
self
.
G
.
node
[
'h0'
][
'state'
]
@
property
def
n_nodes
(
self
):
return
(
self
.
n_children
**
self
.
n_depth
-
1
)
//
(
self
.
n_children
-
1
)
@
property
def
n_leaves
(
self
):
return
self
.
n_children
**
(
self
.
n_depth
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment