"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7de2e51b5ec9f21685df56be42e41b5b3e6938a8"
Unverified Commit 2446f2fd authored by caojy1998's avatar caojy1998 Committed by GitHub
Browse files

[Test] Add testcase for function AddSelfLoop (#5858)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-6-31.ap-northeast-1.compute.internal>
parent b35f15b7
...@@ -2427,7 +2427,7 @@ def test_module_add_self_loop(idtype): ...@@ -2427,7 +2427,7 @@ def test_module_add_self_loop(idtype):
assert "h" in new_g.ndata assert "h" in new_g.ndata
assert "w" in new_g.edata assert "w" in new_g.edata
# Case2: Remove self-loops first to avoid duplicate ones # Case2: remove self-loops first to avoid duplicate ones
transform = dgl.AddSelfLoop(allow_duplicate=True) transform = dgl.AddSelfLoop(allow_duplicate=True)
new_g = transform(g) new_g = transform(g)
assert new_g.device == g.device assert new_g.device == g.device
...@@ -2440,6 +2440,17 @@ def test_module_add_self_loop(idtype): ...@@ -2440,6 +2440,17 @@ def test_module_add_self_loop(idtype):
assert "h" in new_g.ndata assert "h" in new_g.ndata
assert "w" in new_g.edata assert "w" in new_g.edata
# Case3: add self-loops for a homogeneous graph (the example in doc)
transform = dgl.AddSelfLoop(fill_data="sum")
g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx())
new_g = transform(g)
assert new_g.device == g.device
assert new_g.idtype == g.idtype
assert new_g.num_nodes() == g.num_nodes()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 2), (0, 1), (2, 0), (0, 0), (1, 1), (2, 2)}
# Create a heterogeneous graph # Create a heterogeneous graph
g = dgl.heterograph( g = dgl.heterograph(
{ {
...@@ -2454,7 +2465,7 @@ def test_module_add_self_loop(idtype): ...@@ -2454,7 +2465,7 @@ def test_module_add_self_loop(idtype):
g.nodes["game"].data["h2"] = F.randn((2, 4)) g.nodes["game"].data["h2"] = F.randn((2, 4))
g.edges["follows"].data["w2"] = F.randn((1, 5)) g.edges["follows"].data["w2"] = F.randn((1, 5))
# Case3: add self-loops for a heterogeneous graph # Case4: add self-loops for a heterogeneous graph
new_g = transform(g) new_g = transform(g)
assert new_g.device == g.device assert new_g.device == g.device
assert new_g.idtype == g.idtype assert new_g.idtype == g.idtype
...@@ -2469,7 +2480,7 @@ def test_module_add_self_loop(idtype): ...@@ -2469,7 +2480,7 @@ def test_module_add_self_loop(idtype):
assert "w1" in new_g.edges["plays"].data assert "w1" in new_g.edges["plays"].data
assert "w2" in new_g.edges["follows"].data assert "w2" in new_g.edges["follows"].data
# Case4: add self-etypes for a heterogeneous graph # Case5: add self-etypes for a heterogeneous graph
transform = dgl.AddSelfLoop(new_etypes=True) transform = dgl.AddSelfLoop(new_etypes=True)
new_g = transform(g) new_g = transform(g)
assert new_g.device == g.device assert new_g.device == g.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment