Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9c9ac7c9
Commit
9c9ac7c9
authored
May 10, 2018
by
Gan Quan
Browse files
updates to support nx 2.1
parent
51391012
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
52 deletions
+45
-52
graph.py
graph.py
+40
-44
model.py
model.py
+5
-8
No files found.
graph.py
View file @
9c9ac7c9
...
@@ -3,37 +3,25 @@ import torch as T
...
@@ -3,37 +3,25 @@ import torch as T
import
torch.nn
as
NN
import
torch.nn
as
NN
from
util
import
*
from
util
import
*
class
DiGraph
(
nx
.
DiGraph
,
NN
.
Module
):
class
DiGraph
(
NN
.
Module
):
'''
'''
Reserved attributes:
Reserved attributes:
* state: node state vectors during message passing iterations
* state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages
edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data
* tag: node-/edge-specific feature tensors or other data
'''
'''
def
__init__
(
self
,
data
=
None
,
**
attr
):
def
__init__
(
self
,
graph
):
NN
.
Module
.
__init__
(
self
)
NN
.
Module
.
__init__
(
self
)
nx
.
DiGraph
.
__init__
(
self
,
data
=
data
,
**
attr
)
self
.
G
=
graph
self
.
message_funcs
=
[]
self
.
message_funcs
=
[]
self
.
update_funcs
=
[]
self
.
update_funcs
=
[]
def
add_node
(
self
,
n
,
state
=
None
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_node
(
self
,
n
,
state
=
state
,
tag
=
None
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_nodes_from
(
self
,
nodes
,
state
=
None
,
tag
=
None
,
**
attr
):
nx
.
DiGraph
.
add_nodes_from
(
self
,
nodes
,
state
=
state
,
tag
=
tag
,
**
attr
)
def
add_edge
(
self
,
u
,
v
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edge
(
self
,
u
,
v
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_edges_from
(
self
,
ebunch
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
return
self
.
nodes
()
if
nodes
==
'all'
else
nodes
return
self
.
G
.
nodes
()
if
nodes
==
'all'
else
nodes
def
_edges_or_all
(
self
,
edges
=
'all'
):
def
_edges_or_all
(
self
,
edges
=
'all'
):
return
self
.
edges
()
if
edges
==
'all'
else
edges
return
self
.
G
.
edges
()
if
edges
==
'all'
else
edges
def
_node_tag_name
(
self
,
v
):
def
_node_tag_name
(
self
,
v
):
return
'(%s)'
%
v
return
'(%s)'
%
v
...
@@ -50,59 +38,67 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -50,59 +38,67 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
self
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
))
self
.
G
.
node
[
v
][
'state'
]
=
tovar
(
T
.
zeros
(
shape
))
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
G
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
node
[
v
][
'tag'
])
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
G
.
node
[
v
][
'tag'
])
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
edges
=
self
.
_edges_or_all
(
edges
)
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
for
u
,
v
in
edges
:
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
G
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
tovar
(
T
.
zeros
(
shape
,
dtype
=
dtype
))),
*
args
)
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
[
u
][
v
][
'tag'
])
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
.
G
[
u
][
v
][
'tag'
])
def
remove_node_tag
(
self
,
nodes
=
'all'
):
def
remove_node_tag
(
self
,
nodes
=
'all'
):
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
for
v
in
nodes
:
delattr
(
self
,
self
.
_node_tag_name
(
v
))
delattr
(
self
,
self
.
_node_tag_name
(
v
))
del
self
.
node
[
v
][
'tag'
]
del
self
.
G
.
node
[
v
][
'tag'
]
def
remove_edge_tag
(
self
,
edges
=
'all'
):
def
remove_edge_tag
(
self
,
edges
=
'all'
):
edges
=
self
.
_edges_or_all
(
edges
)
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
for
u
,
v
in
edges
:
delattr
(
self
,
self
.
_edge_tag_name
(
u
,
v
))
delattr
(
self
,
self
.
_edge_tag_name
(
u
,
v
))
del
self
[
u
][
v
][
'tag'
]
del
self
.
G
[
u
][
v
][
'tag'
]
@
property
def
node
(
self
):
return
self
.
G
.
node
@
property
def
edges
(
self
):
return
self
.
G
.
edges
def
edge_tags
(
self
):
def
edge_tags
(
self
):
for
u
,
v
in
self
.
edges
():
for
u
,
v
in
self
.
G
.
edges
():
yield
self
[
u
][
v
][
'tag'
]
yield
self
.
G
[
u
][
v
][
'tag'
]
def
node_tags
(
self
):
def
node_tags
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
self
.
node
[
v
][
'tag'
]
yield
self
.
G
.
node
[
v
][
'tag'
]
def
states
(
self
):
def
states
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
self
.
node
[
v
][
'state'
]
yield
self
.
G
.
node
[
v
][
'state'
]
def
named_edge_tags
(
self
):
def
named_edge_tags
(
self
):
for
u
,
v
in
self
.
edges
():
for
u
,
v
in
self
.
G
.
edges
():
yield
((
u
,
v
),
self
[
u
][
v
][
'tag'
])
yield
((
u
,
v
),
self
.
G
[
u
][
v
][
'tag'
])
def
named_node_tags
(
self
):
def
named_node_tags
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
(
v
,
self
.
node
[
v
][
'tag'
])
yield
(
v
,
self
.
G
.
node
[
v
][
'tag'
])
def
named_states
(
self
):
def
named_states
(
self
):
for
v
in
self
.
nodes
():
for
v
in
self
.
G
.
nodes
():
yield
(
v
,
self
.
node
[
v
][
'state'
])
yield
(
v
,
self
.
G
.
node
[
v
][
'state'
])
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batched
=
False
):
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batched
=
False
):
'''
'''
...
@@ -126,24 +122,24 @@ class DiGraph(nx.DiGraph, NN.Module):
...
@@ -126,24 +122,24 @@ class DiGraph(nx.DiGraph, NN.Module):
if
batched
:
if
batched
:
# FIXME: need to optimize since we are repeatedly stacking and
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
# unpacking
source
=
T
.
stack
([
self
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
source
=
T
.
stack
([
self
.
G
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
edge_tags
=
[
self
[
u
][
v
]
[
'tag'
]
for
u
,
v
in
ebunch
]
edge_tags
=
[
self
.
G
[
u
][
v
]
.
get
(
'tag'
,
None
)
for
u
,
v
in
ebunch
]
if
all
(
t
is
None
for
t
in
edge_tags
):
if
all
(
t
is
None
for
t
in
edge_tags
):
edge_tag
=
None
edge_tag
=
None
else
:
else
:
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
edge_tag
=
T
.
stack
([
self
.
G
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
message
=
f
(
source
,
edge_tag
)
message
=
f
(
source
,
edge_tag
)
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
for
i
,
(
u
,
v
)
in
enumerate
(
ebunch
):
self
[
u
][
v
][
'state'
]
=
message
[
i
]
self
.
G
[
u
][
v
][
'state'
]
=
message
[
i
]
else
:
else
:
for
u
,
v
in
ebunch
:
for
u
,
v
in
ebunch
:
self
[
u
][
v
][
'state'
]
=
f
(
self
.
G
[
u
][
v
][
'state'
]
=
f
(
self
.
node
[
u
][
'state'
],
self
.
G
.
node
[
u
][
'state'
],
self
[
u
][
v
][
'tag'
]
self
.
G
[
u
][
v
][
'tag'
]
)
)
# update state
# update state
# TODO: does it make sense to batch update the nodes?
# TODO: does it make sense to batch update the nodes?
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
for
vbunch
,
f
,
batched
in
self
.
update_funcs
:
for
v
in
vbunch
:
for
v
in
vbunch
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
.
in_edges
(
v
,
data
=
True
))
self
.
G
.
node
[
v
][
'state'
]
=
f
(
self
.
G
.
node
[
v
],
list
(
self
.
G
.
in_edges
(
v
,
data
=
True
))
)
model.py
View file @
9c9ac7c9
...
@@ -105,7 +105,7 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -105,7 +105,7 @@ class TreeGlimpsedClassifier(NN.Module):
G
.
add_edges_from
(
hy_edge_list
)
G
.
add_edges_from
(
hy_edge_list
)
G
.
add_edges_from
(
hb_edge_list
)
G
.
add_edges_from
(
hb_edge_list
)
self
.
G
=
DiGraph
(
G
)
self
.
G
=
DiGraph
(
nx
.
DiGraph
(
G
)
)
hh_edge_list
=
[(
u
,
v
)
hh_edge_list
=
[(
u
,
v
)
for
u
,
v
in
self
.
G
.
edges
()
for
u
,
v
in
self
.
G
.
edges
()
if
self
.
G
.
node
[
u
][
'type'
]
==
self
.
G
.
node
[
v
][
'type'
]
==
'h'
]
if
self
.
G
.
node
[
u
][
'type'
]
==
self
.
G
.
node
[
v
][
'type'
]
==
'h'
]
...
@@ -175,6 +175,7 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -175,6 +175,7 @@ class TreeGlimpsedClassifier(NN.Module):
n_bh_edges
,
batch_size
,
_
=
source
.
shape
n_bh_edges
,
batch_size
,
_
=
source
.
shape
# FIXME: really using self.x is a bad design here
# FIXME: really using self.x is a bad design here
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
source
,
_
=
self
.
glimpse
.
rescale
(
source
,
False
)
_source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
_source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
m_b
=
T
.
relu
(
self
.
bh_1
(
_source
))
m_b
=
T
.
relu
(
self
.
bh_1
(
_source
))
...
@@ -232,23 +233,19 @@ class TreeGlimpsedClassifier(NN.Module):
...
@@ -232,23 +233,19 @@ class TreeGlimpsedClassifier(NN.Module):
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
h_dims
,),
batch_size
,
nodes
=
self
.
h_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
self
.
G
.
zero_node_state
((
self
.
n_classes
,),
batch_size
,
nodes
=
self
.
y_nodes_list
)
full
=
self
.
glimpse
.
full
().
unsqueeze
(
0
).
expand
(
batch_size
,
self
.
glimpse
.
att_params
)
self
.
G
.
zero_node_state
((
self
.
glimpse
.
att_params
,),
batch_size
,
nodes
=
self
.
b_nodes_list
)
for
v
in
self
.
G
.
nodes
():
if
self
.
G
.
node
[
v
][
'type'
]
==
'b'
:
# Initialize bbox variables to cover the entire canvas
self
.
G
.
node
[
v
][
'state'
]
=
full
for
t
in
range
(
self
.
steps
):
for
t
in
range
(
self
.
steps
):
self
.
G
.
step
()
self
.
G
.
step
()
# We don't change b of the root
# We don't change b of the root
self
.
G
.
node
[
'b0'
][
'state'
]
=
full
self
.
G
.
node
[
'b0'
][
'state'
]
.
zero_
()
self
.
y_pre
=
T
.
stack
(
self
.
y_pre
=
T
.
stack
(
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
[
self
.
G
.
node
[
'y%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
-
1
,
self
.
n_nodes
-
self
.
n_leaves
-
1
,
-
1
)],
1
1
)
)
self
.
v_B
=
T
.
stack
(
self
.
v_B
=
T
.
stack
(
[
self
.
G
.
node
[
'b%d'
%
i
][
'state'
]
for
i
in
range
(
self
.
n_nodes
)],
[
self
.
glimpse
.
rescale
(
self
.
G
.
node
[
'b%d'
%
i
][
'state'
]
,
False
)[
0
]
for
i
in
range
(
self
.
n_nodes
)],
1
,
1
,
)
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
self
.
y_logprob
=
F
.
log_softmax
(
self
.
y_pre
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment