Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2b092811
Commit
2b092811
authored
Sep 10, 2018
by
GaiYu0
Browse files
line graph
parent
d772d390
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
78 additions
and
212 deletions
+78
-212
examples/pytorch/line_graph/gnn.py
examples/pytorch/line_graph/gnn.py
+69
-194
examples/pytorch/line_graph/test.py
examples/pytorch/line_graph/test.py
+9
-18
No files found.
examples/pytorch/line_graph/gnn.py
View file @
2b092811
...
...
@@ -9,231 +9,106 @@ Deviations from paper:
# TODO self-loop?
# TODO in-place edit of node_reprs/edge_reprs in message_func/update_func?
# TODO batch-norm
import
copy
import
itertools
import
dgl.graph
as
G
import
dgl
import
dgl.function
as
fn
import
networkx
as
nx
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
GLGModule
(
nn
.
Module
):
__SHADOW__
=
'shadow'
class
GNNModule
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
radius
):
super
().
__init__
()
self
.
out_feats
=
out_feats
self
.
radius
=
radius
new_linear
=
lambda
:
nn
.
Linear
(
in_feats
,
out_feats
)
new_linear
=
lambda
:
nn
.
Linear
(
in_feats
,
out_feats
*
2
)
new_module_list
=
lambda
:
nn
.
ModuleList
([
new_linear
()
for
i
in
range
(
radius
)])
self
.
theta_x
,
self
.
theta_y
,
self
.
theta_deg
,
self
.
theta_
global
=
\
new_linear
(),
new_linear
(),
new_linear
()
,
new_linear
()
self
.
theta_x
,
self
.
theta_deg
,
self
.
theta_
y
=
\
new_linear
(),
new_linear
(),
new_linear
()
self
.
theta_list
=
new_module_list
()
self
.
gamma_x
,
self
.
gamma_y
,
self
.
gamma_deg
,
self
.
gamma_
global
=
\
new_linear
(),
new_linear
(),
new_linear
()
,
new_linear
()
self
.
gamma_y
,
self
.
gamma_deg
,
self
.
gamma_
x
=
\
new_linear
(),
new_linear
(),
new_linear
()
self
.
gamma_list
=
new_module_list
()
@
staticmethod
def
copy
(
which
):
if
which
==
'src'
:
return
lambda
src
,
trg
,
_
:
src
.
copy
()
elif
which
==
'trg'
:
return
lambda
src
,
trg
,
_
:
trg
.
copy
()
@
staticmethod
def
aggregate
(
msg_fld
,
trg_fld
,
normalize
=
False
):
def
a
(
node_reprs
,
edge_reprs
):
node_reprs
=
node_reprs
.
copy
()
node_reprs
[
trg_fld
]
=
sum
(
msg
[
msg_fld
]
for
msg
in
edge_reprs
)
if
normalize
:
node_reprs
[
trg_fld
]
/=
len
(
edge_reprs
)
return
node_reprs
return
a
@
staticmethod
def
pull
(
msg_fld
,
trg_fld
):
def
p
(
node_reprs
,
edge_reprs
):
node_reprs
=
node_reprs
.
copy
()
node_reprs
[
trg_fld
]
=
edge_reprs
[
0
][
msg_fld
]
return
node_reprs
return
p
def
local_aggregate
(
self
,
g
):
def
step
():
g
.
register_message_func
(
self
.
copy
(
'src'
),
g
.
edges
)
g
.
register_update_func
(
self
.
aggregate
(
'x'
,
'x'
),
g
.
nodes
)
g
.
update_all
()
step
()
for
reprs
in
g
.
nodes
.
values
():
reprs
[
0
]
=
reprs
[
'x'
]
for
i
in
range
(
1
,
self
.
radius
):
for
j
in
range
(
2
**
(
i
-
1
)):
step
()
for
reprs
in
g
.
nodes
.
values
():
reprs
[
i
]
=
reprs
[
'x'
]
@
staticmethod
def
global_aggregate
(
g
):
shadow
=
GLGModule
.
__SHADOW__
copy
,
aggregate
,
pull
=
GLGModule
.
copy
,
GLGModule
.
aggregate
,
GLGModule
.
pull
node_list
=
list
(
g
.
nodes
)
uv_list
=
[(
node
,
shadow
)
for
node
in
g
.
nodes
]
vu_list
=
[(
shadow
,
node
)
for
node
in
g
.
nodes
]
g
.
add_node
(
shadow
)
# TODO context manager
tuple
(
itertools
.
starmap
(
g
.
add_edge
,
uv_list
))
g
.
register_message_func
(
copy
(
'src'
),
uv_list
)
g
.
register_update_func
(
aggregate
(
'x'
,
'global'
,
normalize
=
True
),
(
shadow
,))
g
.
update_to
(
shadow
)
tuple
(
itertools
.
starmap
(
g
.
add_edge
,
vu_list
))
g
.
register_message_func
(
copy
(
'src'
),
vu_list
)
g
.
register_update_func
(
pull
(
'global'
,
'global'
),
node_list
)
g
.
update_from
(
shadow
)
self
.
bn_x
=
nn
.
BatchNorm1d
(
out_feats
)
self
.
bn_y
=
nn
.
BatchNorm1d
(
out_feats
)
g
.
remove_node
(
shadow
)
def
aggregate
(
self
,
g
,
z
):
z_list
=
[]
g
.
set_n_repr
(
z
)
g
.
update_all
(
fn
.
copy_src
(),
fn
.
sum
(),
batchable
=
True
)
z_list
.
append
(
g
.
get_n_repr
())
for
i
in
range
(
self
.
radius
-
1
):
for
j
in
range
(
2
**
i
):
g
.
update_all
(
fn
.
copy_src
(),
fn
.
sum
(),
batchable
=
True
)
z_list
.
append
(
g
.
get_n_repr
())
return
z_list
@
staticmethod
def
multiply_by_degree
(
g
):
g
.
register_message_func
(
lambda
*
args
:
None
,
g
.
edges
)
def
update_func
(
node_reprs
,
_
):
node_reprs
=
node_reprs
.
copy
()
node_reprs
[
'deg'
]
=
node_reprs
[
'x'
]
*
node_reprs
[
'degree'
]
return
node_reprs
g
.
register_update_func
(
update_func
,
g
.
nodes
)
g
.
update_all
()
def
forward
(
self
,
g
,
lg
,
x
,
y
,
deg_g
,
deg_lg
,
eid2nid
):
xy
=
F
.
embedding
(
eid2nid
,
x
)
@
staticmethod
def
message_func
(
src
,
trg
,
_
):
return
{
'y'
:
src
[
'x'
]}
def
update_func
(
self
,
which
):
if
which
==
'node'
:
linear_x
,
linear_y
,
linear_deg
,
linear_global
=
\
self
.
theta_x
,
self
.
theta_y
,
self
.
theta_deg
,
self
.
theta_global
linear_list
=
self
.
theta_list
elif
which
==
'edge'
:
linear_x
,
linear_y
,
linear_deg
,
linear_global
=
\
self
.
gamma_x
,
self
.
gamma_y
,
self
.
gamma_deg
,
self
.
gamma_global
linear_list
=
self
.
gamma_list
def
u
(
node_reprs
,
edge_reprs
):
edge_reprs
=
filter
(
lambda
x
:
x
is
not
None
,
edge_reprs
)
y
=
sum
(
x
[
'y'
]
for
x
in
edge_reprs
)
node_reprs
=
node_reprs
.
copy
()
node_reprs
[
'x'
]
=
linear_x
(
node_reprs
[
'x'
])
\
+
linear_y
(
y
)
\
+
linear_deg
(
node_reprs
[
'deg'
])
\
+
linear_global
(
node_reprs
[
'global'
])
\
+
sum
(
linear
(
node_reprs
[
i
])
\
for
i
,
linear
in
enumerate
(
linear_list
))
return
node_reprs
return
u
def
forward
(
self
,
g
,
lg
,
glg
):
self
.
local_aggregate
(
g
)
self
.
local_aggregate
(
lg
)
self
.
global_aggregate
(
g
)
self
.
global_aggregate
(
lg
)
self
.
multiply_by_degree
(
g
)
self
.
multiply_by_degree
(
lg
)
# TODO efficiency
for
node
,
reprs
in
g
.
nodes
.
items
():
glg
.
nodes
[
node
].
update
(
reprs
)
for
node
,
reprs
in
lg
.
nodes
.
items
():
glg
.
nodes
[
node
].
update
(
reprs
)
glg
.
register_message_func
(
self
.
message_func
,
glg
.
edges
)
glg
.
register_update_func
(
self
.
update_func
(
'node'
),
g
.
nodes
)
glg
.
register_update_func
(
self
.
update_func
(
'edge'
),
lg
.
nodes
)
glg
.
update_all
()
# TODO efficiency
for
node
,
reprs
in
g
.
nodes
.
items
():
reprs
.
update
(
glg
.
nodes
[
node
])
for
node
,
reprs
in
lg
.
nodes
.
items
():
reprs
.
update
(
glg
.
nodes
[
node
])
x_list
=
[
theta
(
z
)
for
theta
,
z
in
zip
(
self
.
theta_list
,
self
.
aggregate
(
g
,
x
))]
g
.
set_e_repr
(
y
)
g
.
update_all
(
fn
.
copy_edge
(),
fn
.
sum
(),
batchable
=
True
)
yx
=
g
.
get_n_repr
()
x
=
self
.
theta_x
(
x
)
+
self
.
theta_deg
(
deg_g
*
x
)
+
sum
(
x_list
)
+
self
.
theta_y
(
yx
)
x
=
self
.
bn_x
(
x
[:,
:
self
.
out_feats
]
+
F
.
relu
(
x
[:,
self
.
out_feats
:]))
y_list
=
[
gamma
(
z
)
for
gamma
,
z
in
zip
(
self
.
gamma_list
,
self
.
aggregate
(
lg
,
y
))]
lg
.
set_e_repr
(
xy
)
lg
.
update_all
(
fn
.
copy_edge
(),
fn
.
sum
(),
batchable
=
True
)
xy
=
lg
.
get_n_repr
()
y
=
self
.
gamma_y
(
y
)
+
self
.
gamma_deg
(
deg_lg
*
y
)
+
sum
(
y_list
)
+
self
.
gamma_x
(
xy
)
y
=
self
.
bn_y
(
y
[:,
:
self
.
out_feats
]
+
F
.
relu
(
y
[:,
self
.
out_feats
:]))
class
GNNModule
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
order
,
radius
):
super
().
__init__
()
self
.
module_list
=
nn
.
ModuleList
([
GLGModule
(
in_feats
,
out_feats
,
radius
)
for
i
in
range
(
order
)])
def
forward
(
self
,
pairs
,
fusions
):
for
module
,
(
g
,
lg
),
glg
in
zip
(
self
.
module_list
,
pairs
,
fusions
):
module
(
g
,
lg
,
glg
)
for
lhs
,
rhs
in
zip
(
pairs
[:
-
1
],
pairs
[
1
:]):
for
node
,
reprs
in
lhs
[
1
].
nodes
.
items
():
x_rhs
=
reprs
[
'x'
]
reprs
[
'x'
]
=
x_rhs
+
rhs
[
0
].
nodes
[
node
][
'x'
]
rhs
[
0
].
nodes
[
node
][
'x'
]
+=
x_rhs
return
x
,
y
class
GNN
(
nn
.
Module
):
def
__init__
(
self
,
feats
,
order
,
radius
,
n_classes
):
super
().
__init__
()
self
.
order
=
order
self
.
linear
=
nn
.
Linear
(
feats
[
-
1
],
n_classes
)
self
.
module_list
=
nn
.
ModuleList
([
GNNModule
(
in_feats
,
out_feats
,
order
,
radius
)
for
in_feats
,
out_feats
in
zip
(
feats
[:
-
1
],
feats
[
1
:])])
@
staticmethod
def
line_graph
(
g
):
lg
=
nx
.
line_graph
(
g
)
glg
=
nx
.
DiGraph
()
glg
.
add_nodes_from
(
g
.
nodes
)
glg
.
add_nodes_from
(
lg
.
nodes
)
for
u
,
v
in
g
.
edges
:
glg
.
add_edge
(
u
,
(
u
,
v
))
glg
.
add_edge
((
u
,
v
),
u
)
glg
.
add_edge
(
v
,
(
u
,
v
))
glg
.
add_edge
((
u
,
v
),
v
)
return
lg
,
glg
@
staticmethod
def
nx2dgl
(
g
):
deg_dict
=
dict
(
nx
.
degree
(
g
))
z
=
sum
(
deg_dict
.
values
())
dgl_g
=
G
.
DGLGraph
(
g
)
for
node
,
reprs
in
dgl_g
.
nodes
.
items
():
reprs
[
'degree'
]
=
deg_dict
[
node
]
reprs
[
'x'
]
=
th
.
full
((
1
,
1
),
reprs
[
'degree'
]
/
z
)
reprs
.
update
(
g
.
nodes
[
node
])
return
dgl_g
def
forward
(
self
,
g
):
def
__init__
(
self
,
g
,
feats
,
radius
,
n_classes
):
"""
Parameters
----------
g : networkx.DiGraph
"""
pair_list
,
glg_list
=
[],
[]
dgl_g
=
self
.
nx2dgl
(
g
)
origin
=
dgl_g
for
i
in
range
(
self
.
order
):
lg
,
glg
=
self
.
line_graph
(
g
)
dgl_lg
=
self
.
nx2dgl
(
lg
)
pair_list
.
append
((
dgl_g
,
copy
.
deepcopy
(
dgl_lg
)))
glg_list
.
append
(
G
.
DGLGraph
(
glg
))
g
=
lg
dgl_g
=
dgl_lg
super
(
GNN
,
self
).
__init__
()
for
module
in
self
.
module_list
:
module
(
pair_list
,
glg_list
)
lg
=
nx
.
line_graph
(
g
)
x
=
list
(
zip
(
*
g
.
degree
))[
1
]
self
.
x
=
self
.
normalize
(
th
.
tensor
(
x
,
dtype
=
th
.
float
).
unsqueeze
(
1
))
y
=
list
(
zip
(
*
lg
.
degree
))[
1
]
self
.
y
=
self
.
normalize
(
th
.
tensor
(
y
,
dtype
=
th
.
float
).
unsqueeze
(
1
))
self
.
eid2nid
=
th
.
tensor
([
n
for
[[
_
,
n
],
_
]
in
lg
.
edges
])
return
self
.
linear
(
th
.
cat
([
reprs
[
'x'
]
for
reprs
in
origin
.
nodes
.
values
()],
0
))
self
.
g
=
dgl
.
DGLGraph
(
g
)
self
.
lg
=
dgl
.
DGLGraph
(
nx
.
convert_node_labels_to_integers
(
lg
))
self
.
linear
=
nn
.
Linear
(
feats
[
-
1
],
n_classes
)
self
.
module_list
=
nn
.
ModuleList
([
GNNModule
(
m
,
n
,
radius
)
for
m
,
n
in
zip
(
feats
[:
-
1
],
feats
[
1
:])])
@
staticmethod
def
normalize
(
x
):
x
=
x
-
th
.
mean
(
x
,
0
)
x
=
x
/
th
.
sqrt
(
th
.
mean
(
x
*
x
,
0
))
return
x
def
cuda
(
self
):
self
.
x
=
self
.
x
.
cuda
()
self
.
y
=
self
.
y
.
cuda
()
self
.
eid2nid
=
self
.
eid2nid
.
cuda
()
super
(
GNN
,
self
).
cuda
()
def
forward
(
self
):
x
,
y
=
self
.
x
,
self
.
y
for
module
in
self
.
module_list
:
x
,
y
=
module
(
self
.
g
,
self
.
lg
,
x
,
y
,
self
.
x
,
self
.
y
,
self
.
eid2nid
)
return
self
.
linear
(
x
)
examples/pytorch/line_graph/test.py
View file @
2b092811
"""
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10
--order 3
--radius 3
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --radius 3
"""
import
argparse
import
networkx
as
nx
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
.functional
as
F
import
torch.optim
as
optim
import
gnn
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--features'
,
nargs
=
'+'
,
type
=
int
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
)
parser
.
add_argument
(
'--n-classes'
,
type
=
int
)
parser
.
add_argument
(
'--n-iterations'
,
type
=
int
)
parser
.
add_argument
(
'--n-nodes'
,
type
=
int
)
parser
.
add_argument
(
'--order'
,
type
=
int
)
parser
.
add_argument
(
'--radius'
,
type
=
int
)
args
=
parser
.
parse_args
()
if
args
.
gpu
<
0
:
cuda
=
False
else
:
cuda
=
True
th
.
cuda
.
set_device
(
args
.
gpu
)
g
=
nx
.
barabasi_albert_graph
(
args
.
n_nodes
,
1
).
to_directed
()
# TODO SBM
y
=
th
.
multinomial
(
th
.
ones
(
args
.
n_classes
),
args
.
n_nodes
,
replacement
=
True
)
network
=
gnn
.
GNN
(
args
.
features
,
args
.
order
,
args
.
radius
,
args
.
n_classes
)
model
=
gnn
.
GNN
(
g
,
args
.
features
,
args
.
radius
,
args
.
n_classes
)
if
cuda
:
network
.
cuda
()
ce
=
nn
.
CrossEntropyLoss
()
adam
=
optim
.
Adam
(
network
.
parameters
())
model
.
cuda
()
opt
=
optim
.
Adam
(
model
.
parameters
())
for
i
in
range
(
args
.
n_iterations
):
y_bar
=
network
(
g
)
loss
=
ce
(
y_bar
,
y
)
adam
.
zero_grad
()
y_bar
=
model
(
)
loss
=
F
.
cross_entropy
(
y_bar
,
y
)
opt
.
zero_grad
()
loss
.
backward
()
adam
.
step
()
opt
.
step
()
print
(
'[iteration %d]loss %f'
%
(
i
,
loss
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment