Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
572b289e
Unverified
Commit
572b289e
authored
May 08, 2018
by
Zheng Zhang
Committed by
GitHub
May 08, 2018
Browse files
Merge pull request #4 from BarclayII/gq-pytorch
more changes
parents
83e84e67
51391012
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
47 deletions
+121
-47
graph.py
graph.py
+14
-8
model.py
model.py
+107
-39
No files found.
graph.py
View file @
572b289e
import
networkx
as
nx
import
networkx
as
nx
import
torch
as
T
import
torch
as
T
import
torch.nn
as
NN
import
torch.nn
as
NN
from
util
import
*
class
DiGraph
(
nx
.
DiGraph
,
NN
.
Module
):
class
DiGraph
(
nx
.
DiGraph
,
NN
.
Module
):
'''
'''
...
@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module):
def
add_edge
(
self
,
u
,
v
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
def
add_edge
(
self
,
u
,
v
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edge
(
self
,
u
,
v
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
nx
.
DiGraph
.
add_edge
(
self
,
u
,
v
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
None
,
**
attr
):
def
add_edges_from
(
self
,
ebunch
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
nx
.
DiGraph
.
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
...
@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
self
.
node
[
v
][
'state'
]
=
T
.
zeros
(
shape
)
self
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
)
)
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))
)
,
*
args
)
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
node
[
v
][
'tag'
])
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
node
[
v
][
'tag'
])
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
edges
=
self
.
_edges_or_all
(
edges
)
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
for
u
,
v
in
edges
:
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))
)
,
*
args
)
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
[
u
][
v
][
'tag'
])
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
[
u
][
v
][
'tag'
])
def
remove_node_tag
(
self
,
nodes
=
'all'
):
def
remove_node_tag
(
self
,
nodes
=
'all'
):
...
@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module):
'''
'''
batched: whether to do a single batched computation instead of iterating
batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag),
update function: accepts a node attribute dictionary (including state and tag),
and a
dictionary of
edge attribute dictionar
ies
and a
list of tuples (source node, target node,
edge attribute dictionar
y)
'''
'''
self
.
update_funcs
.
append
((
self
.
_nodes_or_all
(
nodes
),
update_func
,
batched
))
self
.
update_funcs
.
append
((
self
.
_nodes_or_all
(
nodes
),
update_func
,
batched
))
...
@@ -126,6 +127,10 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -126,6 +127,10 @@ class DiGraph(nx.DiGraph, NN.Module):
# FIXME: need to optimize since we are repeatedly stacking and
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
# unpacking
source
=
T
.
stack
([
self
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
source
=
T
.
stack
([
self
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
edge_tags
=
[
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
]
if
all
(
t
is
None
for
t
in
edge_tags
):
edge_tag
=
None
else
:
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
message
=
f
(
source
,
edge_tag
)
message
=
f
(
source
,
edge_tag
)
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
...
@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module):
# update state
# update state
# TODO: does it make sense to batch update the nodes?
# TODO: does it make sense to batch update the nodes?
for
v
,
f
in
self
.
update_funcs
:
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
[
v
])
for
v
in
vbunch
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
.
in_edges
(
v
,
data
=
True
))
model.py
View file @
572b289e
import
torch
as
T
import
torch
as
T
import
torch.nn
as
NN
import
torch.nn
as
NN
import
networkx
as
nx
import
torch.nn.init
as
INIT
import
torch.nn.functional
as
F
import
numpy
as
NP
import
numpy.random
as
RNG
from
util
import
*
from
glimpse
import
create_glimpse
from
zoneout
import
ZoneoutLSTMCell
from
collections
import
namedtuple
import
os
from
graph
import
DiGraph
from
graph
import
DiGraph
import
networkx
as
nx
no_msg
=
os
.
getenv
(
'NOMSG'
,
False
)
def
build_cnn
(
**
config
):
cnn_list
=
[]
filters
=
config
[
'filters'
]
kernel_size
=
config
[
'kernel_size'
]
in_channels
=
config
.
get
(
'in_channels'
,
3
)
final_pool_size
=
config
[
'final_pool_size'
]
for
i
in
range
(
len
(
filters
)):
module
=
NN
.
Conv2d
(
in_channels
if
i
==
0
else
filters
[
i
-
1
],
filters
[
i
],
kernel_size
,
padding
=
tuple
((
_
-
1
)
//
2
for
_
in
kernel_size
),
)
INIT
.
xavier_uniform
(
module
.
weight
)
INIT
.
constant
(
module
.
bias
,
0
)
cnn_list
.
append
(
module
)
if
i
<
len
(
filters
)
-
1
:
cnn_list
.
append
(
NN
.
LeakyReLU
())
cnn_list
.
append
(
NN
.
AdaptiveMaxPool2d
(
final_pool_size
))
return
NN
.
Sequential
(
*
cnn_list
)
class
TreeGlimpsedClassifier
(
NN
.
Module
):
class
TreeGlimpsedClassifier
(
NN
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module):
h_dims
=
128
,
h_dims
=
128
,
node_tag_dims
=
128
,
node_tag_dims
=
128
,
edge_tag_dims
=
128
,
edge_tag_dims
=
128
,
h_dims
=
128
,
n_classes
=
10
,
n_classes
=
10
,
steps
=
5
,
steps
=
5
,
filters
=
[
16
,
32
,
64
,
128
,
256
],
kernel_size
=
(
3
,
3
),
final_pool_size
=
(
2
,
2
),
glimpse_type
=
'gaussian'
,
glimpse_size
=
(
15
,
15
),
):
):
'''
'''
Basic idea:
Basic idea:
...
@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module):
self
.
edge_tag_dims
=
edge_tag_dims
self
.
edge_tag_dims
=
edge_tag_dims
self
.
h_dims
=
h_dims
self
.
h_dims
=
h_dims
self
.
n_classes
=
n_classes
self
.
n_classes
=
n_classes
self
.
glimpse
=
create_glimpse
(
glimpse_type
,
glimpse_size
)
self
.
steps
=
steps
self
.
cnn
=
build_cnn
(
filters
=
filters
,
kernel_size
=
kernel_size
,
final_pool_size
=
final_pool_size
,
)
# Create graph of latent variables
# Create graph of latent variables
G
=
nx
.
balanced_tree
(
self
.
n_children
,
self
.
n_depth
)
G
=
nx
.
balanced_tree
(
self
.
n_children
,
self
.
n_depth
)
nx
.
relabel_nodes
(
G
,
nx
.
relabel_nodes
(
G
,
{
i
:
'h%d'
%
i
for
i
in
range
(
self
.
G
.
nodes
())},
{
i
:
'h%d'
%
i
for
i
in
range
(
len
(
G
.
nodes
())
)
},
False
False
)
)
h_nodes_list
=
G
.
nodes
()
self
.
h_nodes_list
=
h_nodes_list
=
G
.
nodes
()
for
h
in
h_nodes_list
:
for
h
in
h_nodes_list
:
G
.
node
[
h
][
'type'
]
=
'h'
G
.
node
[
h
][
'type'
]
=
'h'
b_nodes_list
=
[
'b%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
b_nodes_list
=
[
'b%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
y_nodes_list
=
[
'y%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
y_nodes_list
=
[
'y%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
self
.
b_nodes_list
=
b_nodes_list
self
.
y_nodes_list
=
y_nodes_list
hy_edge_list
=
[(
h
,
y
)
for
h
,
y
in
zip
(
h_nodes_list
,
y_nodes_list
)]
hy_edge_list
=
[(
h
,
y
)
for
h
,
y
in
zip
(
h_nodes_list
,
y_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
y
_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
b
_nodes_list
)]
yh_edge_list
=
[(
y
,
h
)
for
y
,
h
in
zip
(
y_nodes_list
,
h_nodes_list
)]
yh_edge_list
=
[(
y
,
h
)
for
y
,
h
in
zip
(
y_nodes_list
,
h_nodes_list
)]
bh_edge_list
=
[(
b
,
h
)
for
b
,
h
in
zip
(
b_nodes_list
,
h_nodes_list
)]
bh_edge_list
=
[(
b
,
h
)
for
b
,
h
in
zip
(
b_nodes_list
,
h_nodes_list
)]
...
@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag_dims
,
edge_tag_dims
,
T
.
nn
.
init
.
uniform_
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
args
=
(
-
.
01
,
.
01
),
edges
=
hy_edge_list
+
hb_edge_list
+
yh_edge_list
,
bh_edge_list
edges
=
hy_edge_list
+
hb_edge_list
+
bh_edge_list
)
self
.
G
.
init_edge_tag_with
(
h_dims
*
n_classes
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
yh_edge_list
)
)
# y -> h. An attention over embeddings dynamically generated through edge tags
# y -> h. An attention over embeddings dynamically generated through edge tags
self
.
yh_emb
=
NN
.
Sequential
(
NN
.
Linear
(
edge_tag_dims
,
h_dims
),
NN
.
ReLU
(),
NN
.
Linear
(
h_dims
,
n_classes
*
h_dims
),
)
self
.
G
.
register_message_func
(
self
.
_y_to_h
,
edges
=
yh_edge_list
,
batched
=
True
)
self
.
G
.
register_message_func
(
self
.
_y_to_h
,
edges
=
yh_edge_list
,
batched
=
True
)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self
.
bh_1
=
NN
.
Linear
(
self
.
glimpse
.
att_params
,
h_dims
)
self
.
bh_1
=
NN
.
Linear
(
self
.
glimpse
.
att_params
,
h_dims
)
self
.
bh_2
=
NN
.
Linear
(
edge_tag_dims
,
h_dims
)
self
.
bh_2
=
NN
.
Linear
(
edge_tag_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
3
*
h_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
2
*
h_dims
+
filters
[
-
1
]
*
NP
.
prod
(
final_pool_size
)
,
h_dims
)
self
.
G
.
register_message_func
(
self
.
_b_to_h
,
edges
=
bh_edge_list
,
batched
=
True
)
self
.
G
.
register_message_func
(
self
.
_b_to_h
,
edges
=
bh_edge_list
,
batched
=
True
)
# h -> h. Just passes h itself
# h -> h. Just passes h itself
...
@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module):
# h -> b. Concatenates h with edge tag and go through MLP.
# h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb
# Produces Δb
self
.
hb
=
NN
.
Linear
(
h
idden_layer
s
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
hb
=
NN
.
Linear
(
h
_dim
s
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
G
.
register_message_func
(
self
.
_h_to_b
,
edges
=
hb_edge_list
,
batched
=
True
)
self
.
G
.
register_message_func
(
self
.
_h_to_b
,
edges
=
hb_edge_list
,
batched
=
True
)
# h -> y. Concatenates h with edge tag and go through MLP.
# h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy
# Produces Δy
self
.
hy
=
NN
.
Linear
(
h
idden_layer
s
+
edge_tag_dims
,
self
.
n_classes
)
self
.
hy
=
NN
.
Linear
(
h
_dim
s
+
edge_tag_dims
,
self
.
n_classes
)
self
.
G
.
register_message_func
(
self
.
_h_to_y
,
edges
=
hy_edge_list
,
batched
=
True
)
self
.
G
.
register_message_func
(
self
.
_h_to_y
,
edges
=
hy_edge_list
,
batched
=
True
)
# b update: just adds the original b by Δb
# b update: just adds the original b by Δb
...
@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module):
'''
'''
n_yh_edges
,
batch_size
,
_
=
source
.
shape
n_yh_edges
,
batch_size
,
_
=
source
.
shape
if
not
self
.
_yh_emb_cached
:
w
=
edge_tag
.
reshape
(
n_yh_edges
,
1
,
self
.
n_classes
,
self
.
h_dims
)
self
.
_yh_emb_cached
=
True
self
.
_yh_emb_w
=
self
.
yh_emb
(
edge_tag
)
self
.
_yh_emb_w
=
self
.
_yh_emb_w
.
reshape
(
n_yh_edges
,
self
.
n_classes
,
self
.
h_dims
)
w
=
self
.
_yh_emb_w
[:,
None
]
w
=
w
.
expand
(
n_yh_edges
,
batch_size
,
self
.
n_classes
,
self
.
h_dims
)
w
=
w
.
expand
(
n_yh_edges
,
batch_size
,
self
.
n_classes
,
self
.
h_dims
)
source
=
source
[:,
:,
None
,
:]
source
=
source
[:,
:,
None
,
:]
return
(
F
.
softmax
(
source
)
@
w
).
reshape
(
n_yh_edges
,
batch_size
,
self
.
h_dims
)
return
(
F
.
softmax
(
source
)
@
w
).
reshape
(
n_yh_edges
,
batch_size
,
self
.
h_dims
)
...
@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag: (n_bh_edges, edge_tag_dims)
edge_tag: (n_bh_edges, edge_tag_dims)
'''
'''
n_bh_edges
,
batch_size
,
_
=
source
.
shape
n_bh_edges
,
batch_size
,
_
=
source
.
shape
# FIXME: really using self.x is a bad design here
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
_
source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
m_b
=
T
.
relu
(
self
.
bh_1
(
source
))
m_b
=
T
.
relu
(
self
.
bh_1
(
_
source
))
m_t
=
T
.
relu
(
self
.
bh_2
(
edge_tag
))
m_t
=
T
.
relu
(
self
.
bh_2
(
edge_tag
))
m_t
=
m_t
[:,
None
,
:].
expand
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
m_t
=
m_t
[:,
None
,
:].
expand
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
m_t
=
m_t
.
reshape
(
-
1
,
self
.
h_dims
)
m_t
=
m_t
.
reshape
(
-
1
,
self
.
h_dims
)
...
@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module):
# here, the dimension of @source is n_bh_edges (# of glimpses), then
# here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them
# batch size, so we transpose them
g
=
self
.
glimpse
(
self
.
x
,
source
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
g
=
self
.
glimpse
(
self
.
x
,
source
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
nrows
,
ncols
)
grows
,
gcols
=
g
.
size
()[
-
2
:]
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
grows
,
gcols
)
phi
=
self
.
cnn
(
g
).
reshape
(
n_bh_edges
*
batch_size
,
-
1
)
phi
=
self
.
cnn
(
g
).
reshape
(
n_bh_edges
*
batch_size
,
-
1
)
# TODO: add an attribute (g) to h
m
=
self
.
bh_all
(
T
.
cat
([
m_b
,
m_t
,
phi
],
1
))
m
=
self
.
bh_all
(
T
.
cat
([
m_b
,
m_t
,
phi
],
1
))
m
=
m
.
reshape
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
m
=
m
.
reshape
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
...
@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_hb_edges
,
batch_size
,
self
.
edge_tag_dims
)
edge_tag
=
edge_tag
.
expand
(
n_hb_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hb_edges
*
batch_size
,
-
1
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hb_edges
*
batch_size
,
-
1
)
b
=
self
.
hb
(
I
)
d
b
=
self
.
hb
(
I
)
return
db
return
db
.
reshape
(
n_hb_edges
,
batch_size
,
-
1
)
def
_h_to_y
(
self
,
source
,
edge_tag
):
def
_h_to_y
(
self
,
source
,
edge_tag
):
n_hy_edges
,
batch_size
,
_
=
source
.
shape
n_hy_edges
,
batch_size
,
_
=
source
.
shape
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_h
b
_edges
,
batch_size
,
self
.
edge_tag_dims
)
edge_tag
=
edge_tag
.
expand
(
n_h
y
_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_h
b
_edges
*
batch_size
,
-
1
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_h
y
_edges
*
batch_size
,
-
1
)
y
=
self
.
hy
(
I
)
d
y
=
self
.
hy
(
I
)
return
dy
return
dy
.
reshape
(
n_hy_edges
,
batch_size
,
-
1
)
def
_update_b
(
self
,
b
,
b_n
):
def
_update_b
(
self
,
b
,
b_n
):
return
b
[
'state'
]
+
list
(
b_n
.
values
())[
0
][
'state'
]
return
b
[
'state'
]
+
b_n
[
0
][
2
][
'state'
]
def
_update_y
(
self
,
y
,
y_n
):
def
_update_y
(
self
,
y
,
y_n
):
return
y
[
'state'
]
+
list
(
y_n
.
values
())[
0
][
'state'
]
return
y
[
'state'
]
+
y_n
[
0
][
2
][
'state'
]
def
_update_h
(
self
,
h
,
h_n
):
def
_update_h
(
self
,
h
,
h_n
):
m
=
T
.
stack
([
e
[
'state'
]
for
e
in
h_n
]).
mean
(
0
)
m
=
T
.
stack
([
e
[
2
][
'state'
]
for
e
in
h_n
]).
mean
(
0
)
return
T
.
relu
(
h
+
m
)
return
T
.
relu
(
h
[
'state'
]
+
m
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
y
=
None
):
self
.
x
=
x
batch_size
=
x
.
shape
[
0
]
batch_size
=
x
.
shape
[
0
]
self
.
G
.
zero_node_state
(
self
.
h_dims
,
batch_size
,
nodes
=
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
full
=
self
.
glimpse
.
full
().
unsqueeze
(
0
).
expand
(
batch_size
,
self
.
glimpse
.
att_params
)
full
=
self
.
glimpse
.
full
().
unsqueeze
(
0
).
expand
(
batch_size
,
self
.
glimpse
.
att_params
)
for
v
in
self
.
G
.
nodes
():
for
v
in
self
.
G
.
nodes
():
if
G
.
node
[
v
][
'type'
]
==
'b'
:
if
self
.
G
.
node
[
v
][
'type'
]
==
'b'
:
# Initialize bbox variables to cover the entire canvas
# Initialize bbox variables to cover the entire canvas
self
.
G
.
node
[
v
][
'state'
]
=
full
self
.
G
.
node
[
v
][
'state'
]
=
full
self
.
_yh_emb_cached
=
False
for
t
in
range
(
self
.
steps
):
for
t
in
range
(
self
.
steps
):
self
.
G
.
step
()
self
.
G
.
step
()
# We don't change b of the root
# We don't change b of the root
self
.
G
.
node
[
'b0'
][
'state'
]
=
full
self
.
G
.
node
[
'b0'
][
'state'
]
=
full
self
.
y_pre
=
T
.
stack
(
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
1
)
self
.
v_B
=
T
.
stack
(
[
self
.
G
.
node
[
'b%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
)],
1
,
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
return
self
.
G
.
node
[
'h0'
][
'state'
]
@
property
def
n_nodes
(
self
):
return
(
self
.
n_children
**
self
.
n_depth
-
1
)
//
(
self
.
n_children
-
1
)
@
property
def
n_leaves
(
self
):
return
self
.
n_children
**
(
self
.
n_depth
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment