Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cc372e37
Commit
cc372e37
authored
Oct 04, 2018
by
GaiYu0
Browse files
sbm mixture
parent
f52fc3fc
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
298 additions
and
201 deletions
+298
-201
examples/pytorch/line_graph/data.py
examples/pytorch/line_graph/data.py
+0
-1
examples/pytorch/line_graph/dense_sbm.py
examples/pytorch/line_graph/dense_sbm.py
+0
-21
examples/pytorch/line_graph/gnn.py
examples/pytorch/line_graph/gnn.py
+13
-29
examples/pytorch/line_graph/sbm.py
examples/pytorch/line_graph/sbm.py
+0
-129
examples/pytorch/line_graph/train.py
examples/pytorch/line_graph/train.py
+27
-19
examples/pytorch/line_graph/utils.py
examples/pytorch/line_graph/utils.py
+4
-0
python/dgl/backend/pytorch.py
python/dgl/backend/pytorch.py
+26
-2
python/dgl/data/__init__.py
python/dgl/data/__init__.py
+1
-0
python/dgl/data/sbm.py
python/dgl/data/sbm.py
+97
-0
python/dgl/graph.py
python/dgl/graph.py
+42
-0
python/dgl/graph_index.py
python/dgl/graph_index.py
+59
-0
tests/test_adj_and_inc.py
tests/test_adj_and_inc.py
+29
-0
No files found.
examples/pytorch/line_graph/data.py
deleted
100644 → 0
View file @
f52fc3fc
import
torch.utils
as
utils
examples/pytorch/line_graph/dense_sbm.py
deleted
100644 → 0
View file @
f52fc3fc
import
torch
as
th
def
sbm
(
y
,
p
,
q
):
"""
Parameters
----------
y: torch.Tensor (N, 1)
"""
i
=
(
y
==
y
.
t
()).
float
()
r
=
i
*
p
+
(
1
-
i
)
*
q
a
=
th
.
distributions
.
Bernoulli
(
r
).
sample
()
b
=
th
.
triu
(
a
)
+
th
.
triu
(
a
,
1
).
t
()
return
b
if
__name__
==
'__main__'
:
N
=
10000
y
=
th
.
ones
(
N
,
1
)
p
=
1
/
N
q
=
0
a
=
sbm
(
y
,
p
,
q
)
print
(
th
.
sum
(
a
))
examples/pytorch/line_graph/gnn.py
View file @
cc372e37
...
...
@@ -3,7 +3,6 @@ Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415
Deviations from paper:
- Message passing is equivalent to `A^j \cdot X`, instead of `\min(1, A^j) \cdot X`.
- Pm Pd
"""
...
...
@@ -53,15 +52,17 @@ class GNNModule(nn.Module):
xy
=
F
.
embedding
(
eid2nid
,
x
)
x_list
=
[
theta
(
z
)
for
theta
,
z
in
zip
(
self
.
theta_list
,
self
.
aggregate
(
g
,
x
))]
g
.
set_e_repr
(
y
)
g
.
update_all
(
fn
.
copy_edge
(),
fn
.
sum
(),
batchable
=
True
)
yx
=
g
.
get_n_repr
()
x
=
self
.
theta_x
(
x
)
+
self
.
theta_deg
(
deg_g
*
x
)
+
sum
(
x_list
)
+
self
.
theta_y
(
yx
)
x
=
self
.
bn_x
(
x
[:,
:
self
.
out_feats
]
+
F
.
relu
(
x
[:,
self
.
out_feats
:]))
y_list
=
[
gamma
(
z
)
for
gamma
,
z
in
zip
(
self
.
gamma_list
,
self
.
aggregate
(
lg
,
y
))]
lg
.
set_
e
_repr
(
xy
)
lg
.
update_all
(
fn
.
copy_
edge
(),
fn
.
sum
(),
batchable
=
True
)
lg
.
set_
n
_repr
(
xy
)
lg
.
update_all
(
fn
.
copy_
src
(),
fn
.
sum
(),
batchable
=
True
)
xy
=
lg
.
get_n_repr
()
y
=
self
.
gamma_y
(
y
)
+
self
.
gamma_deg
(
deg_lg
*
y
)
+
sum
(
y_list
)
+
self
.
gamma_x
(
xy
)
y
=
self
.
bn_y
(
y
[:,
:
self
.
out_feats
]
+
F
.
relu
(
y
[:,
self
.
out_feats
:]))
...
...
@@ -70,42 +71,25 @@ class GNNModule(nn.Module):
class
GNN
(
nn
.
Module
):
def
__init__
(
self
,
g
,
feats
,
radius
,
n_classes
):
def
__init__
(
self
,
feats
,
radius
,
n_classes
):
"""
Parameters
----------
g : networkx.DiGraph
"""
super
(
GNN
,
self
).
__init__
()
lg
=
nx
.
line_graph
(
g
)
x
=
list
(
zip
(
*
g
.
degree
))[
1
]
self
.
x
=
self
.
normalize
(
th
.
tensor
(
x
,
dtype
=
th
.
float
).
unsqueeze
(
1
))
y
=
list
(
zip
(
*
lg
.
degree
))[
1
]
self
.
y
=
self
.
normalize
(
th
.
tensor
(
y
,
dtype
=
th
.
float
).
unsqueeze
(
1
))
self
.
eid2nid
=
th
.
tensor
([
int
(
n
)
for
[[
_
,
n
],
[
_
,
_
]]
in
lg
.
edges
])
self
.
g
=
dgl
.
DGLGraph
(
g
)
self
.
lg
=
dgl
.
DGLGraph
(
nx
.
convert_node_labels_to_integers
(
lg
))
self
.
linear
=
nn
.
Linear
(
feats
[
-
1
],
n_classes
)
self
.
module_list
=
nn
.
ModuleList
([
GNNModule
(
m
,
n
,
radius
)
for
m
,
n
in
zip
(
feats
[:
-
1
],
feats
[
1
:])])
@
staticmethod
def
normalize
(
x
):
x
=
x
-
th
.
mean
(
x
,
0
)
x
=
x
/
th
.
sqrt
(
th
.
mean
(
x
*
x
,
0
))
return
x
def
cuda
(
self
):
self
.
x
=
self
.
x
.
cuda
()
self
.
y
=
self
.
y
.
cuda
()
self
.
eid2nid
=
self
.
eid2nid
.
cuda
()
super
(
GNN
,
self
).
cuda
()
def
forward
(
self
,
g
,
lg
,
deg_g
,
deg_lg
,
eid2nid
):
def
normalize
(
x
):
x
=
x
-
th
.
mean
(
x
,
0
)
x
=
x
/
th
.
sqrt
(
th
.
mean
(
x
*
x
,
0
))
return
x
def
forward
(
self
):
x
,
y
=
self
.
x
,
self
.
y
x
=
normalize
(
deg_g
)
y
=
normalize
(
deg_lg
)
for
module
in
self
.
module_list
:
x
,
y
=
module
(
self
.
g
,
self
.
lg
,
x
,
y
,
self
.
x
,
self
.
y
,
self
.
eid2nid
)
x
,
y
=
module
(
g
,
lg
,
x
,
y
,
deg_g
,
deg_lg
,
eid2nid
)
return
self
.
linear
(
x
)
examples/pytorch/line_graph/sbm.py
deleted
100644 → 0
View file @
f52fc3fc
"""
By Minjie
"""
from
__future__
import
division
import
math
import
numpy
as
np
import
scipy.sparse
as
sp
import
networkx
as
nx
import
matplotlib.pyplot
as
plt
class
SSBM
:
def
__init__
(
self
,
n
,
k
,
a
=
10.0
,
b
=
2.0
,
regime
=
'constant'
,
rng
=
None
):
"""Symmetric Stochastic Block Model.
n - number of nodes
k - number of communities
a - probability scale for intra-community edge
b - probability scale for inter-community edge
regime - If "logaritm", this generates SSBM(n, k, a*log(n)/n, b*log(n)/n)
If "constant", this generates SSBM(n, k, a/n, b/n)
If "mixed", this generates SSBM(n, k, a*log(n)/n, b/n)
"""
self
.
n
=
n
self
.
k
=
k
if
regime
==
'logarithm'
:
if
math
.
sqrt
(
a
)
-
math
.
sqrt
(
b
)
>=
math
.
sqrt
(
k
):
print
(
'SSBM model with possible exact recovery.'
)
else
:
print
(
'SSBM model with impossible exact recovery.'
)
self
.
a
=
a
*
math
.
log
(
n
)
/
n
self
.
b
=
b
*
math
.
log
(
n
)
/
n
elif
regime
==
'constant'
:
snr
=
(
a
-
b
)
**
2
/
(
k
*
(
a
+
(
k
+
1
)
*
b
))
if
snr
>
1
:
print
(
'SSBM model with possible detection.'
)
else
:
print
(
'SSBM model that may not have detection (snr=%.5f).'
%
snr
)
self
.
a
=
a
/
n
self
.
b
=
b
/
n
elif
regime
==
'mixed'
:
self
.
a
=
a
*
math
.
log
(
n
)
/
n
self
.
b
=
b
/
n
else
:
raise
ValueError
(
'Unknown regime: %s'
%
regime
)
if
rng
is
None
:
self
.
rng
=
np
.
random
.
RandomState
()
else
:
self
.
rng
=
rng
self
.
_graph
=
None
def
generate
(
self
):
self
.
generate_communities
()
print
(
'Finished generating communities.'
)
self
.
generate_edges
()
print
(
'Finished generating edges.'
)
def
generate_communities
(
self
):
nodes
=
list
(
range
(
self
.
n
))
size
=
self
.
n
//
self
.
k
self
.
block_size
=
size
self
.
comm2node
=
[
nodes
[
i
*
size
:(
i
+
1
)
*
size
]
for
i
in
range
(
self
.
k
)]
self
.
node2comm
=
[
nid
//
size
for
nid
in
range
(
self
.
n
)]
def
generate_edges
(
self
):
# TODO: dedup edges
us
=
[]
vs
=
[]
# generate intra-comm edges
for
i
in
range
(
self
.
k
):
sp_mat
=
sp
.
random
(
self
.
block_size
,
self
.
block_size
,
density
=
self
.
a
,
random_state
=
self
.
rng
,
data_rvs
=
lambda
l
:
np
.
ones
(
l
))
u
=
sp_mat
.
row
+
i
*
self
.
block_size
v
=
sp_mat
.
col
+
i
*
self
.
block_size
us
.
append
(
u
)
vs
.
append
(
v
)
# generate inter-comm edges
for
i
in
range
(
self
.
k
):
for
j
in
range
(
self
.
k
):
if
i
==
j
:
continue
sp_mat
=
sp
.
random
(
self
.
block_size
,
self
.
block_size
,
density
=
self
.
b
,
random_state
=
self
.
rng
,
data_rvs
=
lambda
l
:
np
.
ones
(
l
))
u
=
sp_mat
.
row
+
i
*
self
.
block_size
v
=
sp_mat
.
col
+
j
*
self
.
block_size
us
.
append
(
u
)
vs
.
append
(
v
)
us
=
np
.
hstack
(
us
)
vs
=
np
.
hstack
(
vs
)
self
.
sp_mat
=
sp
.
coo_matrix
((
np
.
ones
(
us
.
shape
[
0
]),
(
us
,
vs
)),
shape
=
(
self
.
n
,
self
.
n
))
@
property
def
graph
(
self
):
if
self
.
_graph
is
None
:
self
.
_graph
=
nx
.
from_scipy_sparse_matrix
(
self
.
sp_mat
,
create_using
=
nx
.
DiGraph
())
return
self
.
_graph
def
plot
(
self
):
x
=
self
.
sp_mat
.
row
y
=
self
.
sp_mat
.
col
plt
.
scatter
(
x
,
y
,
s
=
0.5
,
marker
=
'.'
,
c
=
'k'
)
plt
.
savefig
(
'ssbm-%d-%d.pdf'
%
(
self
.
n
,
self
.
k
))
plt
.
clf
()
# plot out degree distribution
out_degree
=
[
d
for
_
,
d
in
self
.
graph
.
out_degree
().
items
()]
plt
.
hist
(
out_degree
,
100
,
normed
=
True
)
plt
.
savefig
(
'ssbm-%d-%d_out_degree.pdf'
%
(
self
.
n
,
self
.
k
))
plt
.
clf
()
if
__name__
==
'__main__'
:
n
=
1000
k
=
10
ssbm
=
SSBM
(
n
,
k
,
regime
=
'mixed'
,
a
=
4
,
b
=
1
)
ssbm
.
generate
()
g
=
ssbm
.
graph
print
(
'#nodes:'
,
g
.
number_of_nodes
())
print
(
'#edges:'
,
g
.
number_of_edges
())
#ssbm.plot()
#lg = nx.line_graph(g)
# plot degree distribution
#degree = [d for _, d in lg.degree().items()]
#plt.hist(degree, 100, normed=True)
#plt.savefig('lg<ssbm-%d-%d>_degree.pdf' % (n, k))
#plt.clf()
examples/pytorch/line_graph/train.py
View file @
cc372e37
"""
ipython3 train.py -- --gpu -1 --n-classes 2 --n-iterations 1000 --n-layers 30 --n-nodes 1000 --n-features 2 --radius 3
"""
from
__future__
import
division
import
argparse
from
itertools
import
permutations
import
networkx
as
nx
import
torch
as
th
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torch.utils.data
import
DataLoader
import
dgl
from
dgl.data
import
SBMMixture
import
gnn
import
sbm
import
utils
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch-size'
,
type
=
int
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
)
parser
.
add_argument
(
'--n-c
lass
es'
,
type
=
int
)
parser
.
add_argument
(
'--n-c
ommuniti
es'
,
type
=
int
)
parser
.
add_argument
(
'--n-features'
,
type
=
int
)
parser
.
add_argument
(
'--n-graphs'
,
type
=
int
)
parser
.
add_argument
(
'--n-iterations'
,
type
=
int
)
parser
.
add_argument
(
'--n-layers'
,
type
=
int
)
parser
.
add_argument
(
'--n-nodes'
,
type
=
int
)
parser
.
add_argument
(
'--model-path'
,
type
=
str
)
parser
.
add_argument
(
'--radius'
,
type
=
int
)
args
=
parser
.
parse_args
()
dev
=
th
.
device
(
'cpu'
)
if
args
.
gpu
<
0
else
th
.
device
(
'cuda:%d'
%
args
.
gpu
)
ssbm
=
sbm
.
SSBM
(
args
.
n_nodes
,
args
.
n_classes
,
1
,
1
)
gg
=
[]
for
i
in
range
(
args
.
n_graphs
):
ssbm
.
generate
()
gg
.
append
(
ssbm
.
graph
)
dataset
=
SBMMixture
(
args
.
n_graphs
,
args
.
n_nodes
,
args
.
n_communities
)
loader
=
utils
.
cycle
(
DataLoader
(
dataset
,
args
.
batch_size
,
shuffle
=
True
,
collate_fn
=
dataset
.
collate_fn
,
drop_last
=
True
))
assert
args
.
n_nodes
%
args
.
n_classes
==
0
ones
=
th
.
ones
(
int
(
args
.
n_nodes
/
args
.
n_classes
))
yy
=
[
th
.
cat
([
x
*
ones
for
x
in
p
]).
long
().
to
(
dev
)
for
p
in
permutations
(
range
(
args
.
n_classes
))]
ones
=
th
.
ones
(
args
.
n_nodes
//
args
.
n_communities
)
y_list
=
[
th
.
cat
([
th
.
cat
([
x
*
ones
for
x
in
p
])]
*
args
.
batch_size
).
long
().
to
(
dev
)
for
p
in
permutations
(
range
(
args
.
n_communities
))]
feats
=
[
1
]
+
[
args
.
n_features
]
*
args
.
n_layers
+
[
args
.
n_c
lass
es
]
model
=
gnn
.
GNN
(
g
,
feats
,
args
.
radius
,
args
.
n_c
lass
es
).
to
(
dev
)
feats
=
[
1
]
+
[
args
.
n_features
]
*
args
.
n_layers
+
[
args
.
n_c
ommuniti
es
]
model
=
gnn
.
GNN
(
feats
,
args
.
radius
,
args
.
n_c
ommuniti
es
).
to
(
dev
)
opt
=
optim
.
Adamax
(
model
.
parameters
(),
lr
=
0.04
)
for
i
in
range
(
args
.
n_iterations
):
y_bar
=
model
()
loss
=
min
(
F
.
cross_entropy
(
y_bar
,
y
)
for
y
in
yy
)
g
,
lg
,
deg_g
,
deg_lg
,
eid2nid
=
next
(
loader
)
deg_g
=
deg_g
.
to
(
dev
)
deg_lg
=
deg_lg
.
to
(
dev
)
eid2nid
=
eid2nid
.
to
(
dev
)
y_bar
=
model
(
g
,
lg
,
deg_g
,
deg_lg
,
eid2nid
)
loss
=
min
(
F
.
cross_entropy
(
y_bar
,
y
)
for
y
in
y_list
)
opt
.
zero_grad
()
loss
.
backward
()
opt
.
step
()
print
(
'[iteration %d]loss %f'
%
(
i
,
loss
))
placeholder
=
'0'
*
(
len
(
str
(
args
.
n_iterations
))
-
len
(
str
(
i
)))
print
(
'[iteration %s%d]loss %f'
%
(
placeholder
,
i
,
loss
))
th
.
save
(
model
.
state_dict
(),
args
.
model_path
)
examples/pytorch/line_graph/utils.py
0 → 100644
View file @
cc372e37
def
cycle
(
loader
):
while
True
:
for
x
in
loader
:
yield
x
python/dgl/backend/pytorch.py
View file @
cc372e37
from
__future__
import
absolute_import
import
ctypes
import
scipy
as
sp
import
torch
as
th
from
.._ffi.base
import
_LIB
,
check_call
,
c_array
...
...
@@ -27,6 +28,9 @@ tensor = th.tensor
sparse_tensor
=
th
.
sparse
.
FloatTensor
sum
=
th
.
sum
max
=
th
.
max
abs
=
th
.
abs
all
=
lambda
x
:
x
.
byte
().
all
()
stack
=
th
.
stack
def
astype
(
a
,
ty
):
return
a
.
type
(
ty
)
...
...
@@ -37,8 +41,25 @@ def asnumpy(a):
def
from_numpy
(
np_data
):
return
th
.
from_numpy
(
np_data
)
def
pack
(
tensors
):
return
th
.
cat
(
tensors
)
def
from_scipy_sparse
(
x
):
x_coo
=
x
.
tocoo
()
row
=
th
.
LongTensor
(
x_coo
.
row
)
col
=
th
.
LongTensor
(
x_coo
.
col
)
idx
=
th
.
stack
([
row
,
col
])
dat
=
th
.
FloatTensor
(
x_coo
.
data
)
return
th
.
sparse
.
FloatTensor
(
idx
,
dat
,
x_coo
.
shape
)
def
to_scipy_sparse
(
x
):
x_cpu
=
x
.
cpu
()
idx
=
x
.
_indices
()
row
,
col
=
idx
.
chunk
(
2
,
0
)
row
=
row
.
squeeze
(
0
).
numpy
()
col
=
col
.
squeeze
(
0
).
numpy
()
dat
=
x
.
_values
().
numpy
()
return
sp
.
sparse
.
coo_matrix
((
dat
,
(
row
,
col
)),
shape
=
x
.
shape
)
def
pack
(
tensors
,
dim
=
0
):
return
th
.
cat
(
tensors
,
dim
)
def
unpack
(
x
,
indices_or_sections
=
1
):
return
th
.
split
(
x
,
indices_or_sections
)
...
...
@@ -49,6 +70,9 @@ def shape(x):
def
dtype
(
x
):
return
x
.
dtype
def
item
(
x
):
return
x
.
item
()
unique
=
th
.
unique
def
gather_row
(
data
,
row_index
):
...
...
python/dgl/data/__init__.py
View file @
cc372e37
...
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import
from
.
import
citation_graph
as
citegrh
from
.tree
import
*
from
.utils
import
*
from
.sbm
import
SBMMixture
def
register_data_args
(
parser
):
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
...
...
python/dgl/data/sbm.py
0 → 100644
View file @
cc372e37
import
math
import
os
import
pickle
import
numpy
as
np
import
numpy.random
as
npr
import
scipy
as
sp
import
networkx
as
nx
from
torch.utils.data
import
Dataset
from
..
import
backend
as
F
from
..batch
import
batch
from
..graph
import
DGLGraph
from
..utils
import
Index
def
sbm
(
n_blocks
,
block_size
,
p
,
q
,
rng
=
None
):
""" (Symmetric) Stochastic Block Model
Parameters
----------
n_blocks : number of blocks
block_size : block size
p : probability for intra-community edge
q : probability for inter-community edge
"""
n
=
n_blocks
*
block_size
p
/=
n
q
/=
n
rng
=
np
.
random
.
RandomState
()
if
rng
is
None
else
rng
rows
=
[]
cols
=
[]
for
i
in
range
(
n_blocks
):
for
j
in
range
(
i
,
n_blocks
):
density
=
p
if
i
==
j
else
q
block
=
sp
.
sparse
.
random
(
block_size
,
block_size
,
density
,
random_state
=
rng
,
data_rvs
=
lambda
n
:
np
.
ones
(
n
))
rows
.
append
(
block
.
row
+
i
*
block_size
)
cols
.
append
(
block
.
col
+
j
*
block_size
)
rows
=
np
.
hstack
(
rows
)
cols
=
np
.
hstack
(
cols
)
a
=
sp
.
sparse
.
coo_matrix
((
np
.
ones
(
rows
.
shape
[
0
]),
(
rows
,
cols
)),
shape
=
(
n
,
n
))
adj
=
sp
.
sparse
.
triu
(
a
)
+
sp
.
sparse
.
triu
(
a
,
1
).
transpose
()
return
adj
class
SBMMixture
(
Dataset
):
def
__init__
(
self
,
n_graphs
,
n_nodes
,
n_communities
,
k
=
2
,
avg_deg
=
3
,
p
=
'Appendix C'
,
rng
=
None
):
""" Symmetric Stochastic Block Model Mixture
n_graphs : number of graphs
n_nodes : number of nodes
n_communities : number of communities
k : multiplier, optional
avg_deg : average degree, optional
p : random density generator, optional
rng : random number generator, optional
"""
super
(
SBMMixture
,
self
).
__init__
()
self
.
_n_nodes
=
n_nodes
assert
n_nodes
%
n_communities
==
0
block_size
=
n_nodes
//
n_communities
if
type
(
p
)
is
str
:
p
=
{
'Appendix C'
:
self
.
_appendix_c
}[
p
]
self
.
_k
=
k
self
.
_avg_deg
=
avg_deg
self
.
_gs
=
[
DGLGraph
()
for
i
in
range
(
n_graphs
)]
adjs
=
[
sbm
(
n_communities
,
block_size
,
*
p
())
for
i
in
range
(
n_graphs
)]
for
g
,
adj
in
zip
(
self
.
_gs
,
adjs
):
g
.
from_scipy_sparse_matrix
(
adj
)
self
.
_lgs
=
[
g
.
line_graph
()
for
g
in
self
.
_gs
]
in_degrees
=
lambda
g
:
g
.
in_degrees
(
Index
(
F
.
arange
(
g
.
number_of_nodes
(),
dtype
=
F
.
int64
))).
unsqueeze
(
1
).
float
()
self
.
_g_degs
=
[
in_degrees
(
g
)
for
g
in
self
.
_gs
]
self
.
_lg_degs
=
[
in_degrees
(
lg
)
for
lg
in
self
.
_lgs
]
self
.
_eid2nids
=
list
(
zip
(
*
[
g
.
edges
(
sorted
=
True
)
for
g
in
self
.
_gs
]))[
0
]
def
__len__
(
self
):
return
len
(
self
.
_gs
)
def
__getitem__
(
self
,
idx
):
return
self
.
_gs
[
idx
],
self
.
_lgs
[
idx
],
\
self
.
_g_degs
[
idx
],
self
.
_lg_degs
[
idx
],
self
.
_eid2nids
[
idx
]
def
_appendix_c
(
self
):
q
=
npr
.
uniform
(
0
,
self
.
_avg_deg
-
math
.
sqrt
(
self
.
_avg_deg
))
p
=
self
.
_k
*
self
.
_avg_deg
-
q
return
p
,
q
def
collate_fn
(
self
,
x
):
g
,
lg
,
deg_g
,
deg_lg
,
eid2nid
=
zip
(
*
x
)
g_batch
=
batch
(
g
)
lg_batch
=
batch
(
lg
)
degg_batch
=
F
.
pack
(
deg_g
)
deglg_batch
=
F
.
pack
(
deg_lg
)
eid2nid_batch
=
F
.
pack
([
x
+
i
*
self
.
_n_nodes
for
i
,
x
in
enumerate
(
eid2nid
)])
return
g_batch
,
lg_batch
,
degg_batch
,
deglg_batch
,
eid2nid_batch
python/dgl/graph.py
View file @
cc372e37
...
...
@@ -3,6 +3,7 @@
from
__future__
import
absolute_import
import
networkx
as
nx
import
scipy
as
sp
import
dgl
from
.base
import
ALL
,
is_all
,
__MSG__
,
__REPR__
...
...
@@ -470,6 +471,17 @@ class DGLGraph(object):
for
attr
in
edge_attrs
:
self
.
_edge_frame
[
attr
]
=
_batcher
(
attr_dict
[
attr
])
def
from_scipy_sparse_matrix
(
self
,
a
):
""" Convert from scipy sparse matrix.
Parameters
----------
a :
"""
self
.
clear
()
self
.
_graph
.
from_scipy_sparse_matrix
(
a
)
self
.
_msg_graph
.
add_nodes
(
self
.
_graph
.
number_of_nodes
())
def
node_attr_schemes
(
self
):
"""Return the node attribute schemes.
...
...
@@ -1351,6 +1363,36 @@ class DGLGraph(object):
self
.
_edge_frame
.
num_rows
,
reduce_func
)
def
adjacency_matrix
(
self
):
"""Return the adjacency matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
return
self
.
_graph
.
adjacency_matrix
()
def
incidence_matrix
(
self
):
"""Return the incidence matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
return
self
.
_graph
.
incidence_matrix
()
def
line_graph
(
self
):
"""Return the line graph of this graph.
Returns
-------
DGLGraph
The line graph of this graph.
"""
return
DGLGraph
(
self
.
_graph
.
line_graph
())
def
_get_repr
(
attr_dict
):
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
return
attr_dict
[
__REPR__
]
...
...
python/dgl/graph_index.py
View file @
cc372e37
...
...
@@ -3,6 +3,7 @@ from __future__ import absolute_import
import
ctypes
import
numpy
as
np
import
networkx
as
nx
import
scipy
as
sp
from
._ffi.base
import
c_array
from
._ffi.function
import
_init_api
...
...
@@ -407,6 +408,34 @@ class GraphIndex(object):
self
.
_cache
[
'adj'
]
=
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
to_context
(
mat
,
ctx
))
return
self
.
_cache
[
'adj'
]
def
incidence_matrix
(
self
):
"""Return the incidence matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
# TODO(gaiyu): DiGraph
if
not
'inc'
in
self
.
_cache
:
src
,
dst
,
_
=
self
.
edges
(
sorted
=
True
)
src
=
src
.
tousertensor
()
dst
=
dst
.
tousertensor
()
m
=
self
.
number_of_edges
()
eid
=
F
.
arange
(
m
,
dtype
=
F
.
int64
)
row
=
F
.
pack
([
src
,
dst
])
col
=
F
.
pack
([
eid
,
eid
])
idx
=
F
.
stack
([
row
,
col
])
x
=
F
.
ones
((
m
,))
x
[
src
==
dst
]
=
0
dat
=
F
.
pack
([
x
,
x
])
n
=
self
.
number_of_nodes
()
mat
=
F
.
sparse_tensor
(
idx
,
dat
,
[
n
,
m
])
self
.
_cache
[
'inc'
]
=
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
to_context
(
mat
,
ctx
))
return
self
.
_cache
[
'inc'
]
def
to_networkx
(
self
):
"""Convert to networkx graph.
...
...
@@ -459,6 +488,36 @@ class GraphIndex(object):
dst
=
utils
.
toindex
(
dst
)
self
.
add_edges
(
src
,
dst
)
def
from_scipy_sparse_matrix
(
self
,
adj
):
"""Convert from scipy sparse matrix.
Parameters
----------
adj :
"""
self
.
clear
()
self
.
add_nodes
(
adj
.
shape
[
0
])
adj_coo
=
adj
.
tocoo
()
src
=
utils
.
toindex
(
adj_coo
.
row
)
dst
=
utils
.
toindex
(
adj_coo
.
col
)
self
.
add_edges
(
src
,
dst
)
def
line_graph
(
self
):
"""Return the line graph of this graph.
Returns
-------
GraphIndex
The line graph of this graph.
"""
m
=
self
.
number_of_edges
()
ctx
=
F
.
get_context
(
F
.
ones
(
1
))
# TODO(gaiyu):
inc
=
F
.
to_scipy_sparse
(
self
.
incidence_matrix
().
get
(
ctx
))
adj
=
inc
.
transpose
().
dot
(
inc
)
-
2
*
sp
.
sparse
.
eye
(
m
)
lg
=
create_graph_index
()
lg
.
from_scipy_sparse_matrix
(
adj
)
return
lg
def
disjoint_union
(
graphs
):
"""Return a disjoint union of the input graphs.
...
...
tests/test_adj_and_inc.py
0 → 100644
View file @
cc372e37
import
dgl
import
dgl.backend
as
F
import
networkx
as
nx
import
numpy
as
np
import
scipy
as
sp
N
=
5
a
=
sp
.
sparse
.
random
(
N
,
N
,
1
/
N
,
data_rvs
=
lambda
n
:
np
.
ones
(
n
))
b
=
sp
.
sparse
.
triu
(
a
)
+
sp
.
sparse
.
triu
(
a
,
1
).
transpose
()
g_nx
=
nx
.
from_scipy_sparse_matrix
(
b
,
create_using
=
nx
.
DiGraph
())
g_dgl
=
dgl
.
DGLGraph
()
g_dgl
.
from_scipy_sparse_matrix
(
b
)
h_nx
=
g_dgl
.
to_networkx
()
g_nodes
=
set
(
g_nx
.
nodes
)
h_nodes
=
set
(
h_nx
.
nodes
)
assert
h_nodes
.
issubset
(
g_nodes
)
assert
all
(
g_nx
.
in_degree
(
x
)
==
g_nx
.
out_degree
(
x
)
==
0
for
x
in
g_nodes
.
difference
(
h_nodes
))
assert
g_nx
.
edges
==
h_nx
.
edges
nx_adj
=
nx
.
adjacency_matrix
(
g_nx
)
nx_inc
=
nx
.
incidence_matrix
(
g_nx
,
edgelist
=
sorted
(
g_nx
.
edges
()))
ctx
=
F
.
get_context
(
F
.
ones
((
1
,)))
dgl_adj
=
F
.
to_scipy_sparse
(
g_dgl
.
adjacency_matrix
().
get
(
ctx
)).
transpose
()
dgl_inc
=
F
.
to_scipy_sparse
(
g_dgl
.
incidence_matrix
().
get
(
ctx
))
assert
abs
(
nx_adj
-
dgl_adj
).
max
()
==
0
assert
abs
(
nx_inc
-
dgl_inc
).
max
()
==
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment