Unverified Commit c3103b62 authored by Yongyi's avatar Yongyi Committed by GitHub
Browse files

[BugFix] Fix a bug in TWIRLS, add unittest (#3573)



* Fix a bug in TWIRLS, add unittest

* reformatting the code

* modify unittest for TWIRLS
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent a3ce780d
......@@ -6,7 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F
from .... import function as fn
class TWIRLSConv(nn.Module):
r"""
......@@ -74,13 +73,8 @@ class TWIRLSConv(nn.Module):
>>> feat = th.ones(6, 10)
>>> conv = TWIRLSConv(10, 2, 128, prop_step = 64)
>>> res = conv(g , feat)
>>> res
tensor([[ 0.4556, -2.6692],
[ 0.4556, -2.6692],
[ 0.4556, -2.6692],
[ 1.0112, -5.9241],
[ 0.8011, -4.6935],
[ 0.8844, -5.1814]], grad_fn=<AddmmBackward>)
>>> res.size()
torch.Size([6, 2])
"""
def __init__(self,
......@@ -148,10 +142,10 @@ class TWIRLSConv(nn.Module):
self.mlp_bef = MLP(self.input_d, self.hidden_d, self.size_bef_unf, self.num_mlp_before,
self.dropout, self.norm, init_activate=False)
self.unfolding = UnfoldingAndAttention(self.hidden_d, self.alp, self.lam, self.prop_step,
self.attn_aft, self.tau, self.T, self.p,
self.use_eta, self.init_att, self.attn_dropout,
self.precond)
self.unfolding = TWIRLSUnfoldingAndAttention(self.hidden_d, self.alp, self.lam,
self.prop_step, self.attn_aft, self.tau,
self.T, self.p, self.use_eta, self.init_att,
self.attn_dropout, self.precond)
# if there are really transformations before unfolding, then do init_activate in mlp_aft
self.mlp_aft = MLP(self.size_aft_unf, self.hidden_d, self.output_d, self.num_mlp_after,
......
......@@ -1323,6 +1323,15 @@ def test_edge_predictor(op):
pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
assert pred(h_src, h_dst).shape == (num_pairs, out_feats)
def test_twirls():
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
feat = th.ones(6, 10)
conv = nn.TWIRLSConv(10, 2, 128, prop_step = 64)
res = conv(g , feat)
assert ( res.size() == (6,2) )
if __name__ == '__main__':
test_graph_conv()
test_graph_conv_e_weight()
......@@ -1354,3 +1363,4 @@ if __name__ == '__main__':
test_atomic_conv()
test_cf_conv()
test_hetero_conv()
test_twirls()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment