Unverified Commit 8b19c287 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update test_transform.py (#4190)

parent 32f12ee1
......@@ -2357,6 +2357,8 @@ def test_module_laplacian_pe(idtype):
def test_module_sign(g):
import torch
atol = 1e-06
ctx = F.ctx()
g = g.to(ctx)
adj = g.adj(transpose=True, scipy_fmt='coo').todense()
......@@ -2372,25 +2374,25 @@ def test_module_sign(g):
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw')
g = transform(g)
target = torch.matmul(adj, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw')
g = transform(g)
target = torch.matmul(weight_adj, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
# rw
adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw')
g = transform(g)
target = torch.matmul(adj_rw, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw')
g = transform(g)
target = torch.matmul(weight_adj_rw, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
# gcn
raw_eweight = g.edata['scalar_w']
......@@ -2401,7 +2403,7 @@ def test_module_sign(g):
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn')
g = transform(g)
target = torch.matmul(adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
gcn_norm = dgl.GCNNorm('scalar_w')
g = gcn_norm(g)
......@@ -2412,20 +2414,20 @@ def test_module_sign(g):
eweight_name='scalar_w', diffuse_op='gcn')
g = transform(g)
target = torch.matmul(weight_adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
# ppr
alpha = 0.2
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha)
g = transform(g)
target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w',
diffuse_op='ppr', alpha=alpha)
g = transform(g)
target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment