Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a3ce6e2f
Unverified
Commit
a3ce6e2f
authored
May 10, 2018
by
Zheng Zhang
Committed by
GitHub
May 10, 2018
Browse files
Merge pull request #5 from BarclayII/gq-pytorch
updates to support nx 2.1
parents
572b289e
9c9ac7c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
52 deletions
+45
-52
graph.py
graph.py
+40
-44
model.py
model.py
+5
-8
No files found.
graph.py
View file @
a3ce6e2f
...
@@ -3,37 +3,25 @@ import torch as T
...
@@ -3,37 +3,25 @@ import torch as T
import
torch.nn
as
NN
import
torch.nn
as
NN
from
util
import
*
from
util
import
*
class
DiGraph
(
nx
.
DiGraph
,
NN
.
Module
):
class
DiGraph
(
NN
.
Module
):
'''
'''
Reserved attributes:
Reserved attributes:
* state: node state vectors during message passing iterations
* state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages
edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data
* tag: node-/edge-specific feature tensors or other data
'''
'''
def
__init__
(
self
,
data
=
None
,
**
attr
):
def
__init__
(
self
,
graph
):
NN
.
Module
.
__init__
(
self
)
NN
.
Module
.
__init__
(
self
)
nx
.
DiGraph
.
__init__
(
self
,
data
=
data
,
**
attr
)
self
.
G
=
graph
self
.
message_funcs
=
[]
self
.
message_funcs
=
[]
self
.
update_funcs
=
[]
self
.
update_funcs
=
[]
def
add_node
(
self
,
n
,
state
=
None
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_node
(
self
,
n
,
state
=
state
,
tag
=
None
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_nodes_from
(
self
,
nodes
,
state
=
None
,
tag
=
None
,
**
attr
):
nx
.
DiGraph
.
add_nodes_from
(
self
,
nodes
,
state
=
state
,
tag
=
tag
,
**
attr
)
def
add_edge
(
self
,
u
,
v
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edge
(
self
,
u
,
v
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_edges_from
(
self
,
ebunch
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
return
self
.
nodes
()
if
nodes
==
'all'
else
nodes
return
self
.
G
.
nodes
()
if
nodes
==
'all'
else
nodes
def
_edges_or_all
(
self
,
edges
=
'all'
):
def
_edges_or_all
(
self
,
edges
=
'all'
):
return
self
.
edges
()
if
edges
==
'all'
else
edges
return
self
.
G
.
edges
()
if
edges
==
'all'
else
edges
def
_node_tag_name
(
self
,
v
):
def
_node_tag_name
(
self
,
v
):
return
'(%s)'
%
v
return
'(%s)'
%
v
...
@@ -50,59 +38,67 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -50,59 +38,67 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
self
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
))
self
.
G
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
))
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
G
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
node
[
v
][
'tag'
])
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
G
.
node
[
v
][
'tag'
])
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
edges
=
self
.
_edges_or_all
(
edges
)
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
for
u
,
v
in
edges
:
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
G
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
[
u
][
v
][
'tag'
])
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
.
G
[
u
][
v
][
'tag'
])
def
remove_node_tag
(
self
,
nodes
=
'all'
):
def
remove_node_tag
(
self
,
nodes
=
'all'
):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
delattr
(
self
,
self
.
_node_tag_name
(
v
))
delattr
(
self
,
self
.
_node_tag_name
(
v
))
del
self
.
node
[
v
][
'tag'
]
del
self
.
G
.
node
[
v
][
'tag'
]
def
remove_edge_tag
(
self
,
edges
=
'all'
):
def
remove_edge_tag
(
self
,
edges
=
'all'
):
edges
=
self
.
_edges_or_all
(
edges
)
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
for
u
,
v
in
edges
:
delattr
(
self
,
self
.
_edge_tag_name
(
u
,
v
))
delattr
(
self
,
self
.
_edge_tag_name
(
u
,
v
))
del
self
[
u
][
v
][
'tag'
]
del
self
.
G
[
u
][
v
][
'tag'
]
@
property
def
node
(
self
):
return
self
.
G
.
node
@
property
def
edges
(
self
):
return
self
.
G
.
edges
def
edge_tags
(
self
):
def
edge_tags
(
self
):
for
u
,
v
in
self
.
edges
():
for
u
,
v
in
self
.
G
.
edges
():
yield
self
[
u
][
v
][
'tag'
]
yield
self
.
G
[
u
][
v
][
'tag'
]
def
node_tags
(
self
):
def
node_tags
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
self
.
node
[
v
][
'tag'
]
yield
self
.
G
.
node
[
v
][
'tag'
]
def
states
(
self
):
def
states
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
self
.
node
[
v
][
'state'
]
yield
self
.
G
.
node
[
v
][
'state'
]
def
named_edge_tags
(
self
):
def
named_edge_tags
(
self
):
for
u
,
v
in
self
.
edges
():
for
u
,
v
in
self
.
G
.
edges
():
yield
((
u
,
v
),
self
[
u
][
v
][
'tag'
])
yield
((
u
,
v
),
self
.
G
[
u
][
v
][
'tag'
])
def
named_node_tags
(
self
):
def
named_node_tags
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
(
v
,
self
.
node
[
v
][
'tag'
])
yield
(
v
,
self
.
G
.
node
[
v
][
'tag'
])
def
named_states
(
self
):
def
named_states
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
(
v
,
self
.
node
[
v
][
'state'
])
yield
(
v
,
self
.
G
.
node
[
v
][
'state'
])
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batched
=
False
):
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batched
=
False
):
'''
'''
...
@@ -126,24 +122,24 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -126,24 +122,24 @@ class DiGraph(nx.DiGraph, NN.Module):
if
batched
:
if
batched
:
# FIXME: need to optimize since we are repeatedly stacking and
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
# unpacking
source
=
T
.
stack
([
self
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
source
=
T
.
stack
([
self
.
G
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
edge_tags
=
[
self
[
u
][
v
]
[
'tag'
]
for
u
,
v
in
ebunch
]
edge_tags
=
[
self
.
G
[
u
][
v
]
.
get
(
'tag'
,
None
)
for
u
,
v
in
ebunch
]
if
all
(
t
is
None
for
t
in
edge_tags
):
if
all
(
t
is
None
for
t
in
edge_tags
):
edge_tag
=
None
edge_tag
=
None
else
:
else
:
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
edge_tag
=
T
.
stack
([
self
.
G
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
message
=
f
(
source
,
edge_tag
)
message
=
f
(
source
,
edge_tag
)
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
self
[
u
][
v
][
'state'
]
=
message
[
i
]
self
.
G
[
u
][
v
][
'state'
]
=
message
[
i
]
else
:
else
:
for
u
,
v
in
ebunch
:
for
u
,
v
in
ebunch
:
self
[
u
][
v
][
'state'
]
=
f
(
self
.
G
[
u
][
v
][
'state'
]
=
f
(
self
.
node
[
u
][
'state'
],
self
.
G
.
node
[
u
][
'state'
],
self
[
u
][
v
][
'tag'
]
self
.
G
[
u
][
v
][
'tag'
]
)
)
# update state
# update state
# TODO: does it make sense to batch update the nodes?
# TODO: does it make sense to batch update the nodes?
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
for
v
in
vbunch
:
for
v
in
vbunch
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
.
in_edges
(
v
,
data
=
True
))
self
.
G
.
node
[
v
][
'state'
]
=
f
(
self
.
G
.
node
[
v
],
list
(
self
.
G
.
in_edges
(
v
,
data
=
True
))
)
model.py
View file @
a3ce6e2f
...
@@ -105,7 +105,7 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -105,7 +105,7 @@ class TreeGlimpsedClassifier(NN.Module):
G
.
add_edges_from
(
hy_edge_list
)
G
.
add_edges_from
(
hy_edge_list
)
G
.
add_edges_from
(
hb_edge_list
)
G
.
add_edges_from
(
hb_edge_list
)
self
.
G
=
DiGraph
(
G
)
self
.
G
=
DiGraph
(
nx
.
DiGraph
(
G
)
)
hh_edge_list
=
[(
u
,
v
)
hh_edge_list
=
[(
u
,
v
)
for
u
,
v
in
self
.
G
.
edges
()
for
u
,
v
in
self
.
G
.
edges
()
if
self
.
G
.
node
[
u
][
'type'
]
==
self
.
G
.
node
[
v
][
'type'
]
==
'h'
]
if
self
.
G
.
node
[
u
][
'type'
]
==
self
.
G
.
node
[
v
][
'type'
]
==
'h'
]
...
@@ -175,6 +175,7 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -175,6 +175,7 @@ class TreeGlimpsedClassifier(NN.Module):
n_bh_edges
,
batch_size
,
_
=
source
.
shape
n_bh_edges
,
batch_size
,
_
=
source
.
shape
# FIXME: really using self.x is a bad design here
# FIXME: really using self.x is a bad design here
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
source
,
_
=
self
.
glimpse
.
rescale
(
source
,
False
)
_source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
_source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
m_b
=
T
.
relu
(
self
.
bh_1
(
_source
))
m_b
=
T
.
relu
(
self
.
bh_1
(
_source
))
...
@@ -232,23 +233,19 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -232,23 +233,19 @@ class TreeGlimpsedClassifier(NN.Module):
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
full
=
self
.
glimpse
.
full
().
unsqueeze
(
0
).
expand
(
batch_size
,
self
.
glimpse
.
att_params
)
self
.
G
.
zero_node_state
((
self
.
glimpse
.
att_params
,),
batch_size
,
nodes
=
self
.
b_nodes_list
)
for
v
in
self
.
G
.
nodes
():
if
self
.
G
.
node
[
v
][
'type'
]
==
'b'
:
# Initialize bbox variables to cover the entire canvas
self
.
G
.
node
[
v
][
'state'
]
=
full
for
t
in
range
(
self
.
steps
):
for
t
in
range
(
self
.
steps
):
self
.
G
.
step
()
self
.
G
.
step
()
# We don't change b of the root
# We don't change b of the root
self
.
G
.
node
[
'b0'
][
'state'
]
=
full
self
.
G
.
node
[
'b0'
][
'state'
]
.
zero_
()
self
.
y_pre
=
T
.
stack
(
self
.
y_pre
=
T
.
stack
(
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
1
1
)
)
self
.
v_B
=
T
.
stack
(
self
.
v_B
=
T
.
stack
(
[
self
.
G
.
node
[
'b%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
)],
[
self
.
glimpse
.
rescale
(
self
.
G
.
node
[
'b%d'
%
i
][
'state'
]
,
False
)[
0
]
for
i
in
range
(
self
.
n_nodes
)],
1
,
1
,
)
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment