Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0024c7e1
Unverified
Commit
0024c7e1
authored
Nov 17, 2023
by
Rhett Ying
Committed by
GitHub
Nov 17, 2023
Browse files
[BugFix] return batch related ids in g.idtype (#6578)
parent
7643e537
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
40 deletions
+42
-40
python/dgl/heterograph.py
python/dgl/heterograph.py
+8
-6
tests/python/common/transforms/test_transform.py
tests/python/common/transforms/test_transform.py
+34
-34
No files found.
python/dgl/heterograph.py
View file @
0024c7e1
...
@@ -750,7 +750,7 @@ class DGLGraph(object):
...
@@ -750,7 +750,7 @@ class DGLGraph(object):
c_etype_batch_num_edges
,
one_hot_removed_edges
,
reducer
=
"sum"
c_etype_batch_num_edges
,
one_hot_removed_edges
,
reducer
=
"sum"
)
)
self
.
_batch_num_edges
[
c_etype
]
=
c_etype_batch_num_edges
-
F
.
astype
(
self
.
_batch_num_edges
[
c_etype
]
=
c_etype_batch_num_edges
-
F
.
astype
(
batch_num_removed_edges
,
F
.
int64
batch_num_removed_edges
,
self
.
idtype
)
)
sub_g
=
self
.
edge_subgraph
(
sub_g
=
self
.
edge_subgraph
(
...
@@ -890,7 +890,7 @@ class DGLGraph(object):
...
@@ -890,7 +890,7 @@ class DGLGraph(object):
self
.
_batch_num_nodes
[
self
.
_batch_num_nodes
[
target_ntype
target_ntype
]
=
c_ntype_batch_num_nodes
-
F
.
astype
(
]
=
c_ntype_batch_num_nodes
-
F
.
astype
(
batch_num_removed_nodes
,
F
.
int64
batch_num_removed_nodes
,
self
.
idtype
)
)
# Record old num_edges to check later whether some edges were removed
# Record old num_edges to check later whether some edges were removed
old_num_edges
=
{
old_num_edges
=
{
...
@@ -917,7 +917,7 @@ class DGLGraph(object):
...
@@ -917,7 +917,7 @@ class DGLGraph(object):
for
c_etype
in
canonical_etypes
:
for
c_etype
in
canonical_etypes
:
if
self
.
_graph
.
num_edges
(
self
.
get_etype_id
(
c_etype
))
==
0
:
if
self
.
_graph
.
num_edges
(
self
.
get_etype_id
(
c_etype
))
==
0
:
self
.
_batch_num_edges
[
c_etype
]
=
F
.
zeros
(
self
.
_batch_num_edges
[
c_etype
]
=
F
.
zeros
(
(
self
.
batch_size
,),
F
.
int64
,
self
.
device
(
self
.
batch_size
,),
self
.
idtype
,
self
.
device
)
)
continue
continue
...
@@ -936,7 +936,7 @@ class DGLGraph(object):
...
@@ -936,7 +936,7 @@ class DGLGraph(object):
reducer
=
"sum"
,
reducer
=
"sum"
,
)
)
self
.
_batch_num_edges
[
c_etype
]
=
F
.
astype
(
self
.
_batch_num_edges
[
c_etype
]
=
F
.
astype
(
batch_num_left_edges
,
F
.
int64
batch_num_left_edges
,
self
.
idtype
)
)
if
batched
and
not
store_ids
:
if
batched
and
not
store_ids
:
...
@@ -1511,7 +1511,7 @@ class DGLGraph(object):
...
@@ -1511,7 +1511,7 @@ class DGLGraph(object):
self
.
_batch_num_nodes
=
{}
self
.
_batch_num_nodes
=
{}
for
ty
in
self
.
ntypes
:
for
ty
in
self
.
ntypes
:
bnn
=
F
.
copy_to
(
bnn
=
F
.
copy_to
(
F
.
tensor
([
self
.
num_nodes
(
ty
)],
F
.
int64
),
self
.
device
F
.
tensor
([
self
.
num_nodes
(
ty
)],
self
.
idtype
),
self
.
device
)
)
self
.
_batch_num_nodes
[
ty
]
=
bnn
self
.
_batch_num_nodes
[
ty
]
=
bnn
if
ntype
is
None
:
if
ntype
is
None
:
...
@@ -1601,6 +1601,7 @@ class DGLGraph(object):
...
@@ -1601,6 +1601,7 @@ class DGLGraph(object):
batch
batch
unbatch
unbatch
"""
"""
val
=
utils
.
prepare_tensor_or_dict
(
self
,
val
,
"batch_num_nodes"
)
if
not
isinstance
(
val
,
Mapping
):
if
not
isinstance
(
val
,
Mapping
):
if
len
(
self
.
ntypes
)
!=
1
:
if
len
(
self
.
ntypes
)
!=
1
:
raise
DGLError
(
raise
DGLError
(
...
@@ -1660,7 +1661,7 @@ class DGLGraph(object):
...
@@ -1660,7 +1661,7 @@ class DGLGraph(object):
self
.
_batch_num_edges
=
{}
self
.
_batch_num_edges
=
{}
for
ty
in
self
.
canonical_etypes
:
for
ty
in
self
.
canonical_etypes
:
bne
=
F
.
copy_to
(
bne
=
F
.
copy_to
(
F
.
tensor
([
self
.
num_edges
(
ty
)],
F
.
int64
),
self
.
device
F
.
tensor
([
self
.
num_edges
(
ty
)],
self
.
idtype
),
self
.
device
)
)
self
.
_batch_num_edges
[
ty
]
=
bne
self
.
_batch_num_edges
[
ty
]
=
bne
if
etype
is
None
:
if
etype
is
None
:
...
@@ -1752,6 +1753,7 @@ class DGLGraph(object):
...
@@ -1752,6 +1753,7 @@ class DGLGraph(object):
batch
batch
unbatch
unbatch
"""
"""
val
=
utils
.
prepare_tensor_or_dict
(
self
,
val
,
"batch_num_edges"
)
if
not
isinstance
(
val
,
Mapping
):
if
not
isinstance
(
val
,
Mapping
):
if
len
(
self
.
etypes
)
!=
1
:
if
len
(
self
.
etypes
)
!=
1
:
raise
DGLError
(
raise
DGLError
(
...
...
tests/python/common/transforms/test_transform.py
View file @
0024c7e1
...
@@ -1608,21 +1608,21 @@ def test_remove_edges(idtype):
...
@@ -1608,21 +1608,21 @@ def test_remove_edges(idtype):
assert
bg
.
batch_size
==
bg_r
.
batch_size
assert
bg
.
batch_size
==
bg_r
.
batch_size
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(),
bg_r
.
batch_num_nodes
())
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(),
bg_r
.
batch_num_nodes
())
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(),
F
.
tensor
([
2
,
0
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(),
F
.
tensor
([
2
,
0
,
2
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_edges
(
bg
,
[
0
,
2
])
bg_r
=
dgl
.
remove_edges
(
bg
,
[
0
,
2
])
assert
bg
.
batch_size
==
bg_r
.
batch_size
assert
bg
.
batch_size
==
bg_r
.
batch_size
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(),
bg_r
.
batch_num_nodes
())
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(),
bg_r
.
batch_num_nodes
())
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(),
F
.
tensor
([
1
,
0
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(),
F
.
tensor
([
1
,
0
,
2
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_edges
(
bg
,
F
.
tensor
([
0
,
2
],
dtype
=
idtype
))
bg_r
=
dgl
.
remove_edges
(
bg
,
F
.
tensor
([
0
,
2
],
dtype
=
idtype
))
assert
bg
.
batch_size
==
bg_r
.
batch_size
assert
bg
.
batch_size
==
bg_r
.
batch_size
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(),
bg_r
.
batch_num_nodes
())
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(),
bg_r
.
batch_num_nodes
())
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(),
F
.
tensor
([
1
,
0
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(),
F
.
tensor
([
1
,
0
,
2
],
dtype
=
idtype
)
)
)
# batched heterogeneous graph
# batched heterogeneous graph
...
@@ -1659,7 +1659,7 @@ def test_remove_edges(idtype):
...
@@ -1659,7 +1659,7 @@ def test_remove_edges(idtype):
for
nty
in
ntypes
:
for
nty
in
ntypes
:
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
nty
),
bg_r
.
batch_num_nodes
(
nty
))
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
nty
),
bg_r
.
batch_num_nodes
(
nty
))
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
1
,
2
,
0
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
1
,
2
,
0
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
bg
.
batch_num_edges
(
"plays"
)
bg_r
.
batch_num_edges
(
"plays"
),
bg
.
batch_num_edges
(
"plays"
)
...
@@ -1673,7 +1673,7 @@ def test_remove_edges(idtype):
...
@@ -1673,7 +1673,7 @@ def test_remove_edges(idtype):
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
2
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
2
,
0
,
1
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_edges
(
bg
,
[
0
,
1
,
3
],
etype
=
"follows"
)
bg_r
=
dgl
.
remove_edges
(
bg
,
[
0
,
1
,
3
],
etype
=
"follows"
)
...
@@ -1681,7 +1681,7 @@ def test_remove_edges(idtype):
...
@@ -1681,7 +1681,7 @@ def test_remove_edges(idtype):
for
nty
in
ntypes
:
for
nty
in
ntypes
:
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
nty
),
bg_r
.
batch_num_nodes
(
nty
))
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
nty
),
bg_r
.
batch_num_nodes
(
nty
))
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_edges
(
"plays"
),
bg_r
.
batch_num_edges
(
"plays"
)
bg
.
batch_num_edges
(
"plays"
),
bg_r
.
batch_num_edges
(
"plays"
)
...
@@ -1695,7 +1695,7 @@ def test_remove_edges(idtype):
...
@@ -1695,7 +1695,7 @@ def test_remove_edges(idtype):
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_edges
(
bg_r
=
dgl
.
remove_edges
(
...
@@ -1705,7 +1705,7 @@ def test_remove_edges(idtype):
...
@@ -1705,7 +1705,7 @@ def test_remove_edges(idtype):
for
nty
in
ntypes
:
for
nty
in
ntypes
:
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
nty
),
bg_r
.
batch_num_nodes
(
nty
))
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
nty
),
bg_r
.
batch_num_nodes
(
nty
))
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_edges
(
"plays"
),
bg_r
.
batch_num_edges
(
"plays"
)
bg
.
batch_num_edges
(
"plays"
),
bg_r
.
batch_num_edges
(
"plays"
)
...
@@ -1719,7 +1719,7 @@ def test_remove_edges(idtype):
...
@@ -1719,7 +1719,7 @@ def test_remove_edges(idtype):
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
idtype
)
)
)
...
@@ -1847,28 +1847,28 @@ def test_remove_nodes(idtype):
...
@@ -1847,28 +1847,28 @@ def test_remove_nodes(idtype):
bg_r
=
dgl
.
remove_nodes
(
bg
,
1
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
1
)
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(),
F
.
tensor
([
4
,
0
,
5
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(),
F
.
tensor
([
4
,
0
,
5
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(),
F
.
tensor
([
0
,
0
,
3
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(),
F
.
tensor
([
0
,
0
,
3
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
[
1
,
7
])
bg_r
=
dgl
.
remove_nodes
(
bg
,
[
1
,
7
])
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(),
F
.
tensor
([
4
,
0
,
4
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(),
F
.
tensor
([
4
,
0
,
4
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(),
F
.
tensor
([
0
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(),
F
.
tensor
([
0
,
0
,
1
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
F
.
tensor
([
1
,
7
],
dtype
=
idtype
))
bg_r
=
dgl
.
remove_nodes
(
bg
,
F
.
tensor
([
1
,
7
],
dtype
=
idtype
))
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(),
F
.
tensor
([
4
,
0
,
4
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(),
F
.
tensor
([
4
,
0
,
4
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(),
F
.
tensor
([
0
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(),
F
.
tensor
([
0
,
0
,
1
],
dtype
=
idtype
)
)
)
# batched heterogeneous graph
# batched heterogeneous graph
...
@@ -1902,16 +1902,16 @@ def test_remove_nodes(idtype):
...
@@ -1902,16 +1902,16 @@ def test_remove_nodes(idtype):
bg_r
=
dgl
.
remove_nodes
(
bg
,
1
,
ntype
=
"user"
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
1
,
ntype
=
"user"
)
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(
"user"
),
F
.
tensor
([
3
,
6
,
3
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(
"user"
),
F
.
tensor
([
3
,
6
,
3
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
"game"
),
bg_r
.
batch_num_nodes
(
"game"
)
bg
.
batch_num_nodes
(
"game"
),
bg_r
.
batch_num_nodes
(
"game"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
2
,
0
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
2
,
0
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
2
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
6
,
ntype
=
"game"
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
6
,
ntype
=
"game"
)
...
@@ -1920,28 +1920,28 @@ def test_remove_nodes(idtype):
...
@@ -1920,28 +1920,28 @@ def test_remove_nodes(idtype):
bg
.
batch_num_nodes
(
"user"
),
bg_r
.
batch_num_nodes
(
"user"
)
bg
.
batch_num_nodes
(
"user"
),
bg_r
.
batch_num_nodes
(
"user"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(
"game"
),
F
.
tensor
([
3
,
2
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(
"game"
),
F
.
tensor
([
3
,
2
,
2
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
2
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
2
,
0
,
1
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
[
1
,
5
,
6
,
11
],
ntype
=
"user"
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
[
1
,
5
,
6
,
11
],
ntype
=
"user"
)
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(
"user"
),
F
.
tensor
([
3
,
4
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(
"user"
),
F
.
tensor
([
3
,
4
,
2
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
"game"
),
bg_r
.
batch_num_nodes
(
"game"
)
bg
.
batch_num_nodes
(
"game"
),
bg_r
.
batch_num_nodes
(
"game"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
[
0
,
3
,
4
,
7
],
ntype
=
"game"
)
bg_r
=
dgl
.
remove_nodes
(
bg
,
[
0
,
3
,
4
,
7
],
ntype
=
"game"
)
...
@@ -1950,13 +1950,13 @@ def test_remove_nodes(idtype):
...
@@ -1950,13 +1950,13 @@ def test_remove_nodes(idtype):
bg
.
batch_num_nodes
(
"user"
),
bg_r
.
batch_num_nodes
(
"user"
)
bg
.
batch_num_nodes
(
"user"
),
bg_r
.
batch_num_nodes
(
"user"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(
"game"
),
F
.
tensor
([
2
,
0
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(
"game"
),
F
.
tensor
([
2
,
0
,
2
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_nodes
(
bg_r
=
dgl
.
remove_nodes
(
...
@@ -1964,16 +1964,16 @@ def test_remove_nodes(idtype):
...
@@ -1964,16 +1964,16 @@ def test_remove_nodes(idtype):
)
)
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
bg_r
.
batch_size
==
bg
.
batch_size
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(
"user"
),
F
.
tensor
([
3
,
4
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(
"user"
),
F
.
tensor
([
3
,
4
,
2
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_nodes
(
"game"
),
bg_r
.
batch_num_nodes
(
"game"
)
bg
.
batch_num_nodes
(
"game"
),
bg_r
.
batch_num_nodes
(
"game"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"follows"
),
F
.
tensor
([
0
,
1
,
0
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
idtype
)
)
)
bg_r
=
dgl
.
remove_nodes
(
bg_r
=
dgl
.
remove_nodes
(
...
@@ -1984,13 +1984,13 @@ def test_remove_nodes(idtype):
...
@@ -1984,13 +1984,13 @@ def test_remove_nodes(idtype):
bg
.
batch_num_nodes
(
"user"
),
bg_r
.
batch_num_nodes
(
"user"
)
bg
.
batch_num_nodes
(
"user"
),
bg_r
.
batch_num_nodes
(
"user"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_nodes
(
"game"
),
F
.
tensor
([
2
,
0
,
2
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_nodes
(
"game"
),
F
.
tensor
([
2
,
0
,
2
],
dtype
=
idtype
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
bg
.
batch_num_edges
(
"follows"
),
bg_r
.
batch_num_edges
(
"follows"
)
)
)
assert
F
.
array_equal
(
assert
F
.
array_equal
(
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
F
.
int64
)
bg_r
.
batch_num_edges
(
"plays"
),
F
.
tensor
([
1
,
0
,
1
],
dtype
=
idtype
)
)
)
...
@@ -2247,13 +2247,13 @@ def test_remove_selfloop(idtype):
...
@@ -2247,13 +2247,13 @@ def test_remove_selfloop(idtype):
idtype
=
idtype
,
idtype
=
idtype
,
device
=
F
.
ctx
(),
device
=
F
.
ctx
(),
)
)
g
.
set_batch_num_nodes
(
F
.
tensor
([
3
,
2
],
dtype
=
F
.
int64
)
)
g
.
set_batch_num_nodes
(
[
3
,
2
]
)
g
.
set_batch_num_edges
(
F
.
tensor
([
4
,
3
],
dtype
=
F
.
int64
)
)
g
.
set_batch_num_edges
(
[
4
,
3
]
)
g
=
dgl
.
remove_self_loop
(
g
)
g
=
dgl
.
remove_self_loop
(
g
)
assert
g
.
num_nodes
()
==
5
assert
g
.
num_nodes
()
==
5
assert
g
.
num_edges
()
==
3
assert
g
.
num_edges
()
==
3
assert
F
.
array_equal
(
g
.
batch_num_nodes
(),
F
.
tensor
([
3
,
2
],
dtype
=
F
.
int64
))
assert
F
.
array_equal
(
g
.
batch_num_nodes
(),
F
.
tensor
([
3
,
2
],
dtype
=
idtype
))
assert
F
.
array_equal
(
g
.
batch_num_edges
(),
F
.
tensor
([
2
,
1
],
dtype
=
F
.
int64
))
assert
F
.
array_equal
(
g
.
batch_num_edges
(),
F
.
tensor
([
2
,
1
],
dtype
=
idtype
))
@
parametrize_idtype
@
parametrize_idtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment