"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "061163142ded92425ef9d6aafabe34c6416806e1"
Unverified Commit 98adca68 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Transform] Fix graph structure corruption with transform (#4753)



* Update

* Update

* Update

* Update

* Update test_transform.py
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-26.ap-northeast-1.compute.internal>
parent ea4d9e83
......@@ -1397,6 +1397,7 @@ class NodeShuffle(BaseTransform):
[ 7., 8.]])
"""
def __call__(self, g):
g = g.clone()
for ntype in g.ntypes:
nids = F.astype(g.nodes(ntype), F.int64)
perm = F.rand_shuffle(nids)
......@@ -1440,6 +1441,8 @@ class DropNode(BaseTransform):
self.dist = Bernoulli(p)
def __call__(self, g):
g = g.clone()
# Fast path
if self.p == 0:
return g
......@@ -1485,6 +1488,8 @@ class DropEdge(BaseTransform):
self.dist = Bernoulli(p)
def __call__(self, g):
g = g.clone()
# Fast path
if self.p == 0:
return g
......@@ -1526,6 +1531,7 @@ class AddEdge(BaseTransform):
device = g.device
idtype = g.idtype
g = g.clone()
for c_etype in g.canonical_etypes:
utype, _, vtype = c_etype
num_edges_to_add = int(g.num_edges(c_etype) * self.ratio)
......
......@@ -2379,13 +2379,18 @@ def test_module_gdc(idtype):
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4), (5, 5)}
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support a slicing operation")
@parametrize_idtype
def test_module_node_shuffle(idtype):
transform = dgl.NodeShuffle()
g = dgl.heterograph({
('A', 'r', 'B'): ([0, 1], [1, 2]),
}, idtype=idtype, device=F.ctx())
g.nodes['B'].data['h'] = F.randn((g.num_nodes('B'), 2))
old_nfeat = g.nodes['B'].data['h']
new_g = transform(g)
new_nfeat = g.nodes['B'].data['h']
assert F.allclose(old_nfeat, new_nfeat)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype
......@@ -2394,11 +2399,15 @@ def test_module_drop_node(idtype):
g = dgl.heterograph({
('A', 'r', 'B'): ([0, 1], [1, 2]),
}, idtype=idtype, device=F.ctx())
num_nodes_old = g.num_nodes()
new_g = transform(g)
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
num_nodes_new = g.num_nodes()
# Ensure that the original graph is not corrupted
assert num_nodes_old == num_nodes_new
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype
......@@ -2408,11 +2417,15 @@ def test_module_drop_edge(idtype):
('A', 'r1', 'B'): ([0, 1], [1, 2]),
('C', 'r2', 'C'): ([3, 4, 5], [6, 7, 8])
}, idtype=idtype, device=F.ctx())
num_edges_old = g.num_edges()
new_g = transform(g)
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
num_edges_new = g.num_edges()
# Ensure that the original graph is not corrupted
assert num_edges_old == num_edges_new
@parametrize_idtype
def test_module_add_edge(idtype):
......@@ -2421,6 +2434,7 @@ def test_module_add_edge(idtype):
('A', 'r1', 'B'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]),
('C', 'r2', 'C'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5])
}, idtype=idtype, device=F.ctx())
num_edges_old = g.num_edges()
new_g = transform(g)
assert new_g.num_edges(('A', 'r1', 'B')) == 6
assert new_g.num_edges(('C', 'r2', 'C')) == 6
......@@ -2428,6 +2442,9 @@ def test_module_add_edge(idtype):
assert new_g.device == g.device
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
num_edges_new = g.num_edges()
# Ensure that the original graph is not corrupted
assert num_edges_old == num_edges_new
@parametrize_idtype
def test_module_random_walk_pe(idtype):
......@@ -2483,7 +2500,7 @@ def test_module_laplacian_pe(idtype):
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature']))
def test_module_sign(g):
import torch
atol = 1e-06
ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment