Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
10ec6e8a
Commit
10ec6e8a
authored
May 02, 2018
by
Gan Quan
Browse files
graph draft
parent
7a50be64
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
338 additions
and
0 deletions
+338
-0
graph.py
graph.py
+143
-0
model.py
model.py
+195
-0
No files found.
graph.py
0 → 100644
View file @
10ec6e8a
import
networkx
as
nx
import
torch
as
T
import
torch.nn
as
NN
class
DiGraph
(
nx
.
DiGraph
,
NN
.
Module
):
'''
Reserved attributes:
* state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data
'''
def
__init__
(
self
,
data
=
None
,
**
attr
):
NN
.
Module
.
__init__
(
self
)
nx
.
DiGraph
.
__init__
(
self
,
data
=
data
,
**
attr
)
self
.
message_funcs
=
[]
self
.
update_funcs
=
[]
def
add_node
(
self
,
n
,
state
=
None
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_node
(
self
,
n
,
state
=
state
,
tag
=
None
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_nodes_from
(
self
,
nodes
,
state
=
None
,
tag
=
None
,
**
attr
):
nx
.
DiGraph
.
add_nodes_from
(
self
,
nodes
,
state
=
state
,
tag
=
tag
,
**
attr
)
def
add_edge
(
self
,
u
,
v
,
tag
=
None
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edge
(
self
,
u
,
v
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
None
,
**
attr
):
nx
.
DiGraph
.
add_edges_from
(
self
,
ebunch
,
tag
=
tag
,
attr_dict
=
attr_dict
,
**
attr
)
def
_nodes_or_all
(
self
,
nodes
=
'all'
):
return
self
.
nodes
()
if
nodes
==
'all'
else
nodes
def
_edges_or_all
(
self
,
edges
=
'all'
):
return
self
.
edges
()
if
edges
==
'all'
else
edges
def
_node_tag_name
(
self
,
v
):
return
'(%s)'
%
v
def
_edge_tag_name
(
self
,
u
,
v
):
return
'(%s, %s)'
%
(
min
(
u
,
v
),
max
(
u
,
v
))
def
zero_node_state
(
self
,
state_dims
,
batch_size
=
None
,
nodes
=
'all'
):
shape
=
(
[
batch_size
]
+
list
(
state_dims
)
if
batch_size
is
not
None
else
state_dims
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
node
[
v
][
'state'
]
=
T
.
zeros
(
shape
)
def
init_node_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
nodes
=
'all'
,
args
=
()):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
self
.
node
[
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
.
register_parameter
(
self
.
_node_tag_name
(
v
),
self
.
node
[
v
][
'tag'
])
def
init_edge_tag_with
(
self
,
shape
,
init_func
,
dtype
=
T
.
float32
,
edges
=
'all'
,
args
=
()):
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
self
[
u
][
v
][
'tag'
]
=
init_func
(
NN
.
Parameter
(
T
.
zeros
(
shape
,
dtype
=
dtype
)),
*
args
)
self
.
register_parameter
(
self
.
_edge_tag_name
(
u
,
v
),
self
[
u
][
v
][
'tag'
])
def
remove_node_tag
(
self
,
nodes
=
'all'
):
nodes
=
self
.
_nodes_or_all
(
nodes
)
for
v
in
nodes
:
delattr
(
self
,
self
.
_node_tag_name
(
v
))
del
self
.
node
[
v
][
'tag'
]
def
remove_edge_tag
(
self
,
edges
=
'all'
):
edges
=
self
.
_edges_or_all
(
edges
)
for
u
,
v
in
edges
:
delattr
(
self
,
self
.
_edge_tag_name
(
u
,
v
))
del
self
[
u
][
v
][
'tag'
]
def
edge_tags
(
self
):
for
u
,
v
in
self
.
edges
():
yield
self
[
u
][
v
][
'tag'
]
def
node_tags
(
self
):
for
v
in
self
.
nodes
():
yield
self
.
node
[
v
][
'tag'
]
def
states
(
self
):
for
v
in
self
.
nodes
():
yield
self
.
node
[
v
][
'state'
]
def
named_edge_tags
(
self
):
for
u
,
v
in
self
.
edges
():
yield
((
u
,
v
),
self
[
u
][
v
][
'tag'
])
def
named_node_tags
(
self
):
for
v
in
self
.
nodes
():
yield
(
v
,
self
.
node
[
v
][
'tag'
])
def
named_states
(
self
):
for
v
in
self
.
nodes
():
yield
(
v
,
self
.
node
[
v
][
'state'
])
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batched
=
False
):
'''
batched: whether to do a single batched computation instead of iterating
message function: accepts source state tensor and edge tag tensor, and
returns a message tensor
'''
self
.
message_funcs
.
append
((
self
.
_edges_or_all
(
edges
),
message_func
,
batched
))
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batched
=
False
):
'''
batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag),
and a dictionary of edge attribute dictionaries
'''
self
.
update_funcs
.
append
((
self
.
_nodes_or_all
(
nodes
),
update_func
,
batched
))
def
step
(
self
):
# update message
for
ebunch
,
f
,
batched
in
self
.
message_funcs
:
if
batched
:
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
source
=
T
.
stack
([
self
.
node
[
u
][
'state'
]
for
u
,
_
in
ebunch
])
edge_tag
=
T
.
stack
([
self
[
u
][
v
][
'tag'
]
for
u
,
v
in
ebunch
])
message
=
f
(
source
,
edge_tag
)
for
u
,
v
in
ebunch
:
self
[
u
][
v
][
'state'
]
=
message
else
:
for
u
,
v
in
ebunch
:
self
[
u
][
v
][
'state'
]
=
f
(
self
.
node
[
u
][
'state'
],
self
[
u
][
v
][
'tag'
]
)
# update state
# TODO: does it make sense to batch update the nodes?
for
v
,
f
in
self
.
update_funcs
:
self
.
node
[
v
][
'state'
]
=
f
(
self
.
node
[
v
],
self
[
v
])
model.py
0 → 100644
View file @
10ec6e8a
import
torch
as
T
import
torch.nn
as
NN
import
networkx
as
nx
from
graph
import
DiGraph
class
TreeGlimpsedClassifier
(
NN
.
Module
):
def
__init__
(
self
,
n_children
=
2
,
n_depth
=
3
,
h_dims
=
128
,
node_tag_dims
=
128
,
edge_tag_dims
=
128
,
h_dims
=
128
,
n_classes
=
10
,
steps
=
5
,
):
'''
Basic idea:
* We detect objects through an undirected graphical model.
* The graphical model consists of a balanced tree of latent variables h
* Each h is then connected to a bbox variable b and a class variable y
* b of the root is fixed to cover the entire canvas
* All other h, b and y are updated through message passing
* The loss function should be either (not completed yet)
* multiset loss, or
* maximum bipartite matching (like Order Matters paper)
'''
NN
.
Module
.
__init__
(
self
)
self
.
n_children
=
n_children
self
.
n_depth
=
n_depth
self
.
h_dims
=
h_dims
self
.
node_tag_dims
=
node_tag_dims
self
.
edge_tag_dims
=
edge_tag_dims
self
.
h_dims
=
h_dims
self
.
n_classes
=
n_classes
# Create graph of latent variables
G
=
nx
.
balanced_tree
(
self
.
n_children
,
self
.
n_depth
)
nx
.
relabel_nodes
(
G
,
{
i
:
'h%d'
%
i
for
i
in
range
(
self
.
G
.
nodes
())},
False
)
h_nodes_list
=
G
.
nodes
()
for
h
in
h_nodes_list
:
G
.
node
[
h
][
'type'
]
=
'h'
b_nodes_list
=
[
'b%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
y_nodes_list
=
[
'y%d'
%
i
for
i
in
range
(
len
(
h_nodes_list
))]
hy_edge_list
=
[(
h
,
y
)
for
h
,
y
in
zip
(
h_nodes_list
,
y_nodes_list
)]
hb_edge_list
=
[(
h
,
b
)
for
h
,
b
in
zip
(
h_nodes_list
,
y_nodes_list
)]
yh_edge_list
=
[(
y
,
h
)
for
y
,
h
in
zip
(
y_nodes_list
,
h_nodes_list
)]
bh_edge_list
=
[(
b
,
h
)
for
b
,
h
in
zip
(
b_nodes_list
,
h_nodes_list
)]
G
.
add_nodes_from
(
b_nodes_list
,
type
=
'b'
)
G
.
add_nodes_from
(
y_nodes_list
,
type
=
'y'
)
G
.
add_edges_from
(
hy_edge_list
)
G
.
add_edges_from
(
hb_edge_list
)
self
.
G
=
DiGraph
(
G
)
hh_edge_list
=
[(
u
,
v
)
for
u
,
v
in
self
.
G
.
edges
()
if
self
.
G
.
node
[
u
][
'type'
]
==
self
.
G
.
node
[
v
][
'type'
]
==
'h'
]
self
.
G
.
init_node_tag_with
(
node_tag_dims
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
))
self
.
G
.
init_edge_tag_with
(
edge_tag_dims
,
T
.
nn
.
init
.
uniform_
,
args
=
(
-
.
01
,
.
01
),
edges
=
hy_edge_list
+
hb_edge_list
+
yh_edge_list
,
bh_edge_list
)
# y -> h. An attention over embeddings dynamically generated through edge tags
self
.
yh_emb
=
NN
.
Sequential
(
NN
.
Linear
(
edge_tag_dims
,
h_dims
),
NN
.
ReLU
(),
NN
.
Linear
(
h_dims
,
n_classes
*
h_dims
),
)
self
.
G
.
register_message_func
(
self
.
_y_to_h
,
edges
=
yh_edge_list
,
batched
=
True
)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self
.
bh_1
=
NN
.
Linear
(
self
.
glimpse
.
att_params
,
h_dims
)
self
.
bh_2
=
NN
.
Linear
(
edge_tag_dims
,
h_dims
)
self
.
bh_all
=
NN
.
Linear
(
3
*
h_dims
,
h_dims
)
self
.
G
.
register_message_func
(
self
.
_b_to_h
,
edges
=
bh_edge_list
,
batched
=
True
)
# h -> h. Just passes h itself
self
.
G
.
register_message_func
(
self
.
_h_to_h
,
edges
=
hh_edge_list
,
batched
=
True
)
# h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb
self
.
hb
=
NN
.
Linear
(
hidden_layers
+
edge_tag_dims
,
self
.
glimpse
.
att_params
)
self
.
G
.
register_message_func
(
self
.
_h_to_b
,
edges
=
hb_edge_list
,
batched
=
True
)
# h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy
self
.
hy
=
NN
.
Linear
(
hidden_layers
+
edge_tag_dims
,
self
.
n_classes
)
self
.
G
.
register_message_func
(
self
.
_h_to_y
,
edges
=
hy_edge_list
,
batched
=
True
)
# b update: just adds the original b by Δb
self
.
G
.
register_update_func
(
self
.
_update_b
,
nodes
=
b_nodes_list
,
batched
=
False
)
# y update: also adds y by Δy
self
.
G
.
register_update_func
(
self
.
_update_y
,
nodes
=
y_nodes_list
,
batched
=
False
)
# h update: simply adds h by the average messages and then passes it through ReLU
self
.
G
.
register_update_func
(
self
.
_update_h
,
nodes
=
h_nodes_list
,
batched
=
False
)
def
_y_to_h
(
self
,
source
,
edge_tag
):
'''
source: (n_yh_edges, batch_size, 10) logits
edge_tag: (n_yh_edges, edge_tag_dims)
'''
n_yh_edges
,
batch_size
,
_
=
source
.
shape
if
not
self
.
_yh_emb_cached
:
self
.
_yh_emb_cached
=
True
self
.
_yh_emb_w
=
self
.
yh_emb
(
edge_tag
)
self
.
_yh_emb_w
=
self
.
_yh_emb_w
.
reshape
(
n_yh_edges
,
self
.
n_classes
,
self
.
h_dims
)
w
=
self
.
_yh_emb_w
[:,
None
]
w
=
w
.
expand
(
n_yh_edges
,
batch_size
,
self
.
n_classes
,
self
.
h_dims
)
source
=
source
[:,
:,
None
,
:]
return
(
F
.
softmax
(
source
)
@
w
).
reshape
(
n_yh_edges
,
batch_size
,
self
.
h_dims
)
def
_b_to_h
(
self
,
source
,
edge_tag
):
'''
source: (n_bh_edges, batch_size, 6) bboxes
edge_tag: (n_bh_edges, edge_tag_dims)
'''
n_bh_edges
,
batch_size
,
_
=
source
.
shape
_
,
nchan
,
nrows
,
ncols
=
self
.
x
.
size
()
source
=
source
.
reshape
(
-
1
,
self
.
glimpse
.
att_params
)
m_b
=
T
.
relu
(
self
.
bh_1
(
source
))
m_t
=
T
.
relu
(
self
.
bh_2
(
edge_tag
))
m_t
=
m_t
[:,
None
,
:].
expand
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
m_t
=
m_t
.
reshape
(
-
1
,
self
.
h_dims
)
# glimpse takes batch dimension first, glimpse dimension second.
# here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them
g
=
self
.
glimpse
(
self
.
x
,
source
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
g
=
g
.
reshape
(
n_bh_edges
*
batch_size
,
nchan
,
nrows
,
ncols
)
phi
=
self
.
cnn
(
g
).
reshape
(
n_bh_edges
*
batch_size
,
-
1
)
m
=
self
.
bh_all
(
T
.
cat
([
m_b
,
m_t
,
phi
],
1
))
m
=
m
.
reshape
(
n_bh_edges
,
batch_size
,
self
.
h_dims
)
return
m
def
_h_to_h
(
self
,
source
,
edge_tag
):
return
source
def
_h_to_b
(
self
,
source
,
edge_tag
):
n_hb_edges
,
batch_size
,
_
=
source
.
shape
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_hb_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hb_edges
*
batch_size
,
-
1
)
b
=
self
.
hb
(
I
)
return
db
def
_h_to_y
(
self
,
source
,
edge_tag
):
n_hy_edges
,
batch_size
,
_
=
source
.
shape
edge_tag
=
edge_tag
[:,
None
]
edge_tag
=
edge_tag
.
expand
(
n_hb_edges
,
batch_size
,
self
.
edge_tag_dims
)
I
=
T
.
cat
([
source
,
edge_tag
],
-
1
).
reshape
(
n_hb_edges
*
batch_size
,
-
1
)
y
=
self
.
hy
(
I
)
return
dy
def
_update_b
(
self
,
b
,
b_n
):
return
b
[
'state'
]
+
list
(
b_n
.
values
())[
0
][
'state'
]
def
_update_y
(
self
,
y
,
y_n
):
return
y
[
'state'
]
+
list
(
y_n
.
values
())[
0
][
'state'
]
def
_update_h
(
self
,
h
,
h_n
):
m
=
T
.
stack
([
e
[
'state'
]
for
e
in
h_n
]).
mean
(
0
)
return
T
.
relu
(
h
+
m
)
def
forward
(
self
,
x
):
batch_size
=
x
.
shape
[
0
]
self
.
G
.
zero_node_state
(
self
.
h_dims
,
batch_size
,
nodes
=
h_nodes_list
)
full
=
self
.
glimpse
.
full
().
unsqueeze
(
0
).
expand
(
batch_size
,
self
.
glimpse
.
att_params
)
for
v
in
self
.
G
.
nodes
():
if
G
.
node
[
v
][
'type'
]
==
'b'
:
# Initialize bbox variables to cover the entire canvas
self
.
G
.
node
[
v
][
'state'
]
=
full
self
.
_yh_emb_cached
=
False
for
t
in
range
(
self
.
steps
):
self
.
G
.
step
()
# We don't change b of the root
self
.
G
.
node
[
'b0'
][
'state'
]
=
full
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment