Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
5cda368d
Commit
5cda368d
authored
Nov 28, 2018
by
HQ
Committed by
Minjie Wang
Nov 27, 2018
Browse files
[Model] SBM hotfix (#137)
* [Model]SBM hotfix * [Model] remove backend in data
parent
02eb463a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
12 deletions
+26
-12
examples/pytorch/line_graph/gnn.py
examples/pytorch/line_graph/gnn.py
+1
-0
examples/pytorch/line_graph/train.py
examples/pytorch/line_graph/train.py
+21
-7
python/dgl/data/sbm.py
python/dgl/data/sbm.py
+4
-5
No files found.
examples/pytorch/line_graph/gnn.py
View file @
5cda368d
...
@@ -6,6 +6,7 @@ import networkx as nx
...
@@ -6,6 +6,7 @@ import networkx as nx
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
numpy
as
np
class
GNNModule
(
nn
.
Module
):
class
GNNModule
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
out_feats
,
radius
):
def
__init__
(
self
,
in_feats
,
out_feats
,
radius
):
...
...
examples/pytorch/line_graph/train.py
View file @
5cda368d
...
@@ -11,6 +11,7 @@ import time
...
@@ -11,6 +11,7 @@ import time
import
argparse
import
argparse
from
itertools
import
permutations
from
itertools
import
permutations
import
numpy
as
np
import
torch
as
th
import
torch
as
th
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -57,8 +58,18 @@ def compute_overlap(z_list):
...
@@ -57,8 +58,18 @@ def compute_overlap(z_list):
overlap_list
.
append
(
overlap
)
overlap_list
.
append
(
overlap
)
return
sum
(
overlap_list
)
/
len
(
overlap_list
)
return
sum
(
overlap_list
)
/
len
(
overlap_list
)
def
from_np
(
f
,
*
args
):
def
wrap
(
*
args
):
new
=
[
th
.
from_numpy
(
x
)
if
isinstance
(
x
,
np
.
ndarray
)
else
x
for
x
in
args
]
return
f
(
*
new
)
return
wrap
@
from_np
def
step
(
i
,
j
,
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
):
def
step
(
i
,
j
,
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
):
""" One step of training. """
""" One step of training. """
deg_g
=
deg_g
.
to
(
dev
)
deg_lg
=
deg_lg
.
to
(
dev
)
pm_pd
=
pm_pd
.
to
(
dev
)
t0
=
time
.
time
()
t0
=
time
.
time
()
z
=
model
(
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
)
z
=
model
(
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
)
t_forward
=
time
.
time
()
-
t0
t_forward
=
time
.
time
()
-
t0
...
@@ -75,6 +86,15 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
...
@@ -75,6 +86,15 @@ def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
return
loss
,
overlap
,
t_forward
,
t_backward
return
loss
,
overlap
,
t_forward
,
t_backward
@
from_np
def
inference
(
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
):
deg_g
=
deg_g
.
to
(
dev
)
deg_lg
=
deg_lg
.
to
(
dev
)
pm_pd
=
pm_pd
.
to
(
dev
)
z
=
model
(
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
)
return
z
def
test
():
def
test
():
p_list
=
[
6
,
5.5
,
5
,
4.5
,
1.5
,
1
,
0.5
,
0
]
p_list
=
[
6
,
5.5
,
5
,
4.5
,
1.5
,
1
,
0.5
,
0
]
q_list
=
[
0
,
0.5
,
1
,
1.5
,
4.5
,
5
,
5.5
,
6
]
q_list
=
[
0
,
0.5
,
1
,
1.5
,
4.5
,
5
,
5.5
,
6
]
...
@@ -84,10 +104,7 @@ def test():
...
@@ -84,10 +104,7 @@ def test():
dataset
=
SBMMixture
(
N
,
args
.
n_nodes
,
K
,
pq
=
[[
p
,
q
]]
*
N
)
dataset
=
SBMMixture
(
N
,
args
.
n_nodes
,
K
,
pq
=
[[
p
,
q
]]
*
N
)
loader
=
DataLoader
(
dataset
,
N
,
collate_fn
=
dataset
.
collate_fn
)
loader
=
DataLoader
(
dataset
,
N
,
collate_fn
=
dataset
.
collate_fn
)
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
=
next
(
iter
(
loader
))
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
=
next
(
iter
(
loader
))
deg_g
=
deg_g
.
to
(
dev
)
z
=
inference
(
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
)
deg_lg
=
deg_lg
.
to
(
dev
)
pm_pd
=
pm_pd
.
to
(
dev
)
z
=
model
(
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
)
overlap_list
.
append
(
compute_overlap
(
th
.
chunk
(
z
,
N
,
0
)))
overlap_list
.
append
(
compute_overlap
(
th
.
chunk
(
z
,
N
,
0
)))
return
overlap_list
return
overlap_list
...
@@ -95,9 +112,6 @@ n_iterations = args.n_graphs // args.batch_size
...
@@ -95,9 +112,6 @@ n_iterations = args.n_graphs // args.batch_size
for
i
in
range
(
args
.
n_epochs
):
for
i
in
range
(
args
.
n_epochs
):
total_loss
,
total_overlap
,
s_forward
,
s_backward
=
0
,
0
,
0
,
0
total_loss
,
total_overlap
,
s_forward
,
s_backward
=
0
,
0
,
0
,
0
for
j
,
[
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
]
in
enumerate
(
training_loader
):
for
j
,
[
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
]
in
enumerate
(
training_loader
):
deg_g
=
deg_g
.
to
(
dev
)
deg_lg
=
deg_lg
.
to
(
dev
)
pm_pd
=
pm_pd
.
to
(
dev
)
loss
,
overlap
,
t_forward
,
t_backward
=
step
(
i
,
j
,
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
)
loss
,
overlap
,
t_forward
,
t_backward
=
step
(
i
,
j
,
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
)
total_loss
+=
loss
total_loss
+=
loss
...
...
python/dgl/data/sbm.py
View file @
5cda368d
...
@@ -8,7 +8,6 @@ import numpy.random as npr
...
@@ -8,7 +8,6 @@ import numpy.random as npr
import
scipy
as
sp
import
scipy
as
sp
import
networkx
as
nx
import
networkx
as
nx
from
..
import
backend
as
F
from
..batched_graph
import
batch
from
..batched_graph
import
batch
from
..graph
import
DGLGraph
from
..graph
import
DGLGraph
from
..utils
import
Index
from
..utils
import
Index
...
@@ -94,7 +93,7 @@ class SBMMixture:
...
@@ -94,7 +93,7 @@ class SBMMixture:
g
.
from_scipy_sparse_matrix
(
adj
)
g
.
from_scipy_sparse_matrix
(
adj
)
self
.
_lgs
=
[
g
.
line_graph
(
backtracking
=
False
)
for
g
in
self
.
_gs
]
self
.
_lgs
=
[
g
.
line_graph
(
backtracking
=
False
)
for
g
in
self
.
_gs
]
in_degrees
=
lambda
g
:
g
.
in_degrees
(
in_degrees
=
lambda
g
:
g
.
in_degrees
(
Index
(
F
.
arange
(
0
,
g
.
number_of_nodes
()))).
unsqueeze
(
1
).
float
()
Index
(
np
.
arange
(
0
,
g
.
number_of_nodes
()))).
unsqueeze
(
1
).
float
()
self
.
_g_degs
=
[
in_degrees
(
g
)
for
g
in
self
.
_gs
]
self
.
_g_degs
=
[
in_degrees
(
g
)
for
g
in
self
.
_gs
]
self
.
_lg_degs
=
[
in_degrees
(
lg
)
for
lg
in
self
.
_lgs
]
self
.
_lg_degs
=
[
in_degrees
(
lg
)
for
lg
in
self
.
_lgs
]
self
.
_pm_pds
=
list
(
zip
(
*
[
g
.
edges
()
for
g
in
self
.
_gs
]))[
0
]
self
.
_pm_pds
=
list
(
zip
(
*
[
g
.
edges
()
for
g
in
self
.
_gs
]))[
0
]
...
@@ -118,7 +117,7 @@ class SBMMixture:
...
@@ -118,7 +117,7 @@ class SBMMixture:
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
=
zip
(
*
x
)
g
,
lg
,
deg_g
,
deg_lg
,
pm_pd
=
zip
(
*
x
)
g_batch
=
batch
(
g
)
g_batch
=
batch
(
g
)
lg_batch
=
batch
(
lg
)
lg_batch
=
batch
(
lg
)
degg_batch
=
F
.
pack
(
deg_g
)
degg_batch
=
np
.
concatenate
(
deg_g
,
axis
=
0
)
deglg_batch
=
F
.
pack
(
deg_lg
)
deglg_batch
=
np
.
concatenate
(
deg_lg
,
axis
=
0
)
pm_pd_batch
=
F
.
pack
([
x
+
i
*
self
.
_n_nodes
for
i
,
x
in
enumerate
(
pm_pd
)])
pm_pd_batch
=
np
.
concatenate
([
x
+
i
*
self
.
_n_nodes
for
i
,
x
in
enumerate
(
pm_pd
)]
,
axis
=
0
)
return
g_batch
,
lg_batch
,
degg_batch
,
deglg_batch
,
pm_pd_batch
return
g_batch
,
lg_batch
,
degg_batch
,
deglg_batch
,
pm_pd_batch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment