Unverified Commit c3103b62 authored by Yongyi's avatar Yongyi Committed by GitHub
Browse files

[BugFix] Fix a bug in TWIRLS, add unittest (#3573)



* Fix a bug in TWIRLS, add unittest

* reformatting the code

* modify unittest for TWIRLS
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent a3ce780d
...@@ -6,7 +6,6 @@ import torch.nn as nn ...@@ -6,7 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .... import function as fn from .... import function as fn
class TWIRLSConv(nn.Module): class TWIRLSConv(nn.Module):
r""" r"""
...@@ -74,13 +73,8 @@ class TWIRLSConv(nn.Module): ...@@ -74,13 +73,8 @@ class TWIRLSConv(nn.Module):
>>> feat = th.ones(6, 10) >>> feat = th.ones(6, 10)
>>> conv = TWIRLSConv(10, 2, 128, prop_step = 64) >>> conv = TWIRLSConv(10, 2, 128, prop_step = 64)
>>> res = conv(g , feat) >>> res = conv(g , feat)
>>> res >>> res.size()
tensor([[ 0.4556, -2.6692], torch.Size([6, 2])
[ 0.4556, -2.6692],
[ 0.4556, -2.6692],
[ 1.0112, -5.9241],
[ 0.8011, -4.6935],
[ 0.8844, -5.1814]], grad_fn=<AddmmBackward>)
""" """
def __init__(self, def __init__(self,
...@@ -148,10 +142,10 @@ class TWIRLSConv(nn.Module): ...@@ -148,10 +142,10 @@ class TWIRLSConv(nn.Module):
self.mlp_bef = MLP(self.input_d, self.hidden_d, self.size_bef_unf, self.num_mlp_before, self.mlp_bef = MLP(self.input_d, self.hidden_d, self.size_bef_unf, self.num_mlp_before,
self.dropout, self.norm, init_activate=False) self.dropout, self.norm, init_activate=False)
self.unfolding = UnfoldingAndAttention(self.hidden_d, self.alp, self.lam, self.prop_step, self.unfolding = TWIRLSUnfoldingAndAttention(self.hidden_d, self.alp, self.lam,
self.attn_aft, self.tau, self.T, self.p, self.prop_step, self.attn_aft, self.tau,
self.use_eta, self.init_att, self.attn_dropout, self.T, self.p, self.use_eta, self.init_att,
self.precond) self.attn_dropout, self.precond)
# if there are really transformations before unfolding, then do init_activate in mlp_aft # if there are really transformations before unfolding, then do init_activate in mlp_aft
self.mlp_aft = MLP(self.size_aft_unf, self.hidden_d, self.output_d, self.num_mlp_after, self.mlp_aft = MLP(self.size_aft_unf, self.hidden_d, self.output_d, self.num_mlp_after,
......
...@@ -1323,6 +1323,15 @@ def test_edge_predictor(op): ...@@ -1323,6 +1323,15 @@ def test_edge_predictor(op):
pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx) pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
assert pred(h_src, h_dst).shape == (num_pairs, out_feats) assert pred(h_src, h_dst).shape == (num_pairs, out_feats)
def test_twirls():
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
feat = th.ones(6, 10)
conv = nn.TWIRLSConv(10, 2, 128, prop_step = 64)
res = conv(g , feat)
assert ( res.size() == (6,2) )
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_graph_conv_e_weight() test_graph_conv_e_weight()
...@@ -1354,3 +1363,4 @@ if __name__ == '__main__': ...@@ -1354,3 +1363,4 @@ if __name__ == '__main__':
test_atomic_conv() test_atomic_conv()
test_cf_conv() test_cf_conv()
test_hetero_conv() test_hetero_conv()
test_twirls()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment