Commit 1038a59e authored by rusty1s's avatar rusty1s
Browse files

degree normalize after root_weight

parent 1b96f53b
...@@ -31,11 +31,11 @@ tests = [{ ...@@ -31,11 +31,11 @@ tests = [{
'root_weight': [[12.5], [13]], 'root_weight': [[12.5], [13]],
'bias': [1], 'bias': [1],
'expected': [ 'expected': [
[1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4], 1 + (12.5 * 9 + 13 * 10 + 8.5 + 40.5 + 107.5 + 101.5) / 5,
[1 + 12.5 * 1 + 13 * 2], 1 + 12.5 * 1 + 13 * 2,
[1 + 12.5 * 3 + 13 * 4], 1 + 12.5 * 3 + 13 * 4,
[1 + 12.5 * 5 + 13 * 6], 1 + 12.5 * 5 + 13 * 6,
[1 + 12.5 * 7 + 13 * 8], 1 + 12.5 * 7 + 13 * 8,
] ]
}] }]
...@@ -53,7 +53,8 @@ def test_spline_conv_forward(test, dtype, device): ...@@ -53,7 +53,8 @@ def test_spline_conv_forward(test, dtype, device):
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size, out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias) is_open_spline, 1, root_weight, bias)
assert out.tolist() == test['expected'] assert list(out.size()) == [5, 1]
assert pytest.approx(out.view(-1).tolist()) == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices)) @pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
......
...@@ -62,13 +62,15 @@ class SplineConv(object): ...@@ -62,13 +62,15 @@ class SplineConv(object):
row_expand = row.unsqueeze(-1).expand_as(out) row_expand = row.unsqueeze(-1).expand_as(out)
out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out) out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)
# Normalize out by node degree.
deg = node_degree(row, n, out.dtype, out.device) deg = node_degree(row, n, out.dtype, out.device)
out /= deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
out += torch.mm(src, root_weight) out += torch.mm(src, root_weight)
deg += 1
# Normalize out by node degree.
out /= deg.unsqueeze(-1).clamp(min=1)
# Add bias (if wished). # Add bias (if wished).
if bias is not None: if bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment