"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4dce43ccaa3dd0295498bf289490fdb6165fbeb5"
Unverified Commit 8b19c287 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update test_transform.py (#4190)

parent 32f12ee1
...@@ -2356,6 +2356,8 @@ def test_module_laplacian_pe(idtype): ...@@ -2356,6 +2356,8 @@ def test_module_laplacian_pe(idtype):
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'])) @pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature']))
def test_module_sign(g): def test_module_sign(g):
import torch import torch
atol = 1e-06
ctx = F.ctx() ctx = F.ctx()
g = g.to(ctx) g = g.to(ctx)
...@@ -2372,25 +2374,25 @@ def test_module_sign(g): ...@@ -2372,25 +2374,25 @@ def test_module_sign(g):
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw')
g = transform(g) g = transform(g)
target = torch.matmul(adj, g.ndata['h']) target = torch.matmul(adj, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw')
g = transform(g) g = transform(g)
target = torch.matmul(weight_adj, g.ndata['h']) target = torch.matmul(weight_adj, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
# rw # rw
adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj) adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw')
g = transform(g) g = transform(g)
target = torch.matmul(adj_rw, g.ndata['h']) target = torch.matmul(adj_rw, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj) weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw')
g = transform(g) g = transform(g)
target = torch.matmul(weight_adj_rw, g.ndata['h']) target = torch.matmul(weight_adj_rw, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
# gcn # gcn
raw_eweight = g.edata['scalar_w'] raw_eweight = g.edata['scalar_w']
...@@ -2401,7 +2403,7 @@ def test_module_sign(g): ...@@ -2401,7 +2403,7 @@ def test_module_sign(g):
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn') transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn')
g = transform(g) g = transform(g)
target = torch.matmul(adj_gcn, g.ndata['h']) target = torch.matmul(adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
gcn_norm = dgl.GCNNorm('scalar_w') gcn_norm = dgl.GCNNorm('scalar_w')
g = gcn_norm(g) g = gcn_norm(g)
...@@ -2412,20 +2414,20 @@ def test_module_sign(g): ...@@ -2412,20 +2414,20 @@ def test_module_sign(g):
eweight_name='scalar_w', diffuse_op='gcn') eweight_name='scalar_w', diffuse_op='gcn')
g = transform(g) g = transform(g)
target = torch.matmul(weight_adj_gcn, g.ndata['h']) target = torch.matmul(weight_adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
# ppr # ppr
alpha = 0.2 alpha = 0.2
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha) transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha)
g = transform(g) g = transform(g)
target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w',
diffuse_op='ppr', alpha=alpha) diffuse_op='ppr', alpha=alpha)
g = transform(g) g = transform(g)
target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target) assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype @parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment