Commit 91ea59fc authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

bugfix - versions with and without bp_to_u working

parent fb2e852d
......@@ -31,10 +31,16 @@ def spline_conv(
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
if output.is_cuda:
output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_kernel, basis_backward_kernel,
weighting_kernel, weighting_backward_kernel,
bp_to_adj)(output, weight[:-1], values)
if bp_to_adj:
output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_kernel, basis_backward_kernel,
weighting_kernel, weighting_backward_kernel,
bp_to_adj)(output, weight[:-1], values)
else:
output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_kernel, basis_backward_kernel,
weighting_kernel, weighting_backward_kernel,
bp_to_adj, values)(output, weight[:-1])
else:
# CPU Implementation not available
raise NotImplementedError()
......@@ -50,8 +56,8 @@ def spline_conv(
output += torch.mm(input, weight[-1])
# Normalize output by degree.
ones = values.data.new(values.size(0)).fill_(1)
zero = values.data.new(output.size(0)).fill_(0)
ones = output.data.new(values.size(0)).fill_(1)
zero = output.data.new(output.size(0)).fill_(0)
degree = zero.scatter_add_(0, row, ones)
degree = torch.clamp(degree, min=1)
output = output / Variable(degree.view(-1, 1))
......
......@@ -471,7 +471,7 @@ class SplineConvGPU(Function):
def __init__(self, kernel_size, is_open_spline, K, degree,
basis_kernel, basis_backward_kernel,
weighting_kernel, weighting_backward_kernel,
bp_to_adj=False):
bp_to_adj=False, adj_values=None):
super(SplineConvGPU, self).__init__()
self.degree = degree
self.f_weighting_fw = weighting_kernel
......@@ -481,11 +481,15 @@ class SplineConvGPU(Function):
self.f_basis_fw = basis_kernel
self.f_basis_bw = basis_backward_kernel
self.bp_to_adj = bp_to_adj
self.adj_values = adj_values
def forward(self, input, weight, adj_values):
def forward(self, input, weight, adj_values=None):
assert input.is_cuda and weight.is_cuda
self.K, self.M_in, self.M_out = weight.size()
# If bp_to_u is false
if adj_values is None:
adj_values = self.adj_values
# Compute B-spline basis tensor products
adj_values = adj_values.unsqueeze(1) if len(adj_values.size()) < 2 \
else adj_values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment