Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b2e4bdc0
Commit
b2e4bdc0
authored
Sep 20, 2018
by
Minjie Wang
Browse files
passed basic test batching
parent
44db98c4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
61 additions
and
63 deletions
+61
-63
python/dgl/frame.py
python/dgl/frame.py
+6
-6
python/dgl/graph.py
python/dgl/graph.py
+36
-24
python/dgl/ndarray.py
python/dgl/ndarray.py
+2
-1
python/dgl/scheduler.py
python/dgl/scheduler.py
+3
-3
python/dgl/utils.py
python/dgl/utils.py
+4
-4
src/graph/graph.cc
src/graph/graph.cc
+3
-3
tests/pytorch/test_batching.py
tests/pytorch/test_batching.py
+3
-15
tests/pytorch/test_batching_anonymous.py
tests/pytorch/test_batching_anonymous.py
+1
-2
tests/pytorch/test_frame.py
tests/pytorch/test_frame.py
+1
-1
tests/pytorch/test_function.py
tests/pytorch/test_function.py
+2
-4
No files found.
python/dgl/frame.py
View file @
b2e4bdc0
...
@@ -123,7 +123,7 @@ class FrameRef(MutableMapping):
...
@@ -123,7 +123,7 @@ class FrameRef(MutableMapping):
def
select_rows
(
self
,
query
):
def
select_rows
(
self
,
query
):
rowids
=
self
.
_getrowid
(
query
)
rowids
=
self
.
_getrowid
(
query
)
def
_lazy_select
(
key
):
def
_lazy_select
(
key
):
idx
=
rowids
.
totensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
idx
=
rowids
.
to
user
tensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
return
F
.
gather_row
(
self
.
_frame
[
key
],
idx
)
return
F
.
gather_row
(
self
.
_frame
[
key
],
idx
)
return
utils
.
LazyDict
(
_lazy_select
,
keys
=
self
.
schemes
)
return
utils
.
LazyDict
(
_lazy_select
,
keys
=
self
.
schemes
)
...
@@ -132,7 +132,7 @@ class FrameRef(MutableMapping):
...
@@ -132,7 +132,7 @@ class FrameRef(MutableMapping):
if
self
.
is_span_whole_column
():
if
self
.
is_span_whole_column
():
return
col
return
col
else
:
else
:
idx
=
self
.
index
().
totensor
(
F
.
get_context
(
col
))
idx
=
self
.
index
().
to
user
tensor
(
F
.
get_context
(
col
))
return
F
.
gather_row
(
col
,
idx
)
return
F
.
gather_row
(
col
,
idx
)
def
__setitem__
(
self
,
key
,
val
):
def
__setitem__
(
self
,
key
,
val
):
...
@@ -156,7 +156,7 @@ class FrameRef(MutableMapping):
...
@@ -156,7 +156,7 @@ class FrameRef(MutableMapping):
else
:
else
:
fcol
=
F
.
zeros
((
self
.
_frame
.
num_rows
,)
+
shp
[
1
:])
fcol
=
F
.
zeros
((
self
.
_frame
.
num_rows
,)
+
shp
[
1
:])
fcol
=
F
.
to_context
(
fcol
,
colctx
)
fcol
=
F
.
to_context
(
fcol
,
colctx
)
idx
=
self
.
index
().
totensor
(
colctx
)
idx
=
self
.
index
().
to
user
tensor
(
colctx
)
newfcol
=
F
.
scatter_row
(
fcol
,
idx
,
col
)
newfcol
=
F
.
scatter_row
(
fcol
,
idx
,
col
)
self
.
_frame
[
name
]
=
newfcol
self
.
_frame
[
name
]
=
newfcol
...
@@ -167,7 +167,7 @@ class FrameRef(MutableMapping):
...
@@ -167,7 +167,7 @@ class FrameRef(MutableMapping):
# add new column
# add new column
tmpref
=
FrameRef
(
self
.
_frame
,
rowids
)
tmpref
=
FrameRef
(
self
.
_frame
,
rowids
)
tmpref
.
add_column
(
key
,
col
)
tmpref
.
add_column
(
key
,
col
)
idx
=
rowids
.
totensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
idx
=
rowids
.
to
user
tensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
self
.
_frame
[
key
]
=
F
.
scatter_row
(
self
.
_frame
[
key
],
idx
,
col
)
self
.
_frame
[
key
]
=
F
.
scatter_row
(
self
.
_frame
[
key
],
idx
,
col
)
def
__delitem__
(
self
,
key
):
def
__delitem__
(
self
,
key
):
...
@@ -223,8 +223,8 @@ class FrameRef(MutableMapping):
...
@@ -223,8 +223,8 @@ class FrameRef(MutableMapping):
# shortcut for identical mapping
# shortcut for identical mapping
return
query
return
query
else
:
else
:
idxtensor
=
self
.
index
().
totensor
()
idxtensor
=
self
.
index
().
to
user
tensor
()
return
utils
.
toindex
(
F
.
gather_row
(
idxtensor
,
query
.
totensor
()))
return
utils
.
toindex
(
F
.
gather_row
(
idxtensor
,
query
.
to
user
tensor
()))
def
index
(
self
):
def
index
(
self
):
if
self
.
_index
is
None
:
if
self
.
_index
is
None
:
...
...
python/dgl/graph.py
View file @
b2e4bdc0
...
@@ -8,12 +8,14 @@ import dgl
...
@@ -8,12 +8,14 @@ import dgl
from
.base
import
ALL
,
is_all
,
__MSG__
,
__REPR__
from
.base
import
ALL
,
is_all
,
__MSG__
,
__REPR__
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.backend
import
Tensor
from
.backend
import
Tensor
from
.graph_index
import
GraphIndex
from
.frame
import
FrameRef
,
merge_frames
from
.frame
import
FrameRef
,
merge_frames
from
.
import
scheduler
from
.
import
utils
from
.function.message
import
BundledMessageFunction
from
.function.message
import
BundledMessageFunction
from
.function.reducer
import
BundledReduceFunction
from
.function.reducer
import
BundledReduceFunction
from
.graph_index
import
GraphIndex
from
.
import
scheduler
from
.
import
utils
__all__
=
[
'DLGraph'
]
class
DGLGraph
(
object
):
class
DGLGraph
(
object
):
"""Base graph class specialized for neural networks on graphs.
"""Base graph class specialized for neural networks on graphs.
...
@@ -63,9 +65,11 @@ class DGLGraph(object):
...
@@ -63,9 +65,11 @@ class DGLGraph(object):
Optional node representations.
Optional node representations.
"""
"""
self
.
_graph
.
add_nodes
(
num
)
self
.
_graph
.
add_nodes
(
num
)
self
.
_msg_graph
.
add_nodes
(
num
)
#TODO(minjie): change frames
#TODO(minjie): change frames
assert
reprs
is
None
def
add_edge
(
self
,
u
,
v
,
repr
=
None
):
def
add_edge
(
self
,
u
,
v
,
repr
s
=
None
):
"""Add one edge.
"""Add one edge.
Parameters
Parameters
...
@@ -74,11 +78,12 @@ class DGLGraph(object):
...
@@ -74,11 +78,12 @@ class DGLGraph(object):
The src node.
The src node.
v : int
v : int
The dst node.
The dst node.
repr : dict
repr
s
: dict
Optional edge representation.
Optional edge representation.
"""
"""
self
.
_graph
.
add_edge
(
u
,
v
)
self
.
_graph
.
add_edge
(
u
,
v
)
#TODO(minjie): change frames
#TODO(minjie): change frames
assert
reprs
is
None
def
add_edges
(
self
,
u
,
v
,
reprs
=
None
):
def
add_edges
(
self
,
u
,
v
,
reprs
=
None
):
"""Add many edges.
"""Add many edges.
...
@@ -96,6 +101,7 @@ class DGLGraph(object):
...
@@ -96,6 +101,7 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
self
.
_graph
.
add_edges
(
u
,
v
)
self
.
_graph
.
add_edges
(
u
,
v
)
#TODO(minjie): change frames
#TODO(minjie): change frames
assert
reprs
is
None
def
clear
(
self
):
def
clear
(
self
):
"""Clear the graph and its storage."""
"""Clear the graph and its storage."""
...
@@ -483,6 +489,8 @@ class DGLGraph(object):
...
@@ -483,6 +489,8 @@ class DGLGraph(object):
dict
dict
Representation dict
Representation dict
"""
"""
if
len
(
self
.
node_attr_schemes
())
==
0
:
return
dict
()
if
is_all
(
u
):
if
is_all
(
u
):
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
self
.
_node_frame
[
__REPR__
]
return
self
.
_node_frame
[
__REPR__
]
...
@@ -535,7 +543,7 @@ class DGLGraph(object):
...
@@ -535,7 +543,7 @@ class DGLGraph(object):
v_is_all
=
is_all
(
v
)
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
assert
u_is_all
==
v_is_all
if
u_is_all
:
if
u_is_all
:
num_edges
=
self
.
cached_graph
.
num
_edges
()
num_edges
=
self
.
number_of
_edges
()
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
...
@@ -553,7 +561,7 @@ class DGLGraph(object):
...
@@ -553,7 +561,7 @@ class DGLGraph(object):
else
:
else
:
self
.
_edge_frame
[
__REPR__
]
=
h_uv
self
.
_edge_frame
[
__REPR__
]
=
h_uv
else
:
else
:
eid
=
self
.
cached
_graph
.
get_
edge_id
(
u
,
v
)
eid
=
self
.
_graph
.
edge_id
s
(
u
,
v
)
if
utils
.
is_dict_like
(
h_uv
):
if
utils
.
is_dict_like
(
h_uv
):
self
.
_edge_frame
[
eid
]
=
h_uv
self
.
_edge_frame
[
eid
]
=
h_uv
else
:
else
:
...
@@ -571,7 +579,7 @@ class DGLGraph(object):
...
@@ -571,7 +579,7 @@ class DGLGraph(object):
"""
"""
# sanity check
# sanity check
if
is_all
(
eid
):
if
is_all
(
eid
):
num_edges
=
self
.
cached_graph
.
num
_edges
()
num_edges
=
self
.
number_of
_edges
()
else
:
else
:
eid
=
utils
.
toindex
(
eid
)
eid
=
utils
.
toindex
(
eid
)
num_edges
=
len
(
eid
)
num_edges
=
len
(
eid
)
...
@@ -611,6 +619,8 @@ class DGLGraph(object):
...
@@ -611,6 +619,8 @@ class DGLGraph(object):
u_is_all
=
is_all
(
u
)
u_is_all
=
is_all
(
u
)
v_is_all
=
is_all
(
v
)
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
assert
u_is_all
==
v_is_all
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
if
u_is_all
:
if
u_is_all
:
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
[
__REPR__
]
return
self
.
_edge_frame
[
__REPR__
]
...
@@ -619,7 +629,7 @@ class DGLGraph(object):
...
@@ -619,7 +629,7 @@ class DGLGraph(object):
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
eid
=
self
.
cached
_graph
.
get_
edge_id
(
u
,
v
)
eid
=
self
.
_graph
.
edge_id
s
(
u
,
v
)
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
.
select_rows
(
eid
)[
__REPR__
]
return
self
.
_edge_frame
.
select_rows
(
eid
)[
__REPR__
]
else
:
else
:
...
@@ -653,6 +663,8 @@ class DGLGraph(object):
...
@@ -653,6 +663,8 @@ class DGLGraph(object):
dict
dict
Representation dict
Representation dict
"""
"""
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
if
is_all
(
eid
):
if
is_all
(
eid
):
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
[
__REPR__
]
return
self
.
_edge_frame
[
__REPR__
]
...
@@ -843,8 +855,8 @@ class DGLGraph(object):
...
@@ -843,8 +855,8 @@ class DGLGraph(object):
def
_batch_send
(
self
,
u
,
v
,
message_func
):
def
_batch_send
(
self
,
u
,
v
,
message_func
):
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached
_graph
.
edges
()
u
,
v
,
_
=
self
.
_graph
.
edges
(
sorted
=
True
)
self
.
msg_graph
.
add_edges
(
u
,
v
)
self
.
_
msg_graph
.
add_edges
(
u
,
v
)
# call UDF
# call UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
edge_reprs
=
self
.
get_e_repr
()
edge_reprs
=
self
.
get_e_repr
()
...
@@ -853,11 +865,10 @@ class DGLGraph(object):
...
@@ -853,11 +865,10 @@ class DGLGraph(object):
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
eid
=
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
self
.
msg_graph
.
add_edges
(
u
,
v
)
# call UDF
# call UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
edge_reprs
=
self
.
get_e_repr
_by_id
(
eid
)
edge_reprs
=
self
.
get_e_repr
(
u
,
v
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
if
utils
.
is_dict_like
(
msgs
):
if
utils
.
is_dict_like
(
msgs
):
self
.
_msg_frame
.
append
(
msgs
)
self
.
_msg_frame
.
append
(
msgs
)
...
@@ -909,7 +920,7 @@ class DGLGraph(object):
...
@@ -909,7 +920,7 @@ class DGLGraph(object):
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached
_graph
.
edges
()
u
,
v
=
self
.
_graph
.
edges
(
sorted
=
True
)
# call the UDF
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
@@ -920,7 +931,7 @@ class DGLGraph(object):
...
@@ -920,7 +931,7 @@ class DGLGraph(object):
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
eid
=
self
.
cached
_graph
.
get_
edge_id
(
u
,
v
)
eid
=
self
.
_graph
.
edge_id
s
(
u
,
v
)
# call the UDF
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
@@ -1005,7 +1016,7 @@ class DGLGraph(object):
...
@@ -1005,7 +1016,7 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
# degree bucketing
# degree bucketing
degrees
,
v_buckets
=
scheduler
.
degree_bucketing
(
self
.
msg_graph
,
v
)
degrees
,
v_buckets
=
scheduler
.
degree_bucketing
(
self
.
_
msg_graph
,
v
)
if
degrees
==
[
0
]:
if
degrees
==
[
0
]:
# no message has been sent to the specified node
# no message has been sent to the specified node
return
return
...
@@ -1020,8 +1031,7 @@ class DGLGraph(object):
...
@@ -1020,8 +1031,7 @@ class DGLGraph(object):
continue
continue
bkt_len
=
len
(
v_bkt
)
bkt_len
=
len
(
v_bkt
)
dst_reprs
=
self
.
get_n_repr
(
v_bkt
)
dst_reprs
=
self
.
get_n_repr
(
v_bkt
)
uu
,
vv
,
_
=
self
.
msg_graph
.
in_edges
(
v_bkt
)
uu
,
vv
,
in_msg_ids
=
self
.
_msg_graph
.
in_edges
(
v_bkt
)
in_msg_ids
=
self
.
msg_graph
.
get_edge_id
(
uu
,
vv
)
in_msgs
=
self
.
_msg_frame
.
select_rows
(
in_msg_ids
)
in_msgs
=
self
.
_msg_frame
.
select_rows
(
in_msg_ids
)
# Reshape the column tensor to (B, Deg, ...).
# Reshape the column tensor to (B, Deg, ...).
def
_reshape_fn
(
msg
):
def
_reshape_fn
(
msg
):
...
@@ -1033,7 +1043,7 @@ class DGLGraph(object):
...
@@ -1033,7 +1043,7 @@ class DGLGraph(object):
else
:
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
reordered_v
.
append
(
v_bkt
.
totensor
())
reordered_v
.
append
(
v_bkt
.
to
user
tensor
())
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
# TODO: clear partial messages
# TODO: clear partial messages
...
@@ -1087,7 +1097,7 @@ class DGLGraph(object):
...
@@ -1087,7 +1097,7 @@ class DGLGraph(object):
# no edges to be triggered
# no edges to be triggered
assert
len
(
v
)
==
0
assert
len
(
v
)
==
0
return
return
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
totensor
()))
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
to
user
tensor
()))
# TODO(minjie): better way to figure out `batchable` flag
# TODO(minjie): better way to figure out `batchable` flag
if
message_func
==
"default"
:
if
message_func
==
"default"
:
...
@@ -1135,10 +1145,10 @@ class DGLGraph(object):
...
@@ -1135,10 +1145,10 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
if
len
(
v
)
==
0
:
if
len
(
v
)
==
0
:
return
return
uu
,
vv
,
_
=
self
.
cached
_graph
.
in_edges
(
v
)
uu
,
vv
,
_
=
self
.
_graph
.
in_edges
(
v
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
batchable
=
batchable
)
apply_node_func
=
None
,
batchable
=
batchable
)
unique_v
=
F
.
unique
(
v
.
totensor
())
unique_v
=
F
.
unique
(
v
.
to
user
tensor
())
self
.
apply_nodes
(
unique_v
,
apply_node_func
,
batchable
=
batchable
)
self
.
apply_nodes
(
unique_v
,
apply_node_func
,
batchable
=
batchable
)
def
push
(
self
,
def
push
(
self
,
...
@@ -1165,7 +1175,7 @@ class DGLGraph(object):
...
@@ -1165,7 +1175,7 @@ class DGLGraph(object):
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
if
len
(
u
)
==
0
:
if
len
(
u
)
==
0
:
return
return
uu
,
vv
,
_
=
self
.
cached
_graph
.
out_edges
(
u
)
uu
,
vv
,
_
=
self
.
_graph
.
out_edges
(
u
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
,
batchable
=
batchable
)
reduce_func
,
apply_node_func
,
batchable
=
batchable
)
...
@@ -1309,8 +1319,10 @@ class DGLGraph(object):
...
@@ -1309,8 +1319,10 @@ class DGLGraph(object):
reduce_func
)
reduce_func
)
def
clear_messages
(
self
):
def
clear_messages
(
self
):
"""Clear all messages."""
self
.
_msg_graph
.
clear
()
self
.
_msg_graph
.
clear
()
self
.
_msg_frame
.
clear
()
self
.
_msg_frame
.
clear
()
self
.
_msg_graph
.
add_nodes
(
self
.
number_of_nodes
())
def
_get_repr
(
attr_dict
):
def
_get_repr
(
attr_dict
):
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
...
...
python/dgl/ndarray.py
View file @
b2e4bdc0
...
@@ -7,6 +7,7 @@ used with C++ library.
...
@@ -7,6 +7,7 @@ used with C++ library.
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
ctypes
import
ctypes
import
functools
import
operator
import
operator
import
numpy
as
_np
import
numpy
as
_np
...
@@ -18,7 +19,7 @@ from . import backend as F
...
@@ -18,7 +19,7 @@ from . import backend as F
class
NDArray
(
NDArrayBase
):
class
NDArray
(
NDArrayBase
):
"""Lightweight NDArray class for DGL framework."""
"""Lightweight NDArray class for DGL framework."""
def
__len__
(
self
):
def
__len__
(
self
):
return
reduce
(
operator
.
mul
,
self
.
shape
,
1
)
return
functools
.
reduce
(
operator
.
mul
,
self
.
shape
,
1
)
def
cpu
(
dev_id
=
0
):
def
cpu
(
dev_id
=
0
):
"""Construct a CPU device
"""Construct a CPU device
...
...
python/dgl/scheduler.py
View file @
b2e4bdc0
...
@@ -11,12 +11,12 @@ from . import utils
...
@@ -11,12 +11,12 @@ from . import utils
__all__
=
[
"degree_bucketing"
,
"get_executor"
]
__all__
=
[
"degree_bucketing"
,
"get_executor"
]
def
degree_bucketing
(
cached_
graph
,
v
):
def
degree_bucketing
(
graph
,
v
):
"""Create degree bucketing scheduling policy.
"""Create degree bucketing scheduling policy.
Parameters
Parameters
----------
----------
cached_
graph : dgl.
cached_graph.CachedGraph
graph : dgl.
graph_index.GraphIndex
the graph
the graph
v : dgl.utils.Index
v : dgl.utils.Index
the nodes to gather messages
the nodes to gather messages
...
@@ -29,7 +29,7 @@ def degree_bucketing(cached_graph, v):
...
@@ -29,7 +29,7 @@ def degree_bucketing(cached_graph, v):
list of node id buckets; nodes belong to the same bucket have
list of node id buckets; nodes belong to the same bucket have
the same degree
the same degree
"""
"""
degrees
=
F
.
asnumpy
(
cached_
graph
.
in_degrees
(
v
).
to
tensor
())
degrees
=
np
.
array
(
graph
.
in_degrees
(
v
).
to
list
())
unique_degrees
=
list
(
np
.
unique
(
degrees
))
unique_degrees
=
list
(
np
.
unique
(
degrees
))
v_np
=
np
.
array
(
v
.
tolist
())
v_np
=
np
.
array
(
v
.
tolist
())
v_bkt
=
[]
v_bkt
=
[]
...
...
python/dgl/utils.py
View file @
b2e4bdc0
...
@@ -141,9 +141,9 @@ def edge_broadcasting(u, v):
...
@@ -141,9 +141,9 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting
The dst id(s) after broadcasting
"""
"""
if
len
(
u
)
!=
len
(
v
)
and
len
(
u
)
==
1
:
if
len
(
u
)
!=
len
(
v
)
and
len
(
u
)
==
1
:
u
=
toindex
(
F
.
broadcast_to
(
u
.
totensor
(),
v
.
totensor
()))
u
=
toindex
(
F
.
broadcast_to
(
u
.
to
user
tensor
(),
v
.
to
user
tensor
()))
elif
len
(
u
)
!=
len
(
v
)
and
len
(
v
)
==
1
:
elif
len
(
u
)
!=
len
(
v
)
and
len
(
v
)
==
1
:
v
=
toindex
(
F
.
broadcast_to
(
v
.
totensor
(),
u
.
totensor
()))
v
=
toindex
(
F
.
broadcast_to
(
v
.
to
user
tensor
(),
u
.
to
user
tensor
()))
else
:
else
:
assert
len
(
u
)
==
len
(
v
)
assert
len
(
u
)
==
len
(
v
)
return
u
,
v
return
u
,
v
...
@@ -205,7 +205,7 @@ def build_relabel_map(x):
...
@@ -205,7 +205,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
new id tensor: new_id = old_to_new[old_id]
"""
"""
x
=
x
.
totensor
()
x
=
x
.
to
user
tensor
()
unique_x
,
_
=
F
.
sort
(
F
.
unique
(
x
))
unique_x
,
_
=
F
.
sort
(
F
.
unique
(
x
))
map_len
=
int
(
F
.
max
(
unique_x
))
+
1
map_len
=
int
(
F
.
max
(
unique_x
))
+
1
old_to_new
=
F
.
zeros
(
map_len
,
dtype
=
F
.
int64
)
old_to_new
=
F
.
zeros
(
map_len
,
dtype
=
F
.
int64
)
...
@@ -312,6 +312,6 @@ def reorder(dict_like, index):
...
@@ -312,6 +312,6 @@ def reorder(dict_like, index):
"""
"""
new_dict
=
{}
new_dict
=
{}
for
key
,
val
in
dict_like
.
items
():
for
key
,
val
in
dict_like
.
items
():
idx_ctx
=
index
.
totensor
(
F
.
get_context
(
val
))
idx_ctx
=
index
.
to
user
tensor
(
F
.
get_context
(
val
))
new_dict
[
key
]
=
F
.
gather_row
(
val
,
idx_ctx
)
new_dict
[
key
]
=
F
.
gather_row
(
val
,
idx_ctx
)
return
new_dict
return
new_dict
src/graph/graph.cc
View file @
b2e4bdc0
...
@@ -31,7 +31,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
...
@@ -31,7 +31,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src
_ids
->
shape
[
0
];
const
auto
dstlen
=
dst
_ids
->
shape
[
0
];
const
int64_t
*
src_data
=
static_cast
<
int64_t
*>
(
src_ids
->
data
);
const
int64_t
*
src_data
=
static_cast
<
int64_t
*>
(
src_ids
->
data
);
const
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst_ids
->
data
);
const
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst_ids
->
data
);
if
(
srclen
==
1
)
{
if
(
srclen
==
1
)
{
...
@@ -78,7 +78,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
...
@@ -78,7 +78,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src
_ids
->
shape
[
0
];
const
auto
dstlen
=
dst
_ids
->
shape
[
0
];
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
BoolArray
rst
=
BoolArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
BoolArray
rst
=
BoolArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
...
@@ -150,7 +150,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
...
@@ -150,7 +150,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src
_ids
->
shape
[
0
];
const
auto
dstlen
=
dst
_ids
->
shape
[
0
];
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
IdArray
rst
=
IdArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
IdArray
rst
=
IdArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
...
...
tests/pytorch/test_batching.py
View file @
b2e4bdc0
...
@@ -27,8 +27,7 @@ def apply_node_func(node):
...
@@ -27,8 +27,7 @@ def apply_node_func(node):
def
generate_graph
(
grad
=
False
):
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_nodes
(
10
)
# 10 nodes.
g
.
add_node
(
i
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
0
,
i
)
...
@@ -198,7 +197,7 @@ def test_update_routines():
...
@@ -198,7 +197,7 @@ def test_update_routines():
def
test_reduce_0deg
():
def
test_reduce_0deg
():
g
=
DGLGraph
()
g
=
DGLGraph
()
g
.
add_nodes
_from
([
0
,
1
,
2
,
3
,
4
]
)
g
.
add_nodes
(
5
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
add_edge
(
3
,
0
)
g
.
add_edge
(
3
,
0
)
...
@@ -218,7 +217,7 @@ def test_reduce_0deg():
...
@@ -218,7 +217,7 @@ def test_reduce_0deg():
def
test_pull_0deg
():
def
test_pull_0deg
():
g
=
DGLGraph
()
g
=
DGLGraph
()
g
.
add_nodes
_from
([
0
,
1
]
)
g
.
add_nodes
(
2
)
g
.
add_edge
(
0
,
1
)
g
.
add_edge
(
0
,
1
)
def
_message
(
src
,
edge
):
def
_message
(
src
,
edge
):
return
src
return
src
...
@@ -243,16 +242,6 @@ def test_pull_0deg():
...
@@ -243,16 +242,6 @@ def test_pull_0deg():
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
def
_test_delete
():
g
=
generate_graph
()
ecol
=
Variable
(
th
.
randn
(
17
,
D
),
requires_grad
=
grad
)
g
.
set_e_repr
({
'e'
:
ecol
})
assert
g
.
get_n_repr
()[
'h'
].
shape
[
0
]
==
10
assert
g
.
get_e_repr
()[
'e'
].
shape
[
0
]
==
17
g
.
remove_node
(
0
)
assert
g
.
get_n_repr
()[
'h'
].
shape
[
0
]
==
9
assert
g
.
get_e_repr
()[
'e'
].
shape
[
0
]
==
8
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_batch_setter_getter
()
test_batch_setter_getter
()
test_batch_setter_autograd
()
test_batch_setter_autograd
()
...
@@ -261,4 +250,3 @@ if __name__ == '__main__':
...
@@ -261,4 +250,3 @@ if __name__ == '__main__':
test_update_routines
()
test_update_routines
()
test_reduce_0deg
()
test_reduce_0deg
()
test_pull_0deg
()
test_pull_0deg
()
#test_delete()
tests/pytorch/test_batching_anonymous.py
View file @
b2e4bdc0
...
@@ -23,8 +23,7 @@ def reduce_func(hv, msgs):
...
@@ -23,8 +23,7 @@ def reduce_func(hv, msgs):
def
generate_graph
(
grad
=
False
):
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_nodes
(
10
)
g
.
add_node
(
i
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
0
,
i
)
...
...
tests/pytorch/test_frame.py
View file @
b2e4bdc0
...
@@ -113,7 +113,7 @@ def test_append2():
...
@@ -113,7 +113,7 @@ def test_append2():
assert
not
f
.
is_span_whole_column
()
assert
not
f
.
is_span_whole_column
()
assert
f
.
num_rows
==
3
*
N
assert
f
.
num_rows
==
3
*
N
new_idx
=
list
(
range
(
N
))
+
list
(
range
(
2
*
N
,
4
*
N
))
new_idx
=
list
(
range
(
N
))
+
list
(
range
(
2
*
N
,
4
*
N
))
assert
check_eq
(
f
.
index
().
totensor
(),
th
.
tensor
(
new_idx
))
assert
check_eq
(
f
.
index
().
to
user
tensor
(),
th
.
tensor
(
new_idx
))
assert
data
.
num_rows
==
4
*
N
assert
data
.
num_rows
==
4
*
N
def
test_row1
():
def
test_row1
():
...
...
tests/pytorch/test_function.py
View file @
b2e4bdc0
...
@@ -5,8 +5,7 @@ from dgl.graph import __REPR__
...
@@ -5,8 +5,7 @@ from dgl.graph import __REPR__
def
generate_graph
():
def
generate_graph
():
g
=
dgl
.
DGLGraph
()
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_nodes
(
10
)
# 10 nodes.
g
.
add_node
(
i
)
# 10 nodes.
h
=
th
.
arange
(
1
,
11
)
h
=
th
.
arange
(
1
,
11
)
g
.
set_n_repr
({
'h'
:
h
})
g
.
set_n_repr
({
'h'
:
h
})
# create a graph where 0 is the source and 9 is the sink
# create a graph where 0 is the source and 9 is the sink
...
@@ -23,8 +22,7 @@ def generate_graph():
...
@@ -23,8 +22,7 @@ def generate_graph():
def
generate_graph1
():
def
generate_graph1
():
"""graph with anonymous repr"""
"""graph with anonymous repr"""
g
=
dgl
.
DGLGraph
()
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_nodes
(
10
)
# 10 nodes.
g
.
add_node
(
i
)
# 10 nodes.
h
=
th
.
arange
(
1
,
11
)
h
=
th
.
arange
(
1
,
11
)
g
.
set_n_repr
(
h
)
g
.
set_n_repr
(
h
)
# create a graph where 0 is the source and 9 is the sink
# create a graph where 0 is the source and 9 is the sink
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment