Unverified Commit 98adca68 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Transform] Fix graph structure corruption with transform (#4753)



* Update

* Update

* Update

* Update

* Update test_transform.py
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-26.ap-northeast-1.compute.internal>
parent ea4d9e83
...@@ -1397,6 +1397,7 @@ class NodeShuffle(BaseTransform): ...@@ -1397,6 +1397,7 @@ class NodeShuffle(BaseTransform):
[ 7., 8.]]) [ 7., 8.]])
""" """
def __call__(self, g): def __call__(self, g):
g = g.clone()
for ntype in g.ntypes: for ntype in g.ntypes:
nids = F.astype(g.nodes(ntype), F.int64) nids = F.astype(g.nodes(ntype), F.int64)
perm = F.rand_shuffle(nids) perm = F.rand_shuffle(nids)
...@@ -1440,6 +1441,8 @@ class DropNode(BaseTransform): ...@@ -1440,6 +1441,8 @@ class DropNode(BaseTransform):
self.dist = Bernoulli(p) self.dist = Bernoulli(p)
def __call__(self, g): def __call__(self, g):
g = g.clone()
# Fast path # Fast path
if self.p == 0: if self.p == 0:
return g return g
...@@ -1485,6 +1488,8 @@ class DropEdge(BaseTransform): ...@@ -1485,6 +1488,8 @@ class DropEdge(BaseTransform):
self.dist = Bernoulli(p) self.dist = Bernoulli(p)
def __call__(self, g): def __call__(self, g):
g = g.clone()
# Fast path # Fast path
if self.p == 0: if self.p == 0:
return g return g
...@@ -1526,6 +1531,7 @@ class AddEdge(BaseTransform): ...@@ -1526,6 +1531,7 @@ class AddEdge(BaseTransform):
device = g.device device = g.device
idtype = g.idtype idtype = g.idtype
g = g.clone()
for c_etype in g.canonical_etypes: for c_etype in g.canonical_etypes:
utype, _, vtype = c_etype utype, _, vtype = c_etype
num_edges_to_add = int(g.num_edges(c_etype) * self.ratio) num_edges_to_add = int(g.num_edges(c_etype) * self.ratio)
......
...@@ -2379,13 +2379,18 @@ def test_module_gdc(idtype): ...@@ -2379,13 +2379,18 @@ def test_module_gdc(idtype):
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4), (5, 5)} assert eset == {(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4), (5, 5)}
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support a slicing operation")
@parametrize_idtype @parametrize_idtype
def test_module_node_shuffle(idtype): def test_module_node_shuffle(idtype):
transform = dgl.NodeShuffle() transform = dgl.NodeShuffle()
g = dgl.heterograph({ g = dgl.heterograph({
('A', 'r', 'B'): ([0, 1], [1, 2]), ('A', 'r', 'B'): ([0, 1], [1, 2]),
}, idtype=idtype, device=F.ctx()) }, idtype=idtype, device=F.ctx())
g.nodes['B'].data['h'] = F.randn((g.num_nodes('B'), 2))
old_nfeat = g.nodes['B'].data['h']
new_g = transform(g) new_g = transform(g)
new_nfeat = g.nodes['B'].data['h']
assert F.allclose(old_nfeat, new_nfeat)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype @parametrize_idtype
...@@ -2394,11 +2399,15 @@ def test_module_drop_node(idtype): ...@@ -2394,11 +2399,15 @@ def test_module_drop_node(idtype):
g = dgl.heterograph({ g = dgl.heterograph({
('A', 'r', 'B'): ([0, 1], [1, 2]), ('A', 'r', 'B'): ([0, 1], [1, 2]),
}, idtype=idtype, device=F.ctx()) }, idtype=idtype, device=F.ctx())
num_nodes_old = g.num_nodes()
new_g = transform(g) new_g = transform(g)
assert new_g.idtype == g.idtype assert new_g.idtype == g.idtype
assert new_g.device == g.device assert new_g.device == g.device
assert new_g.ntypes == g.ntypes assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes assert new_g.canonical_etypes == g.canonical_etypes
num_nodes_new = g.num_nodes()
# Ensure that the original graph is not corrupted
assert num_nodes_old == num_nodes_new
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype @parametrize_idtype
...@@ -2408,11 +2417,15 @@ def test_module_drop_edge(idtype): ...@@ -2408,11 +2417,15 @@ def test_module_drop_edge(idtype):
('A', 'r1', 'B'): ([0, 1], [1, 2]), ('A', 'r1', 'B'): ([0, 1], [1, 2]),
('C', 'r2', 'C'): ([3, 4, 5], [6, 7, 8]) ('C', 'r2', 'C'): ([3, 4, 5], [6, 7, 8])
}, idtype=idtype, device=F.ctx()) }, idtype=idtype, device=F.ctx())
num_edges_old = g.num_edges()
new_g = transform(g) new_g = transform(g)
assert new_g.idtype == g.idtype assert new_g.idtype == g.idtype
assert new_g.device == g.device assert new_g.device == g.device
assert new_g.ntypes == g.ntypes assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes assert new_g.canonical_etypes == g.canonical_etypes
num_edges_new = g.num_edges()
# Ensure that the original graph is not corrupted
assert num_edges_old == num_edges_new
@parametrize_idtype @parametrize_idtype
def test_module_add_edge(idtype): def test_module_add_edge(idtype):
...@@ -2421,6 +2434,7 @@ def test_module_add_edge(idtype): ...@@ -2421,6 +2434,7 @@ def test_module_add_edge(idtype):
('A', 'r1', 'B'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]), ('A', 'r1', 'B'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]),
('C', 'r2', 'C'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]) ('C', 'r2', 'C'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5])
}, idtype=idtype, device=F.ctx()) }, idtype=idtype, device=F.ctx())
num_edges_old = g.num_edges()
new_g = transform(g) new_g = transform(g)
assert new_g.num_edges(('A', 'r1', 'B')) == 6 assert new_g.num_edges(('A', 'r1', 'B')) == 6
assert new_g.num_edges(('C', 'r2', 'C')) == 6 assert new_g.num_edges(('C', 'r2', 'C')) == 6
...@@ -2428,6 +2442,9 @@ def test_module_add_edge(idtype): ...@@ -2428,6 +2442,9 @@ def test_module_add_edge(idtype):
assert new_g.device == g.device assert new_g.device == g.device
assert new_g.ntypes == g.ntypes assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes assert new_g.canonical_etypes == g.canonical_etypes
num_edges_new = g.num_edges()
# Ensure that the original graph is not corrupted
assert num_edges_old == num_edges_new
@parametrize_idtype @parametrize_idtype
def test_module_random_walk_pe(idtype): def test_module_random_walk_pe(idtype):
...@@ -2483,7 +2500,7 @@ def test_module_laplacian_pe(idtype): ...@@ -2483,7 +2500,7 @@ def test_module_laplacian_pe(idtype):
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'])) @pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature']))
def test_module_sign(g): def test_module_sign(g):
import torch import torch
atol = 1e-06 atol = 1e-06
ctx = F.ctx() ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment