"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1f9ae6686e7762a518172e5ec673508e3b706300"
Unverified Commit 0024c7e1 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] return batch related ids in g.idtype (#6578)

parent 7643e537
...@@ -750,7 +750,7 @@ class DGLGraph(object): ...@@ -750,7 +750,7 @@ class DGLGraph(object):
c_etype_batch_num_edges, one_hot_removed_edges, reducer="sum" c_etype_batch_num_edges, one_hot_removed_edges, reducer="sum"
) )
self._batch_num_edges[c_etype] = c_etype_batch_num_edges - F.astype( self._batch_num_edges[c_etype] = c_etype_batch_num_edges - F.astype(
batch_num_removed_edges, F.int64 batch_num_removed_edges, self.idtype
) )
sub_g = self.edge_subgraph( sub_g = self.edge_subgraph(
...@@ -890,7 +890,7 @@ class DGLGraph(object): ...@@ -890,7 +890,7 @@ class DGLGraph(object):
self._batch_num_nodes[ self._batch_num_nodes[
target_ntype target_ntype
] = c_ntype_batch_num_nodes - F.astype( ] = c_ntype_batch_num_nodes - F.astype(
batch_num_removed_nodes, F.int64 batch_num_removed_nodes, self.idtype
) )
# Record old num_edges to check later whether some edges were removed # Record old num_edges to check later whether some edges were removed
old_num_edges = { old_num_edges = {
...@@ -917,7 +917,7 @@ class DGLGraph(object): ...@@ -917,7 +917,7 @@ class DGLGraph(object):
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
if self._graph.num_edges(self.get_etype_id(c_etype)) == 0: if self._graph.num_edges(self.get_etype_id(c_etype)) == 0:
self._batch_num_edges[c_etype] = F.zeros( self._batch_num_edges[c_etype] = F.zeros(
(self.batch_size,), F.int64, self.device (self.batch_size,), self.idtype, self.device
) )
continue continue
...@@ -936,7 +936,7 @@ class DGLGraph(object): ...@@ -936,7 +936,7 @@ class DGLGraph(object):
reducer="sum", reducer="sum",
) )
self._batch_num_edges[c_etype] = F.astype( self._batch_num_edges[c_etype] = F.astype(
batch_num_left_edges, F.int64 batch_num_left_edges, self.idtype
) )
if batched and not store_ids: if batched and not store_ids:
...@@ -1511,7 +1511,7 @@ class DGLGraph(object): ...@@ -1511,7 +1511,7 @@ class DGLGraph(object):
self._batch_num_nodes = {} self._batch_num_nodes = {}
for ty in self.ntypes: for ty in self.ntypes:
bnn = F.copy_to( bnn = F.copy_to(
F.tensor([self.num_nodes(ty)], F.int64), self.device F.tensor([self.num_nodes(ty)], self.idtype), self.device
) )
self._batch_num_nodes[ty] = bnn self._batch_num_nodes[ty] = bnn
if ntype is None: if ntype is None:
...@@ -1601,6 +1601,7 @@ class DGLGraph(object): ...@@ -1601,6 +1601,7 @@ class DGLGraph(object):
batch batch
unbatch unbatch
""" """
val = utils.prepare_tensor_or_dict(self, val, "batch_num_nodes")
if not isinstance(val, Mapping): if not isinstance(val, Mapping):
if len(self.ntypes) != 1: if len(self.ntypes) != 1:
raise DGLError( raise DGLError(
...@@ -1660,7 +1661,7 @@ class DGLGraph(object): ...@@ -1660,7 +1661,7 @@ class DGLGraph(object):
self._batch_num_edges = {} self._batch_num_edges = {}
for ty in self.canonical_etypes: for ty in self.canonical_etypes:
bne = F.copy_to( bne = F.copy_to(
F.tensor([self.num_edges(ty)], F.int64), self.device F.tensor([self.num_edges(ty)], self.idtype), self.device
) )
self._batch_num_edges[ty] = bne self._batch_num_edges[ty] = bne
if etype is None: if etype is None:
...@@ -1752,6 +1753,7 @@ class DGLGraph(object): ...@@ -1752,6 +1753,7 @@ class DGLGraph(object):
batch batch
unbatch unbatch
""" """
val = utils.prepare_tensor_or_dict(self, val, "batch_num_edges")
if not isinstance(val, Mapping): if not isinstance(val, Mapping):
if len(self.etypes) != 1: if len(self.etypes) != 1:
raise DGLError( raise DGLError(
......
...@@ -1608,21 +1608,21 @@ def test_remove_edges(idtype): ...@@ -1608,21 +1608,21 @@ def test_remove_edges(idtype):
assert bg.batch_size == bg_r.batch_size assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes()) assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([2, 0, 2], dtype=F.int64) bg_r.batch_num_edges(), F.tensor([2, 0, 2], dtype=idtype)
) )
bg_r = dgl.remove_edges(bg, [0, 2]) bg_r = dgl.remove_edges(bg, [0, 2])
assert bg.batch_size == bg_r.batch_size assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes()) assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64) bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=idtype)
) )
bg_r = dgl.remove_edges(bg, F.tensor([0, 2], dtype=idtype)) bg_r = dgl.remove_edges(bg, F.tensor([0, 2], dtype=idtype))
assert bg.batch_size == bg_r.batch_size assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes()) assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64) bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=idtype)
) )
# batched heterogeneous graph # batched heterogeneous graph
...@@ -1659,7 +1659,7 @@ def test_remove_edges(idtype): ...@@ -1659,7 +1659,7 @@ def test_remove_edges(idtype):
for nty in ntypes: for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([1, 2, 0], dtype=F.int64) bg_r.batch_num_edges("follows"), F.tensor([1, 2, 0], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), bg.batch_num_edges("plays") bg_r.batch_num_edges("plays"), bg.batch_num_edges("plays")
...@@ -1673,7 +1673,7 @@ def test_remove_edges(idtype): ...@@ -1673,7 +1673,7 @@ def test_remove_edges(idtype):
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows") bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=idtype)
) )
bg_r = dgl.remove_edges(bg, [0, 1, 3], etype="follows") bg_r = dgl.remove_edges(bg, [0, 1, 3], etype="follows")
...@@ -1681,7 +1681,7 @@ def test_remove_edges(idtype): ...@@ -1681,7 +1681,7 @@ def test_remove_edges(idtype):
for nty in ntypes: for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64) bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_edges("plays"), bg_r.batch_num_edges("plays") bg.batch_num_edges("plays"), bg_r.batch_num_edges("plays")
...@@ -1695,7 +1695,7 @@ def test_remove_edges(idtype): ...@@ -1695,7 +1695,7 @@ def test_remove_edges(idtype):
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows") bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
) )
bg_r = dgl.remove_edges( bg_r = dgl.remove_edges(
...@@ -1705,7 +1705,7 @@ def test_remove_edges(idtype): ...@@ -1705,7 +1705,7 @@ def test_remove_edges(idtype):
for nty in ntypes: for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty)) assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64) bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_edges("plays"), bg_r.batch_num_edges("plays") bg.batch_num_edges("plays"), bg_r.batch_num_edges("plays")
...@@ -1719,7 +1719,7 @@ def test_remove_edges(idtype): ...@@ -1719,7 +1719,7 @@ def test_remove_edges(idtype):
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows") bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
) )
...@@ -1847,28 +1847,28 @@ def test_remove_nodes(idtype): ...@@ -1847,28 +1847,28 @@ def test_remove_nodes(idtype):
bg_r = dgl.remove_nodes(bg, 1) bg_r = dgl.remove_nodes(bg, 1)
assert bg_r.batch_size == bg.batch_size assert bg_r.batch_size == bg.batch_size
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes(), F.tensor([4, 0, 5], dtype=F.int64) bg_r.batch_num_nodes(), F.tensor([4, 0, 5], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([0, 0, 3], dtype=F.int64) bg_r.batch_num_edges(), F.tensor([0, 0, 3], dtype=idtype)
) )
bg_r = dgl.remove_nodes(bg, [1, 7]) bg_r = dgl.remove_nodes(bg, [1, 7])
assert bg_r.batch_size == bg.batch_size assert bg_r.batch_size == bg.batch_size
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64) bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64) bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=idtype)
) )
bg_r = dgl.remove_nodes(bg, F.tensor([1, 7], dtype=idtype)) bg_r = dgl.remove_nodes(bg, F.tensor([1, 7], dtype=idtype))
assert bg_r.batch_size == bg.batch_size assert bg_r.batch_size == bg.batch_size
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64) bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64) bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=idtype)
) )
# batched heterogeneous graph # batched heterogeneous graph
...@@ -1902,16 +1902,16 @@ def test_remove_nodes(idtype): ...@@ -1902,16 +1902,16 @@ def test_remove_nodes(idtype):
bg_r = dgl.remove_nodes(bg, 1, ntype="user") bg_r = dgl.remove_nodes(bg, 1, ntype="user")
assert bg_r.batch_size == bg.batch_size assert bg_r.batch_size == bg.batch_size
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes("user"), F.tensor([3, 6, 3], dtype=F.int64) bg_r.batch_num_nodes("user"), F.tensor([3, 6, 3], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game") bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 2, 0], dtype=F.int64) bg_r.batch_num_edges("follows"), F.tensor([0, 2, 0], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 2], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([1, 0, 2], dtype=idtype)
) )
bg_r = dgl.remove_nodes(bg, 6, ntype="game") bg_r = dgl.remove_nodes(bg, 6, ntype="game")
...@@ -1920,28 +1920,28 @@ def test_remove_nodes(idtype): ...@@ -1920,28 +1920,28 @@ def test_remove_nodes(idtype):
bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user") bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes("game"), F.tensor([3, 2, 2], dtype=F.int64) bg_r.batch_num_nodes("game"), F.tensor([3, 2, 2], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows") bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=idtype)
) )
bg_r = dgl.remove_nodes(bg, [1, 5, 6, 11], ntype="user") bg_r = dgl.remove_nodes(bg, [1, 5, 6, 11], ntype="user")
assert bg_r.batch_size == bg.batch_size assert bg_r.batch_size == bg.batch_size
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=F.int64) bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game") bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64) bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
) )
bg_r = dgl.remove_nodes(bg, [0, 3, 4, 7], ntype="game") bg_r = dgl.remove_nodes(bg, [0, 3, 4, 7], ntype="game")
...@@ -1950,13 +1950,13 @@ def test_remove_nodes(idtype): ...@@ -1950,13 +1950,13 @@ def test_remove_nodes(idtype):
bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user") bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=F.int64) bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows") bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
) )
bg_r = dgl.remove_nodes( bg_r = dgl.remove_nodes(
...@@ -1964,16 +1964,16 @@ def test_remove_nodes(idtype): ...@@ -1964,16 +1964,16 @@ def test_remove_nodes(idtype):
) )
assert bg_r.batch_size == bg.batch_size assert bg_r.batch_size == bg.batch_size
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=F.int64) bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game") bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64) bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
) )
bg_r = dgl.remove_nodes( bg_r = dgl.remove_nodes(
...@@ -1984,13 +1984,13 @@ def test_remove_nodes(idtype): ...@@ -1984,13 +1984,13 @@ def test_remove_nodes(idtype):
bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user") bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=F.int64) bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=idtype)
) )
assert F.array_equal( assert F.array_equal(
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows") bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
) )
assert F.array_equal( assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64) bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
) )
...@@ -2247,13 +2247,13 @@ def test_remove_selfloop(idtype): ...@@ -2247,13 +2247,13 @@ def test_remove_selfloop(idtype):
idtype=idtype, idtype=idtype,
device=F.ctx(), device=F.ctx(),
) )
g.set_batch_num_nodes(F.tensor([3, 2], dtype=F.int64)) g.set_batch_num_nodes([3, 2])
g.set_batch_num_edges(F.tensor([4, 3], dtype=F.int64)) g.set_batch_num_edges([4, 3])
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
assert g.num_nodes() == 5 assert g.num_nodes() == 5
assert g.num_edges() == 3 assert g.num_edges() == 3
assert F.array_equal(g.batch_num_nodes(), F.tensor([3, 2], dtype=F.int64)) assert F.array_equal(g.batch_num_nodes(), F.tensor([3, 2], dtype=idtype))
assert F.array_equal(g.batch_num_edges(), F.tensor([2, 1], dtype=F.int64)) assert F.array_equal(g.batch_num_edges(), F.tensor([2, 1], dtype=idtype))
@parametrize_idtype @parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment