Commit 5487e31a authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent d9100e71
import torch
from torch_spline_conv import spline_conv
x = torch.Tensor([[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]])
index = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]])
pseudo = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
pseudo = torch.Tensor(pseudo)
weight = torch.arange(0.5, 0.5 * 25, step=0.5).view(12, 2, 1)
# print(weight[:, 0].squeeze())
kernel_size = torch.LongTensor([3, 4])
is_open_spline = torch.ByteTensor([1, 0])
root_weight = torch.arange(12.5, 13.5, step=0.5).view(2, 1)
output = spline_conv(x, index, pseudo, weight, kernel_size, is_open_spline,
root_weight)
edgewise_output = [
1 * 0.25 * (0.5 + 1.5 + 4.5 + 5.5) + 2 * 0.25 * (1 + 2 + 5 + 6),
3 * 0.25 * (1.5 + 2.5 + 5.5 + 6.5) + 4 * 0.25 * (2 + 3 + 6 + 7),
5 * 0.25 * (6.5 + 7.5 + 10.5 + 11.5) + 6 * 0.25 * (7 + 8 + 11 + 12),
7 * 0.25 * (7.5 + 4.5 + 11.5 + 8.5) + 8 * 0.25 * (8 + 5 + 12 + 9),
]
expected_output = [
[12.5 * 9 + 13 * 10 + sum(edgewise_output) / 4],
[12.5 * 1 + 13 * 2],
[12.5 * 3 + 13 * 4],
[12.5 * 5 + 13 * 6],
[12.5 * 7 + 13 * 8],
]
...@@ -35,7 +35,8 @@ def spline_conv(x, ...@@ -35,7 +35,8 @@ def spline_conv(x,
output = zero.scatter_add_(0, row, output) output = zero.scatter_add_(0, row, output)
# Normalize output by node degree. # Normalize output by node degree.
output /= node_degree(edge_index, n, out=x.new()).clamp_(min=1) degree = node_degree(edge_index, n, out=x.new())
output /= degree.unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
......
...@@ -64,7 +64,7 @@ void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight, ...@@ -64,7 +64,7 @@ void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight,
b = *(basis_data + s * basis_stride); b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride); i = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
value += b * *(weight_data + i * M_in * M_out + m_in * M_in + m_out) * *(input_data + m_in * input_stride); value += b * *(weight_data + i * M_in * M_out + m_in * M_out + m_out) * *(input_data + m_in * input_stride);
} }
} }
output_data[m_out * output_stride] = value; output_data[m_out * output_stride] = value;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment