Commit 4b92ba74 authored by rusty1s's avatar rusty1s
Browse files

lots of bug fixes

parent dc7f15fd
......@@ -8,6 +8,7 @@ def edgewise_spline_weighting(input, weight, amount, index):
if input.is_cuda:
K, M_in, M_out = weight.size()
k_max = amount.size(1)
return EdgewiseSplineWeightingGPU(amount, index, K, M_in, M_out, k_max)(input, weight)
return EdgewiseSplineWeightingGPU(amount, index, K, M_in, M_out,
k_max)(input, weight)
else:
raise NotImplementedError
......@@ -16,6 +16,9 @@ def spline_conv(
degree=1,
bias=None):
if input.dim() == 1:
input = input.unsqueeze(1)
values = adj._values()
row, col = adj._indices()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment