Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
eafcb7e7
Commit
eafcb7e7
authored
Dec 05, 2018
by
Minjie Wang
Committed by
Da Zheng
Dec 05, 2018
Browse files
[Bugfix][MXNet] Fix edge order and builtin max bug in mx (#247)
* Fix edge order and builtin max bug in mx * fix as requested
parent
71fa26ac
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
223 additions
and
87 deletions
+223
-87
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+3
-0
python/dgl/backend/mxnet/tensor.py
python/dgl/backend/mxnet/tensor.py
+14
-3
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+3
-1
python/dgl/function/message.py
python/dgl/function/message.py
+9
-7
python/dgl/graph.py
python/dgl/graph.py
+2
-2
python/dgl/graph_index.py
python/dgl/graph_index.py
+14
-28
python/dgl/immutable_graph_index.py
python/dgl/immutable_graph_index.py
+14
-4
python/dgl/runtime/ir/executor.py
python/dgl/runtime/ir/executor.py
+2
-1
python/dgl/runtime/scheduler.py
python/dgl/runtime/scheduler.py
+2
-2
python/dgl/runtime/spmv.py
python/dgl/runtime/spmv.py
+38
-15
python/dgl/utils.py
python/dgl/utils.py
+12
-22
tests/mxnet/test_graph_index.py
tests/mxnet/test_graph_index.py
+2
-2
tests/mxnet/test_specialization.py
tests/mxnet/test_specialization.py
+108
-0
No files found.
python/dgl/backend/backend.py
View file @
eafcb7e7
...
...
@@ -105,6 +105,9 @@ def sparse_matrix(data, index, shape, force_format=False):
SparseMatrix
The framework-specific sparse matrix. It can be stored in any format
unless force_format is True.
Tensor
The data convert index due to sparse format change.
None if no conversion is needed.
"""
pass
...
...
python/dgl/backend/mxnet/tensor.py
View file @
eafcb7e7
...
...
@@ -27,13 +27,24 @@ def sparse_matrix(data, index, shape, force_format=False):
raise
TypeError
(
'MXNet backend only supports CSR format,'
' but COO format is forced.'
)
coord
=
index
[
1
]
return
nd
.
sparse
.
csr_matrix
((
data
,
(
coord
[
0
],
coord
[
1
])),
# generate convert idx
# FIXME: cannot use int64
tmp_data
=
nd
.
arange
(
len
(
coord
[
0
]),
dtype
=
data
.
dtype
,
ctx
=
coord
[
0
].
context
)
tmp_spmat
=
nd
.
sparse
.
csr_matrix
((
tmp_data
,
(
coord
[
0
],
coord
[
1
])),
tuple
(
shape
),
ctx
=
data
.
context
)
convert_idx
=
nd
.
cast
(
tmp_spmat
.
data
,
dtype
=
'int64'
)
# shuffle the data
data
=
data
[
convert_idx
]
spmat
=
nd
.
sparse
.
csr_matrix
((
data
,
tmp_spmat
.
indices
,
tmp_spmat
.
indptr
),
tuple
(
shape
),
ctx
=
data
.
context
)
return
spmat
,
convert_idx
elif
fmt
==
'csr'
:
indices
=
index
[
1
]
indptr
=
index
[
2
]
return
nd
.
sparse
.
csr_matrix
((
data
,
indices
,
indptr
),
spmat
=
nd
.
sparse
.
csr_matrix
((
data
,
indices
,
indptr
),
tuple
(
shape
),
ctx
=
data
.
context
)
# No conversion is required.
return
spmat
,
None
else
:
raise
TypeError
(
'Invalid format: %s.'
%
fmt
)
...
...
@@ -73,7 +84,7 @@ def mean(input, dim):
return
nd
.
mean
(
input
,
axis
=
dim
)
def
max
(
input
,
dim
):
return
nd
.
max
(
input
,
axis
=
dim
)
.
asnumpy
()[
0
]
return
nd
.
max
(
input
,
axis
=
dim
)
def
cat
(
seq
,
dim
):
return
nd
.
concat
(
*
seq
,
dim
=
dim
)
...
...
python/dgl/backend/pytorch/tensor.py
View file @
eafcb7e7
...
...
@@ -24,7 +24,9 @@ def sparse_matrix(data, index, shape, force_format=False):
if
fmt
!=
'coo'
:
raise
TypeError
(
'Pytorch backend only supports COO format. But got %s.'
%
fmt
)
# NOTE: use _sparse_coo_tensor_unsafe to avoid unnecessary boundary check
return
th
.
_sparse_coo_tensor_unsafe
(
index
[
1
],
data
,
shape
)
spmat
=
th
.
_sparse_coo_tensor_unsafe
(
index
[
1
],
data
,
shape
)
# No conversion is required.
return
spmat
,
None
def
sparse_matrix_indices
(
spmat
):
return
(
'coo'
,
spmat
.
_indices
())
...
...
python/dgl/function/message.py
View file @
eafcb7e7
...
...
@@ -53,14 +53,16 @@ class SrcMulEdgeMessageFunction(MessageFunction):
return
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
def
__call__
(
self
,
edges
):
s
rc_
data
=
edges
.
src
[
self
.
src_field
]
sdata
=
edges
.
src
[
self
.
src_field
]
edata
=
edges
.
data
[
self
.
edge_field
]
if
F
.
ndim
(
edata
)
==
1
:
# edge feature is a scalar, unsqueeze dims of len 1
src_dim
=
F
.
ndim
(
src_data
)
new_eshape
=
(
F
.
shape
(
edata
)[
0
],)
+
(
1
,)
*
(
src_dim
-
1
)
edata
=
F
.
reshape
(
edata
,
new_eshape
)
ret
=
self
.
mul_op
(
src_data
,
edata
)
# Due to the different broadcasting semantics of different backends,
# we need to broadcast the sdata and edata to be of the same rank.
rank
=
max
(
F
.
ndim
(
sdata
),
F
.
ndim
(
edata
))
sshape
=
F
.
shape
(
sdata
)
eshape
=
F
.
shape
(
edata
)
sdata
=
F
.
reshape
(
sdata
,
sshape
+
(
1
,)
*
(
rank
-
F
.
ndim
(
sdata
)))
edata
=
F
.
reshape
(
edata
,
eshape
+
(
1
,)
*
(
rank
-
F
.
ndim
(
edata
)))
ret
=
self
.
mul_op
(
sdata
,
edata
)
return
{
self
.
out_field
:
ret
}
@
property
...
...
python/dgl/graph.py
View file @
eafcb7e7
...
...
@@ -2703,7 +2703,7 @@ class DGLGraph(object):
SparseTensor
The adjacency matrix.
"""
return
self
.
_graph
.
adjacency_matrix
(
transpose
,
ctx
)
return
self
.
_graph
.
adjacency_matrix
(
transpose
,
ctx
)
[
0
]
def
incidence_matrix
(
self
,
type
,
ctx
=
F
.
cpu
()):
"""Return the incidence matrix representation of this graph.
...
...
@@ -2745,7 +2745,7 @@ class DGLGraph(object):
SparseTensor
The incidence matrix.
"""
return
self
.
_graph
.
incidence_matrix
(
type
,
ctx
)
return
self
.
_graph
.
incidence_matrix
(
type
,
ctx
)
[
0
]
def
line_graph
(
self
,
backtracking
=
True
,
shared
=
False
):
"""Return the line graph of this graph.
...
...
python/dgl/graph_index.py
View file @
eafcb7e7
...
...
@@ -484,28 +484,6 @@ class GraphIndex(object):
induced_nodes
=
utils
.
toindex
(
rst
(
1
))
return
SubgraphIndex
(
rst
(
0
),
self
,
induced_nodes
,
e
)
def
adjacency_matrix_indices_and_shape
(
self
,
transpose
=
False
):
"""Return the indices and dense shape of adjacency matrix representation of
this graph.
utils.CtxCachedObject
An object that returns indices tensor given context.
tuple
Dense shape of the adjacency matrix
"""
if
not
'adj_ind_shape'
in
self
.
_cache
:
src
,
dst
,
_
=
self
.
edges
(
sorted
=
False
)
src
=
F
.
unsqueeze
(
src
.
tousertensor
(),
0
)
dst
=
F
.
unsqueeze
(
dst
.
tousertensor
(),
0
)
n
=
self
.
number_of_nodes
()
if
transpose
:
idx
=
F
.
cat
([
src
,
dst
],
dim
=
0
)
else
:
idx
=
F
.
cat
([
dst
,
src
],
dim
=
0
)
cached_idx
=
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
copy_to
(
idx
,
ctx
))
self
.
_cache
[
'adj_ind_shape'
]
=
(
cached_idx
,
(
n
,
n
))
return
self
.
_cache
[
'adj_ind_shape'
]
def
adjacency_matrix
(
self
,
transpose
,
ctx
):
"""Return the adjacency matrix representation of this graph.
...
...
@@ -526,6 +504,9 @@ class GraphIndex(object):
-------
SparseTensor
The adjacency matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
if
not
isinstance
(
transpose
,
bool
):
raise
DGLError
(
'Expect bool value for "transpose" arg,'
...
...
@@ -543,8 +524,9 @@ class GraphIndex(object):
m
=
self
.
number_of_edges
()
# FIXME(minjie): data type
dat
=
F
.
ones
((
m
,),
dtype
=
F
.
float32
,
ctx
=
ctx
)
adj
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
n
))
return
adj
adj
,
shuffle_idx
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
n
))
shuffle_idx
=
utils
.
toindex
(
shuffle_idx
)
if
shuffle_idx
is
not
None
else
None
return
adj
,
shuffle_idx
def
incidence_matrix
(
self
,
type
,
ctx
):
"""Return the incidence matrix representation of this graph.
...
...
@@ -577,6 +559,9 @@ class GraphIndex(object):
-------
SparseTensor
The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
src
,
dst
,
eid
=
self
.
edges
(
sorted
=
False
)
src
=
src
.
tousertensor
(
ctx
)
# the index of the ctx will be cached
...
...
@@ -590,14 +575,14 @@ class GraphIndex(object):
idx
=
F
.
cat
([
row
,
col
],
dim
=
0
)
# FIXME(minjie): data type
dat
=
F
.
ones
((
m
,),
dtype
=
F
.
float32
,
ctx
=
ctx
)
inc
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
inc
,
shuffle_idx
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
elif
type
==
'out'
:
row
=
F
.
unsqueeze
(
src
,
0
)
col
=
F
.
unsqueeze
(
eid
,
0
)
idx
=
F
.
cat
([
row
,
col
],
dim
=
0
)
# FIXME(minjie): data type
dat
=
F
.
ones
((
m
,),
dtype
=
F
.
float32
,
ctx
=
ctx
)
inc
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
inc
,
shuffle_idx
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
elif
type
==
'both'
:
# create index
row
=
F
.
unsqueeze
(
F
.
cat
([
src
,
dst
],
dim
=
0
),
0
)
...
...
@@ -611,10 +596,11 @@ class GraphIndex(object):
x
[
diagonal
]
=
0
y
[
diagonal
]
=
0
dat
=
F
.
cat
([
x
,
y
],
dim
=
0
)
inc
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
inc
,
shuffle_idx
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
else
:
raise
DGLError
(
'Invalid incidence matrix type: %s'
%
str
(
type
))
return
inc
shuffle_idx
=
utils
.
toindex
(
shuffle_idx
)
if
shuffle_idx
is
not
None
else
None
return
inc
,
shuffle_idx
def
to_networkx
(
self
):
"""Convert to networkx graph.
...
...
python/dgl/immutable_graph_index.py
View file @
eafcb7e7
...
...
@@ -8,7 +8,7 @@ import scipy.sparse as sp
from
._ffi.function
import
_init_api
from
.
import
backend
as
F
from
.
import
utils
from
.base
import
ALL
,
is_all
from
.base
import
ALL
,
is_all
,
dgl_warning
class
ImmutableGraphIndex
(
object
):
"""Graph index object on immutable graphs.
...
...
@@ -473,11 +473,16 @@ class ImmutableGraphIndex(object):
-------
utils.CtxCachedObject
An object that returns tensor given context.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
def
get_adj
(
ctx
):
new_mat
=
self
.
_sparse
.
adjacency_matrix
(
transpose
)
return
F
.
copy_to
(
new_mat
,
ctx
)
return
self
.
_sparse
.
adjacency_matrix
(
transpose
,
ctx
)
# FIXME(minjie): calculate the shuffle index
dgl_warning
(
'Shuffle index is not correctly computed. SPMV with edge feature might fail!!'
)
return
self
.
_sparse
.
adjacency_matrix
(
transpose
,
ctx
),
None
def
incidence_matrix
(
self
,
type
,
ctx
):
"""Return the incidence matrix representation of this graph.
...
...
@@ -510,6 +515,9 @@ class ImmutableGraphIndex(object):
-------
SparseTensor
The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
raise
Exception
(
'immutable graph doesn
\'
t support incidence_matrix for now.'
)
...
...
@@ -540,9 +548,11 @@ class ImmutableGraphIndex(object):
nx_graph : networkx.DiGraph
The nx graph
"""
assert
isinstance
(
nx_graph
,
nx
.
DiGraph
),
"The input graph has to be a NetworkX DiGraph."
if
not
isinstance
(
nx_graph
,
nx
.
DiGraph
):
nx_graph
=
nx
.
DiGraph
(
nx_graph
)
# We store edge Ids as an edge attribute.
out_mat
=
nx
.
convert_matrix
.
to_scipy_sparse_matrix
(
nx_graph
,
format
=
'coo'
)
nodelist
=
list
(
range
(
nx_graph
.
number_of_nodes
()))
out_mat
=
nx
.
convert_matrix
.
to_scipy_sparse_matrix
(
nx_graph
,
nodelist
=
nodelist
,
format
=
'coo'
)
self
.
_sparse
.
from_coo_matrix
(
out_mat
)
def
from_scipy_sparse_matrix
(
self
,
adj
):
...
...
python/dgl/runtime/ir/executor.py
View file @
eafcb7e7
...
...
@@ -307,7 +307,8 @@ class SPMVWithDataExecutor(Executor):
spA
=
spA_ctxobj
.
get
(
ctx
)
spidx
=
F
.
sparse_matrix_indices
(
spA
)
shape
=
F
.
shape
(
spA
)
spA
=
F
.
sparse_matrix
(
A_data
,
spidx
,
shape
)
# shuffle index is not used
spA
,
_
=
F
.
sparse_matrix
(
A_data
,
spidx
,
shape
)
if
F
.
ndim
(
B
)
==
1
:
# B is a vector, append a (1,) dim at the end
...
...
python/dgl/runtime/scheduler.py
View file @
eafcb7e7
...
...
@@ -489,9 +489,9 @@ def _gen_send_reduce(
uv_getter : callable
A function that returns a pair of var.IDX (u, v) for the triggered edges.
adj_creator : callable
A function that returns
var.SPMAT that represents the adjmat
.
A function that returns
the adjmat and the shuffle index
.
inc_creator : callable
A function that returns
var.SPMAT that represents the incmat
.
A function that returns
the incmat and the shuffle index
.
Returns
-------
...
...
python/dgl/runtime/spmv.py
View file @
eafcb7e7
...
...
@@ -80,9 +80,9 @@ def analyze_e2v_spmv(graph, rfunc):
rfunc_left
.
append
(
rfn
)
return
spmv_rfunc
,
rfunc_left
def
gen_v2v_spmv_schedule
(
adj
mat
,
spmv_pairs
,
nf
,
ef
,
eid
,
out
):
def
gen_v2v_spmv_schedule
(
adj
,
spmv_pairs
,
nf
,
ef
,
eid
,
out
):
"""
adj
mat
: sparse matrix
adj :
tuple (
sparse matrix
, utils.Index)
spmv_pairs : list of pair
nf : var.Var
input node features
...
...
@@ -93,9 +93,12 @@ def gen_v2v_spmv_schedule(adjmat, spmv_pairs, nf, ef, eid, out):
out : var.Var
output node features
"""
adjmat
,
shuffle_idx
=
adj
adj_var
=
var
.
SPMAT
(
adjmat
)
if
shuffle_idx
is
not
None
:
new_eid
=
utils
.
reorder_index
(
eid
.
data
,
shuffle_idx
)
eid
=
var
.
IDX
(
new_eid
)
for
mfn
,
rfn
in
spmv_pairs
:
#print('v2v mfn=%s rfn=%s' % (mfn.name, rfn.name))
if
mfn
.
use_edge_feature
:
ftedge
=
ir
.
READ
(
ef
,
eid
,
var
.
STR
(
mfn
.
edge_field
))
ftsrc
=
ir
.
READ_COL
(
nf
,
var
.
STR
(
mfn
.
src_field
))
...
...
@@ -108,15 +111,15 @@ def gen_v2v_spmv_schedule(adjmat, spmv_pairs, nf, ef, eid, out):
def
gen_e2v_spmv_schedule
(
inc
,
spmv_rfunc
,
mf
,
out
):
"""
inc : sparse matrix
The incidence matrix
inc : tuple (sparse matrix, utils.Index)
spmv_rfunc : list of builtin reducers
mf : var.Var
Variable for message frame.
out : var.Var
Variable for output reduced features.
"""
inc_var
=
var
.
SPMAT
(
inc
)
incmat
,
_
=
inc
inc_var
=
var
.
SPMAT
(
incmat
)
for
rfn
in
spmv_rfunc
:
ftmsg
=
ir
.
READ_COL
(
mf
,
var
.
STR
(
rfn
.
msg_field
))
ftdst
=
ir
.
SPMV
(
inc_var
,
ftmsg
)
...
...
@@ -134,10 +137,14 @@ def build_adj_matrix_graph(graph):
-------
utils.CtxCachedObject
Get be used to get adjacency matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
return
utils
.
CtxCachedObject
(
lambda
ctx
:
graph
.
adjacency_matrix
(
ctx
=
ctx
))
adjmat
,
shuffle_idx
=
graph
.
_graph
.
adjacency_matrix
(
transpose
=
False
,
ctx
=
F
.
cpu
())
return
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
copy_to
(
adjmat
,
ctx
)),
shuffle_idx
def
build_adj_matrix_index_uv
(
graph
,
edges
,
reduce_nodes
):
def
_
build_adj_matrix_index_uv
(
graph
,
edges
,
reduce_nodes
):
"""Build adj matrix index and shape using the given (u, v) edges.
The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
...
...
@@ -198,15 +205,19 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes):
Returns
-------
utils.CtxCachedObject
Get be used to get adjacency matrix on the provided ctx.
Get be used to get adjacency matrix and on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
sp_idx
,
shape
=
build_adj_matrix_index_uv
(
graph
,
edges
,
reduce_nodes
)
sp_idx
,
shape
=
_
build_adj_matrix_index_uv
(
graph
,
edges
,
reduce_nodes
)
u
,
v
=
edges
nnz
=
len
(
u
)
# FIXME(minjie): data type
dat
=
F
.
ones
((
nnz
,),
dtype
=
F
.
float32
,
ctx
=
F
.
cpu
())
mat
=
F
.
sparse_matrix
(
dat
,
sp_idx
,
shape
)
return
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
copy_to
(
mat
,
ctx
))
mat
,
shuffle_idx
=
F
.
sparse_matrix
(
dat
,
sp_idx
,
shape
)
shuffle_idx
=
utils
.
toindex
(
shuffle_idx
)
if
shuffle_idx
is
not
None
else
None
return
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
copy_to
(
mat
,
ctx
)),
shuffle_idx
def
build_inc_matrix_graph
(
graph
):
"""Build incidence matrix.
...
...
@@ -220,8 +231,13 @@ def build_inc_matrix_graph(graph):
-------
utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
return
utils
.
CtxCachedObject
(
lambda
ctx
:
graph
.
incidence_matrix
(
type
=
'in'
,
ctx
=
ctx
))
incmat
,
_
=
graph
.
_graph
.
incidence_matrix
(
type
=
'in'
,
ctx
=
F
.
cpu
())
# inc mat will not use data tensor so conversion index is not needed
return
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
copy_to
(
incmat
,
ctx
)),
None
def
build_inc_matrix_eid
(
m
,
eid
,
dst
,
reduce_nodes
):
"""Build incidence matrix using edge id and edge dst nodes.
...
...
@@ -269,6 +285,9 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
-------
utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
new2old
,
old2new
=
utils
.
build_relabel_map
(
reduce_nodes
,
sorted
=
True
)
dst
=
dst
.
tousertensor
()
...
...
@@ -283,8 +302,9 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
# create dat tensor
nnz
=
len
(
eid
)
dat
=
F
.
ones
((
nnz
,),
dtype
=
F
.
float32
,
ctx
=
F
.
cpu
())
mat
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
return
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
copy_to
(
mat
,
ctx
))
mat
,
_
=
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
(
n
,
m
))
# inc mat will not use data tensor so conversion index is not needed
return
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
copy_to
(
mat
,
ctx
)),
None
def
build_inc_matrix_dst
(
dst
,
reduce_nodes
):
"""Build incidence matrix using only edge destinations.
...
...
@@ -318,6 +338,9 @@ def build_inc_matrix_dst(dst, reduce_nodes):
-------
utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
eid
=
utils
.
toindex
(
F
.
arange
(
0
,
len
(
dst
)))
return
build_inc_matrix_eid
(
len
(
eid
),
eid
,
dst
,
reduce_nodes
)
python/dgl/utils.py
View file @
eafcb7e7
...
...
@@ -237,8 +237,8 @@ def build_relabel_map(x, sorted=False):
unique_x
,
_
=
F
.
sort_1d
(
F
.
unique
(
x
))
else
:
unique_x
=
x
map_len
=
int
(
F
.
max
(
unique_x
,
dim
=
0
))
+
1
old_to_new
=
F
.
zeros
(
map_len
,
dtype
=
F
.
int64
,
ctx
=
F
.
cpu
())
map_len
=
int
(
F
.
asnumpy
(
F
.
max
(
unique_x
,
dim
=
0
))
)
+
1
old_to_new
=
F
.
zeros
(
(
map_len
,
),
dtype
=
F
.
int64
,
ctx
=
F
.
cpu
())
F
.
scatter_row_inplace
(
old_to_new
,
unique_x
,
F
.
arange
(
0
,
len
(
unique_x
)))
return
unique_x
,
old_to_new
...
...
@@ -334,30 +334,20 @@ def reorder(dict_like, index):
new_dict
[
key
]
=
F
.
gather_row
(
val
,
idx_ctx
)
return
new_dict
def
build_coo_sparse_matrix
(
dat
,
row
,
col
,
dense_shape
):
"""
Build coo sparse matrix
def
reorder_index
(
idx
,
order
):
"""
Reorder the idx according to the given order
Parameters
----------
dat: Tensor
Data.
row: Tensor
Row index.
col: Tensor
Column index.
dense_shape: list or tuple of two integer
Dense shape of the sparse matrix
Returns
-------
SparseTensor
The sparse matrix.
idx : utils.Index
The index to be reordered.
order : utils.Index
The order to follow.
"""
nnz
=
len
(
row
)
row
=
F
.
unsqueeze
(
row
,
0
)
col
=
F
.
unsqueeze
(
col
,
0
)
idx
=
F
.
cat
([
row
,
col
],
dim
=
0
)
return
F
.
sparse_matrix
(
dat
,
(
'coo'
,
idx
),
dense_shape
)
idx
=
idx
.
tousertensor
()
order
=
order
.
tousertensor
()
new_idx
=
F
.
gather_row
(
idx
,
order
)
return
toindex
(
new_idx
)
def
is_iterable
(
obj
):
"""Return true if the object is an iterable."""
...
...
tests/mxnet/test_graph_index.py
View file @
eafcb7e7
...
...
@@ -14,8 +14,8 @@ def generate_rand_graph(n):
return
g
,
ig
def
check_graph_equal
(
g1
,
g2
):
adj1
=
g1
.
adjacency_matrix
(
transpose
=
False
,
ctx
=
mx
.
cpu
())
!=
0
adj2
=
g2
.
adjacency_matrix
(
transpose
=
False
,
ctx
=
mx
.
cpu
())
!=
0
adj1
=
g1
.
adjacency_matrix
(
transpose
=
False
,
ctx
=
mx
.
cpu
())
[
0
]
!=
0
adj2
=
g2
.
adjacency_matrix
(
transpose
=
False
,
ctx
=
mx
.
cpu
())
[
0
]
!=
0
assert
mx
.
nd
.
sum
(
adj1
-
adj2
).
asnumpy
()
==
0
def
test_graph_gen
():
...
...
tests/mxnet/test_specialization.py
View file @
eafcb7e7
...
...
@@ -26,6 +26,7 @@ def generate_graph2(n):
arr
=
(
sp
.
sparse
.
random
(
n
,
n
,
density
=
0.1
,
format
=
'coo'
)
!=
0
).
astype
(
np
.
int64
)
g1
=
dgl
.
DGLGraph
(
arr
,
readonly
=
True
)
g2
=
dgl
.
DGLGraph
(
arr
,
readonly
=
True
)
num_nodes
=
g1
.
number_of_nodes
()
g1
.
set_n_repr
({
'f1'
:
mx
.
nd
.
random
.
normal
(
shape
=
(
num_nodes
,)),
'f2'
:
mx
.
nd
.
random
.
normal
(
shape
=
(
num_nodes
,
D
))})
...
...
@@ -308,9 +309,116 @@ def test_send_and_recv_multi_fn():
v2
=
g
.
ndata
[
'v2'
]
assert
np
.
allclose
(
v1
.
asnumpy
(),
v2
.
asnumpy
(),
rtol
=
1e-05
,
atol
=
1e-05
)
############################ Copy from torch
D
=
5
def
simple_graph
():
g
=
dgl
.
DGLGraph
()
g
.
add_nodes
(
10
)
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
g
.
set_n_repr
({
'f1'
:
mx
.
nd
.
random
.
normal
(
shape
=
(
10
,)),
'f2'
:
mx
.
nd
.
random
.
normal
(
shape
=
(
10
,
D
))})
weights
=
mx
.
nd
.
random
.
normal
(
shape
=
(
17
,))
g
.
set_e_repr
({
'e1'
:
weights
,
'e2'
:
mx
.
nd
.
expand_dims
(
weights
,
1
)})
return
g
def
test_v2v_update_all_sum
():
def
_test
(
fld
):
def
message_func
(
edges
):
return
{
'm'
:
edges
.
src
[
fld
]}
def
message_func_edge
(
edges
):
if
len
(
edges
.
src
[
fld
].
shape
)
==
1
:
return
{
'm'
:
edges
.
src
[
fld
]
*
edges
.
data
[
'e1'
]}
else
:
return
{
'm'
:
edges
.
src
[
fld
]
*
edges
.
data
[
'e2'
]}
def
reduce_func
(
nodes
):
return
{
fld
:
mx
.
nd
.
sum
(
nodes
.
mailbox
[
'm'
],
axis
=
1
)}
def
apply_func
(
nodes
):
return
{
fld
:
2
*
nodes
.
data
[
fld
]}
g
=
simple_graph
()
# update all
v1
=
g
.
ndata
[
fld
]
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
v3
=
g
.
ndata
[
fld
]
assert
np
.
allclose
(
v2
.
asnumpy
(),
v3
.
asnumpy
(),
rtol
=
1e-05
,
atol
=
1e-05
)
# update all with edge weights
v1
=
g
.
ndata
[
fld
]
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v3
=
g
.
ndata
[
fld
].
squeeze
()
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
)
v4
=
g
.
ndata
[
fld
]
assert
np
.
allclose
(
v2
.
asnumpy
(),
v3
.
asnumpy
(),
rtol
=
1e-05
,
atol
=
1e-05
)
assert
np
.
allclose
(
v3
.
asnumpy
(),
v4
.
asnumpy
(),
rtol
=
1e-05
,
atol
=
1e-05
)
# test 1d node features
_test
(
'f1'
)
# test 2d node features
_test
(
'f2'
)
def
test_v2v_update_all_max
():
def
_test
(
fld
):
def
message_func
(
edges
):
return
{
'm'
:
edges
.
src
[
fld
]}
def
message_func_edge
(
edges
):
if
len
(
edges
.
src
[
fld
].
shape
)
==
1
:
return
{
'm'
:
edges
.
src
[
fld
]
*
edges
.
data
[
'e1'
]}
else
:
return
{
'm'
:
edges
.
src
[
fld
]
*
edges
.
data
[
'e2'
]}
def
reduce_func
(
nodes
):
return
{
fld
:
mx
.
nd
.
max
(
nodes
.
mailbox
[
'm'
],
axis
=
1
)}
def
apply_func
(
nodes
):
return
{
fld
:
2
*
nodes
.
data
[
fld
]}
g
=
simple_graph
()
# update all
v1
=
g
.
ndata
[
fld
]
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
max
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
v3
=
g
.
ndata
[
fld
]
assert
np
.
allclose
(
v2
.
asnumpy
(),
v3
.
asnumpy
(),
rtol
=
1e-05
,
atol
=
1e-05
)
# update all with edge weights
v1
=
g
.
ndata
[
fld
]
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
fn
.
max
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
ndata
[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm'
),
fn
.
max
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v3
=
g
.
ndata
[
fld
].
squeeze
()
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
)
v4
=
g
.
ndata
[
fld
]
assert
np
.
allclose
(
v2
.
asnumpy
(),
v3
.
asnumpy
(),
rtol
=
1e-05
,
atol
=
1e-05
)
assert
np
.
allclose
(
v3
.
asnumpy
(),
v4
.
asnumpy
(),
rtol
=
1e-05
,
atol
=
1e-05
)
# test 1d node features
_test
(
'f1'
)
# test 2d node features
_test
(
'f2'
)
############################ Copy from torch
if
__name__
==
'__main__'
:
test_update_all
()
test_pull
()
test_send_and_recv
()
test_update_all_multi_fn
()
test_send_and_recv_multi_fn
()
test_v2v_update_all_sum
()
test_v2v_update_all_max
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment