Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b2e4bdc0
"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "91fe0c90690d9a7078b0b03dc059088a6f310777"
Commit
b2e4bdc0
authored
Sep 20, 2018
by
Minjie Wang
Browse files
passed basic test batching
parent
44db98c4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
61 additions
and
63 deletions
+61
-63
python/dgl/frame.py
python/dgl/frame.py
+6
-6
python/dgl/graph.py
python/dgl/graph.py
+36
-24
python/dgl/ndarray.py
python/dgl/ndarray.py
+2
-1
python/dgl/scheduler.py
python/dgl/scheduler.py
+3
-3
python/dgl/utils.py
python/dgl/utils.py
+4
-4
src/graph/graph.cc
src/graph/graph.cc
+3
-3
tests/pytorch/test_batching.py
tests/pytorch/test_batching.py
+3
-15
tests/pytorch/test_batching_anonymous.py
tests/pytorch/test_batching_anonymous.py
+1
-2
tests/pytorch/test_frame.py
tests/pytorch/test_frame.py
+1
-1
tests/pytorch/test_function.py
tests/pytorch/test_function.py
+2
-4
No files found.
python/dgl/frame.py
View file @
b2e4bdc0
...
...
@@ -123,7 +123,7 @@ class FrameRef(MutableMapping):
def
select_rows
(
self
,
query
):
rowids
=
self
.
_getrowid
(
query
)
def
_lazy_select
(
key
):
idx
=
rowids
.
totensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
idx
=
rowids
.
to
user
tensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
return
F
.
gather_row
(
self
.
_frame
[
key
],
idx
)
return
utils
.
LazyDict
(
_lazy_select
,
keys
=
self
.
schemes
)
...
...
@@ -132,7 +132,7 @@ class FrameRef(MutableMapping):
if
self
.
is_span_whole_column
():
return
col
else
:
idx
=
self
.
index
().
totensor
(
F
.
get_context
(
col
))
idx
=
self
.
index
().
to
user
tensor
(
F
.
get_context
(
col
))
return
F
.
gather_row
(
col
,
idx
)
def
__setitem__
(
self
,
key
,
val
):
...
...
@@ -156,7 +156,7 @@ class FrameRef(MutableMapping):
else
:
fcol
=
F
.
zeros
((
self
.
_frame
.
num_rows
,)
+
shp
[
1
:])
fcol
=
F
.
to_context
(
fcol
,
colctx
)
idx
=
self
.
index
().
totensor
(
colctx
)
idx
=
self
.
index
().
to
user
tensor
(
colctx
)
newfcol
=
F
.
scatter_row
(
fcol
,
idx
,
col
)
self
.
_frame
[
name
]
=
newfcol
...
...
@@ -167,7 +167,7 @@ class FrameRef(MutableMapping):
# add new column
tmpref
=
FrameRef
(
self
.
_frame
,
rowids
)
tmpref
.
add_column
(
key
,
col
)
idx
=
rowids
.
totensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
idx
=
rowids
.
to
user
tensor
(
F
.
get_context
(
self
.
_frame
[
key
]))
self
.
_frame
[
key
]
=
F
.
scatter_row
(
self
.
_frame
[
key
],
idx
,
col
)
def
__delitem__
(
self
,
key
):
...
...
@@ -223,8 +223,8 @@ class FrameRef(MutableMapping):
# shortcut for identical mapping
return
query
else
:
idxtensor
=
self
.
index
().
totensor
()
return
utils
.
toindex
(
F
.
gather_row
(
idxtensor
,
query
.
totensor
()))
idxtensor
=
self
.
index
().
to
user
tensor
()
return
utils
.
toindex
(
F
.
gather_row
(
idxtensor
,
query
.
to
user
tensor
()))
def
index
(
self
):
if
self
.
_index
is
None
:
...
...
python/dgl/graph.py
View file @
b2e4bdc0
...
...
@@ -8,12 +8,14 @@ import dgl
from
.base
import
ALL
,
is_all
,
__MSG__
,
__REPR__
from
.
import
backend
as
F
from
.backend
import
Tensor
from
.graph_index
import
GraphIndex
from
.frame
import
FrameRef
,
merge_frames
from
.
import
scheduler
from
.
import
utils
from
.function.message
import
BundledMessageFunction
from
.function.reducer
import
BundledReduceFunction
from
.graph_index
import
GraphIndex
from
.
import
scheduler
from
.
import
utils
__all__
=
[
'DLGraph'
]
class
DGLGraph
(
object
):
"""Base graph class specialized for neural networks on graphs.
...
...
@@ -63,9 +65,11 @@ class DGLGraph(object):
Optional node representations.
"""
self
.
_graph
.
add_nodes
(
num
)
self
.
_msg_graph
.
add_nodes
(
num
)
#TODO(minjie): change frames
assert
reprs
is
None
def
add_edge
(
self
,
u
,
v
,
repr
=
None
):
def
add_edge
(
self
,
u
,
v
,
repr
s
=
None
):
"""Add one edge.
Parameters
...
...
@@ -74,11 +78,12 @@ class DGLGraph(object):
The src node.
v : int
The dst node.
repr : dict
repr
s
: dict
Optional edge representation.
"""
self
.
_graph
.
add_edge
(
u
,
v
)
#TODO(minjie): change frames
assert
reprs
is
None
def
add_edges
(
self
,
u
,
v
,
reprs
=
None
):
"""Add many edges.
...
...
@@ -96,6 +101,7 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
self
.
_graph
.
add_edges
(
u
,
v
)
#TODO(minjie): change frames
assert
reprs
is
None
def
clear
(
self
):
"""Clear the graph and its storage."""
...
...
@@ -483,6 +489,8 @@ class DGLGraph(object):
dict
Representation dict
"""
if
len
(
self
.
node_attr_schemes
())
==
0
:
return
dict
()
if
is_all
(
u
):
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
self
.
_node_frame
[
__REPR__
]
...
...
@@ -535,7 +543,7 @@ class DGLGraph(object):
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
if
u_is_all
:
num_edges
=
self
.
cached_graph
.
num
_edges
()
num_edges
=
self
.
number_of
_edges
()
else
:
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
...
...
@@ -553,7 +561,7 @@ class DGLGraph(object):
else
:
self
.
_edge_frame
[
__REPR__
]
=
h_uv
else
:
eid
=
self
.
cached
_graph
.
get_
edge_id
(
u
,
v
)
eid
=
self
.
_graph
.
edge_id
s
(
u
,
v
)
if
utils
.
is_dict_like
(
h_uv
):
self
.
_edge_frame
[
eid
]
=
h_uv
else
:
...
...
@@ -571,7 +579,7 @@ class DGLGraph(object):
"""
# sanity check
if
is_all
(
eid
):
num_edges
=
self
.
cached_graph
.
num
_edges
()
num_edges
=
self
.
number_of
_edges
()
else
:
eid
=
utils
.
toindex
(
eid
)
num_edges
=
len
(
eid
)
...
...
@@ -611,6 +619,8 @@ class DGLGraph(object):
u_is_all
=
is_all
(
u
)
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
if
u_is_all
:
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
[
__REPR__
]
...
...
@@ -619,7 +629,7 @@ class DGLGraph(object):
else
:
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
eid
=
self
.
cached
_graph
.
get_
edge_id
(
u
,
v
)
eid
=
self
.
_graph
.
edge_id
s
(
u
,
v
)
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
.
select_rows
(
eid
)[
__REPR__
]
else
:
...
...
@@ -653,6 +663,8 @@ class DGLGraph(object):
dict
Representation dict
"""
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
if
is_all
(
eid
):
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
[
__REPR__
]
...
...
@@ -843,8 +855,8 @@ class DGLGraph(object):
def
_batch_send
(
self
,
u
,
v
,
message_func
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached
_graph
.
edges
()
self
.
msg_graph
.
add_edges
(
u
,
v
)
u
,
v
,
_
=
self
.
_graph
.
edges
(
sorted
=
True
)
self
.
_
msg_graph
.
add_edges
(
u
,
v
)
# call UDF
src_reprs
=
self
.
get_n_repr
(
u
)
edge_reprs
=
self
.
get_e_repr
()
...
...
@@ -853,11 +865,10 @@ class DGLGraph(object):
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
eid
=
self
.
cached_graph
.
get_edge_id
(
u
,
v
)
self
.
msg_graph
.
add_edges
(
u
,
v
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
# call UDF
src_reprs
=
self
.
get_n_repr
(
u
)
edge_reprs
=
self
.
get_e_repr
_by_id
(
eid
)
edge_reprs
=
self
.
get_e_repr
(
u
,
v
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
if
utils
.
is_dict_like
(
msgs
):
self
.
_msg_frame
.
append
(
msgs
)
...
...
@@ -909,7 +920,7 @@ class DGLGraph(object):
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached
_graph
.
edges
()
u
,
v
=
self
.
_graph
.
edges
(
sorted
=
True
)
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
...
@@ -920,7 +931,7 @@ class DGLGraph(object):
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
eid
=
self
.
cached
_graph
.
get_
edge_id
(
u
,
v
)
eid
=
self
.
_graph
.
edge_id
s
(
u
,
v
)
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
...
@@ -1005,7 +1016,7 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
# degree bucketing
degrees
,
v_buckets
=
scheduler
.
degree_bucketing
(
self
.
msg_graph
,
v
)
degrees
,
v_buckets
=
scheduler
.
degree_bucketing
(
self
.
_
msg_graph
,
v
)
if
degrees
==
[
0
]:
# no message has been sent to the specified node
return
...
...
@@ -1020,8 +1031,7 @@ class DGLGraph(object):
continue
bkt_len
=
len
(
v_bkt
)
dst_reprs
=
self
.
get_n_repr
(
v_bkt
)
uu
,
vv
,
_
=
self
.
msg_graph
.
in_edges
(
v_bkt
)
in_msg_ids
=
self
.
msg_graph
.
get_edge_id
(
uu
,
vv
)
uu
,
vv
,
in_msg_ids
=
self
.
_msg_graph
.
in_edges
(
v_bkt
)
in_msgs
=
self
.
_msg_frame
.
select_rows
(
in_msg_ids
)
# Reshape the column tensor to (B, Deg, ...).
def
_reshape_fn
(
msg
):
...
...
@@ -1033,7 +1043,7 @@ class DGLGraph(object):
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
reordered_v
.
append
(
v_bkt
.
totensor
())
reordered_v
.
append
(
v_bkt
.
to
user
tensor
())
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
# TODO: clear partial messages
...
...
@@ -1087,7 +1097,7 @@ class DGLGraph(object):
# no edges to be triggered
assert
len
(
v
)
==
0
return
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
totensor
()))
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
to
user
tensor
()))
# TODO(minjie): better way to figure out `batchable` flag
if
message_func
==
"default"
:
...
...
@@ -1135,10 +1145,10 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
if
len
(
v
)
==
0
:
return
uu
,
vv
,
_
=
self
.
cached
_graph
.
in_edges
(
v
)
uu
,
vv
,
_
=
self
.
_graph
.
in_edges
(
v
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
batchable
=
batchable
)
unique_v
=
F
.
unique
(
v
.
totensor
())
unique_v
=
F
.
unique
(
v
.
to
user
tensor
())
self
.
apply_nodes
(
unique_v
,
apply_node_func
,
batchable
=
batchable
)
def
push
(
self
,
...
...
@@ -1165,7 +1175,7 @@ class DGLGraph(object):
u
=
utils
.
toindex
(
u
)
if
len
(
u
)
==
0
:
return
uu
,
vv
,
_
=
self
.
cached
_graph
.
out_edges
(
u
)
uu
,
vv
,
_
=
self
.
_graph
.
out_edges
(
u
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
,
batchable
=
batchable
)
...
...
@@ -1309,8 +1319,10 @@ class DGLGraph(object):
reduce_func
)
def
clear_messages
(
self
):
"""Clear all messages."""
self
.
_msg_graph
.
clear
()
self
.
_msg_frame
.
clear
()
self
.
_msg_graph
.
add_nodes
(
self
.
number_of_nodes
())
def
_get_repr
(
attr_dict
):
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
...
...
python/dgl/ndarray.py
View file @
b2e4bdc0
...
...
@@ -7,6 +7,7 @@ used with C++ library.
from
__future__
import
absolute_import
as
_abs
import
ctypes
import
functools
import
operator
import
numpy
as
_np
...
...
@@ -18,7 +19,7 @@ from . import backend as F
class
NDArray
(
NDArrayBase
):
"""Lightweight NDArray class for DGL framework."""
def
__len__
(
self
):
return
reduce
(
operator
.
mul
,
self
.
shape
,
1
)
return
functools
.
reduce
(
operator
.
mul
,
self
.
shape
,
1
)
def
cpu
(
dev_id
=
0
):
"""Construct a CPU device
...
...
python/dgl/scheduler.py
View file @
b2e4bdc0
...
...
@@ -11,12 +11,12 @@ from . import utils
__all__
=
[
"degree_bucketing"
,
"get_executor"
]
def
degree_bucketing
(
cached_
graph
,
v
):
def
degree_bucketing
(
graph
,
v
):
"""Create degree bucketing scheduling policy.
Parameters
----------
cached_
graph : dgl.
cached_graph.CachedGraph
graph : dgl.
graph_index.GraphIndex
the graph
v : dgl.utils.Index
the nodes to gather messages
...
...
@@ -29,7 +29,7 @@ def degree_bucketing(cached_graph, v):
list of node id buckets; nodes belong to the same bucket have
the same degree
"""
degrees
=
F
.
asnumpy
(
cached_
graph
.
in_degrees
(
v
).
to
tensor
())
degrees
=
np
.
array
(
graph
.
in_degrees
(
v
).
to
list
())
unique_degrees
=
list
(
np
.
unique
(
degrees
))
v_np
=
np
.
array
(
v
.
tolist
())
v_bkt
=
[]
...
...
python/dgl/utils.py
View file @
b2e4bdc0
...
...
@@ -141,9 +141,9 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting
"""
if
len
(
u
)
!=
len
(
v
)
and
len
(
u
)
==
1
:
u
=
toindex
(
F
.
broadcast_to
(
u
.
totensor
(),
v
.
totensor
()))
u
=
toindex
(
F
.
broadcast_to
(
u
.
to
user
tensor
(),
v
.
to
user
tensor
()))
elif
len
(
u
)
!=
len
(
v
)
and
len
(
v
)
==
1
:
v
=
toindex
(
F
.
broadcast_to
(
v
.
totensor
(),
u
.
totensor
()))
v
=
toindex
(
F
.
broadcast_to
(
v
.
to
user
tensor
(),
u
.
to
user
tensor
()))
else
:
assert
len
(
u
)
==
len
(
v
)
return
u
,
v
...
...
@@ -205,7 +205,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x
=
x
.
totensor
()
x
=
x
.
to
user
tensor
()
unique_x
,
_
=
F
.
sort
(
F
.
unique
(
x
))
map_len
=
int
(
F
.
max
(
unique_x
))
+
1
old_to_new
=
F
.
zeros
(
map_len
,
dtype
=
F
.
int64
)
...
...
@@ -312,6 +312,6 @@ def reorder(dict_like, index):
"""
new_dict
=
{}
for
key
,
val
in
dict_like
.
items
():
idx_ctx
=
index
.
totensor
(
F
.
get_context
(
val
))
idx_ctx
=
index
.
to
user
tensor
(
F
.
get_context
(
val
))
new_dict
[
key
]
=
F
.
gather_row
(
val
,
idx_ctx
)
return
new_dict
src/graph/graph.cc
View file @
b2e4bdc0
...
...
@@ -31,7 +31,7 @@ void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src
_ids
->
shape
[
0
];
const
auto
dstlen
=
dst
_ids
->
shape
[
0
];
const
int64_t
*
src_data
=
static_cast
<
int64_t
*>
(
src_ids
->
data
);
const
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst_ids
->
data
);
if
(
srclen
==
1
)
{
...
...
@@ -78,7 +78,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src
_ids
->
shape
[
0
];
const
auto
dstlen
=
dst
_ids
->
shape
[
0
];
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
BoolArray
rst
=
BoolArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
...
...
@@ -150,7 +150,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src
_ids
->
shape
[
0
];
const
auto
dstlen
=
dst
_ids
->
shape
[
0
];
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
IdArray
rst
=
IdArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
...
...
tests/pytorch/test_batching.py
View file @
b2e4bdc0
...
...
@@ -27,8 +27,7 @@ def apply_node_func(node):
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
)
# 10 nodes.
g
.
add_nodes
(
10
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
...
...
@@ -198,7 +197,7 @@ def test_update_routines():
def
test_reduce_0deg
():
g
=
DGLGraph
()
g
.
add_nodes
_from
([
0
,
1
,
2
,
3
,
4
]
)
g
.
add_nodes
(
5
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
add_edge
(
3
,
0
)
...
...
@@ -218,7 +217,7 @@ def test_reduce_0deg():
def
test_pull_0deg
():
g
=
DGLGraph
()
g
.
add_nodes
_from
([
0
,
1
]
)
g
.
add_nodes
(
2
)
g
.
add_edge
(
0
,
1
)
def
_message
(
src
,
edge
):
return
src
...
...
@@ -243,16 +242,6 @@ def test_pull_0deg():
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
def
_test_delete
():
g
=
generate_graph
()
ecol
=
Variable
(
th
.
randn
(
17
,
D
),
requires_grad
=
grad
)
g
.
set_e_repr
({
'e'
:
ecol
})
assert
g
.
get_n_repr
()[
'h'
].
shape
[
0
]
==
10
assert
g
.
get_e_repr
()[
'e'
].
shape
[
0
]
==
17
g
.
remove_node
(
0
)
assert
g
.
get_n_repr
()[
'h'
].
shape
[
0
]
==
9
assert
g
.
get_e_repr
()[
'e'
].
shape
[
0
]
==
8
if
__name__
==
'__main__'
:
test_batch_setter_getter
()
test_batch_setter_autograd
()
...
...
@@ -261,4 +250,3 @@ if __name__ == '__main__':
test_update_routines
()
test_reduce_0deg
()
test_pull_0deg
()
#test_delete()
tests/pytorch/test_batching_anonymous.py
View file @
b2e4bdc0
...
...
@@ -23,8 +23,7 @@ def reduce_func(hv, msgs):
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
)
# 10 nodes.
g
.
add_nodes
(
10
)
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
...
...
tests/pytorch/test_frame.py
View file @
b2e4bdc0
...
...
@@ -113,7 +113,7 @@ def test_append2():
assert
not
f
.
is_span_whole_column
()
assert
f
.
num_rows
==
3
*
N
new_idx
=
list
(
range
(
N
))
+
list
(
range
(
2
*
N
,
4
*
N
))
assert
check_eq
(
f
.
index
().
totensor
(),
th
.
tensor
(
new_idx
))
assert
check_eq
(
f
.
index
().
to
user
tensor
(),
th
.
tensor
(
new_idx
))
assert
data
.
num_rows
==
4
*
N
def
test_row1
():
...
...
tests/pytorch/test_function.py
View file @
b2e4bdc0
...
...
@@ -5,8 +5,7 @@ from dgl.graph import __REPR__
def
generate_graph
():
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
)
# 10 nodes.
g
.
add_nodes
(
10
)
# 10 nodes.
h
=
th
.
arange
(
1
,
11
)
g
.
set_n_repr
({
'h'
:
h
})
# create a graph where 0 is the source and 9 is the sink
...
...
@@ -23,8 +22,7 @@ def generate_graph():
def
generate_graph1
():
"""graph with anonymous repr"""
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
)
# 10 nodes.
g
.
add_nodes
(
10
)
# 10 nodes.
h
=
th
.
arange
(
1
,
11
)
g
.
set_n_repr
(
h
)
# create a graph where 0 is the source and 9 is the sink
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment