Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e3bac70b
Unverified
Commit
e3bac70b
authored
Aug 12, 2018
by
Minjie Wang
Committed by
GitHub
Aug 12, 2018
Browse files
Spmv partial (#43)
* partial spmv impl and test * some fix for update edge
parent
ee241699
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
145 additions
and
51 deletions
+145
-51
python/dgl/backend/pytorch.py
python/dgl/backend/pytorch.py
+5
-4
python/dgl/graph.py
python/dgl/graph.py
+86
-40
python/dgl/utils.py
python/dgl/utils.py
+41
-0
tests/pytorch/test_specialization.py
tests/pytorch/test_specialization.py
+13
-7
No files found.
python/dgl/backend/pytorch.py
View file @
e3bac70b
...
@@ -23,6 +23,9 @@ sparse_tensor = th.sparse.FloatTensor
...
@@ -23,6 +23,9 @@ sparse_tensor = th.sparse.FloatTensor
sum
=
th
.
sum
sum
=
th
.
sum
max
=
th
.
max
max
=
th
.
max
def
astype
(
a
,
ty
):
return
a
.
type
(
ty
)
def
asnumpy
(
a
):
def
asnumpy
(
a
):
return
a
.
cpu
().
numpy
()
return
a
.
cpu
().
numpy
()
...
@@ -50,16 +53,14 @@ def broadcast_to(x, to_array):
...
@@ -50,16 +53,14 @@ def broadcast_to(x, to_array):
return
x
+
th
.
zeros_like
(
to_array
)
return
x
+
th
.
zeros_like
(
to_array
)
nonzero
=
th
.
nonzero
nonzero
=
th
.
nonzero
def
eq_scalar
(
x
,
val
):
return
th
.
eq
(
x
,
float
(
val
))
squeeze
=
th
.
squeeze
squeeze
=
th
.
squeeze
unsqueeze
=
th
.
unsqueeze
unsqueeze
=
th
.
unsqueeze
reshape
=
th
.
reshape
reshape
=
th
.
reshape
zeros
=
th
.
zeros
ones
=
th
.
ones
ones
=
th
.
ones
spmm
=
th
.
spmm
spmm
=
th
.
spmm
sort
=
th
.
sort
sort
=
th
.
sort
arange
=
th
.
arange
def
to_context
(
x
,
ctx
):
def
to_context
(
x
,
ctx
):
if
ctx
is
None
:
if
ctx
is
None
:
...
...
python/dgl/graph.py
View file @
e3bac70b
...
@@ -436,20 +436,28 @@ class DGLGraph(DiGraph):
...
@@ -436,20 +436,28 @@ class DGLGraph(DiGraph):
def
_nonbatch_sendto
(
self
,
u
,
v
,
message_func
):
def
_nonbatch_sendto
(
self
,
u
,
v
,
message_func
):
f_msg
=
_get_message_func
(
message_func
)
f_msg
=
_get_message_func
(
message_func
)
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
f_msg
(
_get_repr
(
self
.
nodes
[
uu
]),
ret
=
f_msg
(
_get_repr
(
self
.
nodes
[
uu
]),
_get_repr
(
self
.
edges
[
uu
,
vv
]))
_get_repr
(
self
.
edges
[
uu
,
vv
]))
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
ret
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
ret
def
_batch_sendto
(
self
,
u
,
v
,
message_func
):
def
_batch_sendto
(
self
,
u
,
v
,
message_func
):
f_msg
=
_get_message_func
(
message_func
)
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
u
,
v
=
self
.
cached_graph
.
edges
()
self
.
msg_graph
.
add_edges
(
u
,
v
)
# call UDF
src_reprs
=
self
.
get_n_repr
(
u
)
edge_reprs
=
self
.
get_e_repr
()
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
else
:
u
=
utils
.
convert_to_id_tensor
(
u
)
u
=
utils
.
convert_to_id_tensor
(
u
)
v
=
utils
.
convert_to_id_tensor
(
v
)
v
=
utils
.
convert_to_id_tensor
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
eid
=
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
eid
=
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
self
.
msg_graph
.
add_edges
(
u
,
v
)
self
.
msg_graph
.
add_edges
(
u
,
v
)
if
len
(
u
)
!=
len
(
v
)
and
len
(
u
)
==
1
:
u
=
F
.
broadcast_to
(
u
,
v
)
# call UDF
# call UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
edge_reprs
=
self
.
get_e_repr_by_id
(
eid
)
edge_reprs
=
self
.
get_e_repr_by_id
(
eid
)
...
@@ -490,6 +498,8 @@ class DGLGraph(DiGraph):
...
@@ -490,6 +498,8 @@ class DGLGraph(DiGraph):
self
.
_nonbatch_update_edge
(
u
,
v
,
edge_func
)
self
.
_nonbatch_update_edge
(
u
,
v
,
edge_func
)
def
_nonbatch_update_edge
(
self
,
u
,
v
,
edge_func
):
def
_nonbatch_update_edge
(
self
,
u
,
v
,
edge_func
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
edge_func
(
_get_repr
(
self
.
nodes
[
uu
]),
ret
=
edge_func
(
_get_repr
(
self
.
nodes
[
uu
]),
_get_repr
(
self
.
nodes
[
vv
]),
_get_repr
(
self
.
nodes
[
vv
]),
...
@@ -497,13 +507,19 @@ class DGLGraph(DiGraph):
...
@@ -497,13 +507,19 @@ class DGLGraph(DiGraph):
_set_repr
(
self
.
edges
[
uu
,
vv
],
ret
)
_set_repr
(
self
.
edges
[
uu
,
vv
],
ret
)
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
edge_reprs
=
self
.
get_e_repr
()
new_edge_reprs
=
edge_func
(
src_reprs
,
dst_reprs
,
edge_reprs
)
self
.
set_e_repr
(
new_edge_reprs
)
else
:
u
=
utils
.
convert_to_id_tensor
(
u
)
u
=
utils
.
convert_to_id_tensor
(
u
)
v
=
utils
.
convert_to_id_tensor
(
v
)
v
=
utils
.
convert_to_id_tensor
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
eid
=
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
eid
=
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
if
len
(
u
)
!=
len
(
v
)
and
len
(
u
)
==
1
:
u
=
F
.
broadcast_to
(
u
,
v
)
elif
len
(
u
)
!=
len
(
v
)
and
len
(
v
)
==
1
:
v
=
F
.
broadcast_to
(
v
,
u
)
# call the UDF
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
@@ -566,6 +582,8 @@ class DGLGraph(DiGraph):
...
@@ -566,6 +582,8 @@ class DGLGraph(DiGraph):
def
_nonbatch_recv
(
self
,
u
,
reduce_func
,
update_func
):
def
_nonbatch_recv
(
self
,
u
,
reduce_func
,
update_func
):
f_reduce
=
_get_reduce_func
(
reduce_func
)
f_reduce
=
_get_reduce_func
(
reduce_func
)
f_update
=
update_func
f_update
=
update_func
if
is_all
(
u
):
u
=
list
(
range
(
0
,
self
.
number_of_nodes
()))
for
i
,
uu
in
enumerate
(
utils
.
node_iter
(
u
)):
for
i
,
uu
in
enumerate
(
utils
.
node_iter
(
u
)):
# reduce phase
# reduce phase
msgs_batch
=
[
self
.
edges
[
vv
,
uu
].
pop
(
__MSG__
)
msgs_batch
=
[
self
.
edges
[
vv
,
uu
].
pop
(
__MSG__
)
...
@@ -702,6 +720,8 @@ class DGLGraph(DiGraph):
...
@@ -702,6 +720,8 @@ class DGLGraph(DiGraph):
message_func
,
message_func
,
reduce_func
,
reduce_func
,
update_func
):
update_func
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
self
.
_nonbatch_sendto
(
u
,
v
,
message_func
)
self
.
_nonbatch_sendto
(
u
,
v
,
message_func
)
dst
=
set
()
dst
=
set
()
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
...
@@ -714,22 +734,35 @@ class DGLGraph(DiGraph):
...
@@ -714,22 +734,35 @@ class DGLGraph(DiGraph):
message_func
,
message_func
,
reduce_func
,
reduce_func
,
update_func
):
update_func
):
if
message_func
==
'from_src'
and
reduce_func
==
'sum'
\
if
is_all
(
u
)
and
is_all
(
v
):
and
is_all
(
u
)
and
is_all
(
v
):
self
.
update_all
(
message_func
,
reduce_func
,
update_func
,
True
)
# TODO(minjie): SPMV is only supported for updating all nodes right now.
elif
message_func
==
'from_src'
and
reduce_func
==
'sum'
:
adjmat
=
self
.
cached_graph
.
adjmat
(
self
.
context
)
# TODO(minjie): check the validity of edges u->v
u
=
utils
.
convert_to_id_tensor
(
u
)
v
=
utils
.
convert_to_id_tensor
(
v
)
# TODO(minjie): broadcasting is optional for many-one input.
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
# relabel destination nodes.
new2old
,
old2new
=
utils
.
build_relabel_map
(
v
)
# TODO(minjie): should not directly use []
new_v
=
old2new
[
v
]
# create adj mat
idx
=
F
.
pack
([
F
.
unsqueeze
(
new_v
,
0
),
F
.
unsqueeze
(
u
,
0
)])
dat
=
F
.
ones
((
len
(
u
),))
n
=
self
.
number_of_nodes
()
m
=
len
(
new2old
)
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
[
m
,
n
])
adjmat
=
F
.
to_context
(
adjmat
,
self
.
context
)
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs
=
{}
reduced_msgs
=
{}
for
key
in
self
.
_node_frame
.
schemes
:
for
key
in
self
.
_node_frame
.
schemes
:
col
=
self
.
_node_frame
[
key
]
col
=
self
.
_node_frame
[
key
]
reduced_msgs
[
key
]
=
F
.
spmm
(
adjmat
,
col
)
reduced_msgs
[
key
]
=
F
.
spmm
(
adjmat
,
col
)
node_repr
=
self
.
get_n_repr
()
if
len
(
reduced_msgs
)
==
1
and
__REPR__
in
reduced_msgs
:
if
len
(
reduced_msgs
)
==
1
and
__REPR__
in
reduced_msgs
:
reduced_msgs
=
reduced_msgs
[
__REPR__
]
reduced_msgs
=
reduced_msgs
[
__REPR__
]
self
.
set_n_repr
(
update_func
(
node_repr
,
reduced_msgs
))
node_repr
=
self
.
get_n_repr
(
new2old
)
else
:
new_node_repr
=
update_func
(
node_repr
,
reduced_msgs
)
if
is_all
(
u
)
and
is_all
(
v
):
self
.
set_n_repr
(
new_node_repr
,
new2old
)
self
.
_batch_sendto
(
u
,
v
,
message_func
)
self
.
_batch_recv
(
v
,
reduce_func
,
update_func
)
else
:
else
:
self
.
_batch_sendto
(
u
,
v
,
message_func
)
self
.
_batch_sendto
(
u
,
v
,
message_func
)
unique_v
=
F
.
unique
(
v
)
unique_v
=
F
.
unique
(
v
)
...
@@ -845,11 +878,24 @@ class DGLGraph(DiGraph):
...
@@ -845,11 +878,24 @@ class DGLGraph(DiGraph):
assert
reduce_func
is
not
None
assert
reduce_func
is
not
None
assert
update_func
is
not
None
assert
update_func
is
not
None
if
batchable
:
if
batchable
:
self
.
_batch_update_by_edge
(
ALL
,
ALL
,
if
message_func
==
'from_src'
and
reduce_func
==
'sum'
:
message_func
,
reduce_func
,
update_func
)
# TODO(minjie): use lazy dict for reduced_msgs
adjmat
=
self
.
cached_graph
.
adjmat
(
self
.
context
)
reduced_msgs
=
{}
for
key
in
self
.
_node_frame
.
schemes
:
col
=
self
.
_node_frame
[
key
]
reduced_msgs
[
key
]
=
F
.
spmm
(
adjmat
,
col
)
if
len
(
reduced_msgs
)
==
1
and
__REPR__
in
reduced_msgs
:
reduced_msgs
=
reduced_msgs
[
__REPR__
]
node_repr
=
self
.
get_n_repr
()
self
.
set_n_repr
(
update_func
(
node_repr
,
reduced_msgs
))
else
:
self
.
_batch_sendto
(
ALL
,
ALL
,
message_func
)
self
.
_batch_recv
(
ALL
,
reduce_func
,
update_func
)
else
:
else
:
u
=
[
uu
for
uu
,
_
in
self
.
edges
]
u
,
v
=
zip
(
*
self
.
edges
)
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
u
=
list
(
u
)
v
=
list
(
v
)
self
.
_nonbatch_sendto
(
u
,
v
,
message_func
)
self
.
_nonbatch_sendto
(
u
,
v
,
message_func
)
self
.
_nonbatch_recv
(
list
(
self
.
nodes
()),
reduce_func
,
update_func
)
self
.
_nonbatch_recv
(
list
(
self
.
nodes
()),
reduce_func
,
update_func
)
...
...
python/dgl/utils.py
View file @
e3bac70b
...
@@ -6,17 +6,21 @@ import dgl.backend as F
...
@@ -6,17 +6,21 @@ import dgl.backend as F
from
dgl.backend
import
Tensor
,
SparseTensor
from
dgl.backend
import
Tensor
,
SparseTensor
def
is_id_tensor
(
u
):
def
is_id_tensor
(
u
):
"""Return whether the input is a supported id tensor."""
return
isinstance
(
u
,
Tensor
)
and
F
.
isinteger
(
u
)
and
len
(
F
.
shape
(
u
))
==
1
return
isinstance
(
u
,
Tensor
)
and
F
.
isinteger
(
u
)
and
len
(
F
.
shape
(
u
))
==
1
def
is_id_container
(
u
):
def
is_id_container
(
u
):
"""Return whether the input is a supported id container."""
return
isinstance
(
u
,
list
)
return
isinstance
(
u
,
list
)
def
node_iter
(
n
):
def
node_iter
(
n
):
"""Return an iterator that loops over the given nodes."""
n
=
convert_to_id_container
(
n
)
n
=
convert_to_id_container
(
n
)
for
nn
in
n
:
for
nn
in
n
:
yield
nn
yield
nn
def
edge_iter
(
u
,
v
):
def
edge_iter
(
u
,
v
):
"""Return an iterator that loops over the given edges."""
u
=
convert_to_id_container
(
u
)
u
=
convert_to_id_container
(
u
)
v
=
convert_to_id_container
(
v
)
v
=
convert_to_id_container
(
v
)
if
len
(
u
)
==
len
(
v
):
if
len
(
u
)
==
len
(
v
):
...
@@ -35,6 +39,7 @@ def edge_iter(u, v):
...
@@ -35,6 +39,7 @@ def edge_iter(u, v):
raise
ValueError
(
'Error edges:'
,
u
,
v
)
raise
ValueError
(
'Error edges:'
,
u
,
v
)
def
convert_to_id_container
(
x
):
def
convert_to_id_container
(
x
):
"""Convert the input to id container."""
if
is_id_container
(
x
):
if
is_id_container
(
x
):
return
x
return
x
elif
is_id_tensor
(
x
):
elif
is_id_tensor
(
x
):
...
@@ -47,6 +52,7 @@ def convert_to_id_container(x):
...
@@ -47,6 +52,7 @@ def convert_to_id_container(x):
return
None
return
None
def
convert_to_id_tensor
(
x
,
ctx
=
None
):
def
convert_to_id_tensor
(
x
,
ctx
=
None
):
"""Convert the input to id tensor."""
if
is_id_container
(
x
):
if
is_id_container
(
x
):
ret
=
F
.
tensor
(
x
,
dtype
=
F
.
int64
)
ret
=
F
.
tensor
(
x
,
dtype
=
F
.
int64
)
elif
is_id_tensor
(
x
):
elif
is_id_tensor
(
x
):
...
@@ -81,3 +87,38 @@ class LazyDict(Mapping):
...
@@ -81,3 +87,38 @@ class LazyDict(Mapping):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_keys
)
return
len
(
self
.
_keys
)
def
build_relabel_map
(
x
):
"""Relabel the input ids to continuous ids that starts from zero.
Parameters
----------
x : int, tensor or container
The input ids.
Returns
-------
new_to_old : tensor
The mapping from new id to old id.
old_to_new : tensor
The mapping from old id to new id. It is a vector of length MAX(x).
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x
=
convert_to_id_tensor
(
x
)
unique_x
,
_
=
F
.
sort
(
F
.
unique
(
x
))
map_len
=
int
(
F
.
max
(
unique_x
))
+
1
old_to_new
=
F
.
zeros
(
map_len
,
dtype
=
F
.
int64
)
# TODO(minjie): should not directly use []
old_to_new
[
unique_x
]
=
F
.
astype
(
F
.
arange
(
len
(
unique_x
)),
F
.
int64
)
return
unique_x
,
old_to_new
def
edge_broadcasting
(
u
,
v
):
"""Convert one-many and many-one edges to many-many."""
if
len
(
u
)
!=
len
(
v
)
and
len
(
u
)
==
1
:
u
=
F
.
broadcast_to
(
u
,
v
)
elif
len
(
u
)
!=
len
(
v
)
and
len
(
v
)
==
1
:
v
=
F
.
broadcast_to
(
v
,
u
)
else
:
assert
len
(
u
)
==
len
(
v
)
return
u
,
v
tests/pytorch/test_specialization.py
View file @
e3bac70b
...
@@ -34,16 +34,22 @@ def generate_graph():
...
@@ -34,16 +34,22 @@ def generate_graph():
def
test_spmv_specialize
():
def
test_spmv_specialize
():
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
'from_src'
,
batchable
=
True
)
# update all
g
.
register_reduce_func
(
'sum'
,
batchable
=
True
)
g
.
register_update_func
(
update_func
,
batchable
=
True
)
v1
=
g
.
get_n_repr
()
v1
=
g
.
get_n_repr
()
g
.
update_all
()
g
.
update_all
(
'from_src'
,
'sum'
,
update_func
,
batchable
=
True
)
v2
=
g
.
get_n_repr
()
v2
=
g
.
get_n_repr
()
g
.
set_n_repr
(
v1
)
g
.
set_n_repr
(
v1
)
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
update_all
(
message_func
,
reduce_func
,
update_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
v3
=
g
.
get_n_repr
()
g
.
update_all
()
check_eq
(
v2
,
v3
)
# partial update
u
=
th
.
tensor
([
0
,
0
,
0
,
3
,
4
,
9
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
0
])
v1
=
g
.
get_n_repr
()
g
.
update_by_edge
(
u
,
v
,
'from_src'
,
'sum'
,
update_func
,
batchable
=
True
)
v2
=
g
.
get_n_repr
()
g
.
set_n_repr
(
v1
)
g
.
update_by_edge
(
u
,
v
,
message_func
,
reduce_func
,
update_func
,
batchable
=
True
)
v3
=
g
.
get_n_repr
()
v3
=
g
.
get_n_repr
()
check_eq
(
v2
,
v3
)
check_eq
(
v2
,
v3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment