Commit 91ea59fc authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

bugfix - versions with and without bp_to_u working

parent fb2e852d
...@@ -31,10 +31,16 @@ def spline_conv( ...@@ -31,10 +31,16 @@ def spline_conv(
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out]. # Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
if output.is_cuda: if output.is_cuda:
output = SplineConvGPU(kernel_size, is_open_spline, K, degree, if bp_to_adj:
basis_kernel, basis_backward_kernel, output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
weighting_kernel, weighting_backward_kernel, basis_kernel, basis_backward_kernel,
bp_to_adj)(output, weight[:-1], values) weighting_kernel, weighting_backward_kernel,
bp_to_adj)(output, weight[:-1], values)
else:
output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_kernel, basis_backward_kernel,
weighting_kernel, weighting_backward_kernel,
bp_to_adj, values)(output, weight[:-1])
else: else:
# CPU Implementation not available # CPU Implementation not available
raise NotImplementedError() raise NotImplementedError()
...@@ -50,8 +56,8 @@ def spline_conv( ...@@ -50,8 +56,8 @@ def spline_conv(
output += torch.mm(input, weight[-1]) output += torch.mm(input, weight[-1])
# Normalize output by degree. # Normalize output by degree.
ones = values.data.new(values.size(0)).fill_(1) ones = output.data.new(values.size(0)).fill_(1)
zero = values.data.new(output.size(0)).fill_(0) zero = output.data.new(output.size(0)).fill_(0)
degree = zero.scatter_add_(0, row, ones) degree = zero.scatter_add_(0, row, ones)
degree = torch.clamp(degree, min=1) degree = torch.clamp(degree, min=1)
output = output / Variable(degree.view(-1, 1)) output = output / Variable(degree.view(-1, 1))
......
...@@ -471,7 +471,7 @@ class SplineConvGPU(Function): ...@@ -471,7 +471,7 @@ class SplineConvGPU(Function):
def __init__(self, kernel_size, is_open_spline, K, degree, def __init__(self, kernel_size, is_open_spline, K, degree,
basis_kernel, basis_backward_kernel, basis_kernel, basis_backward_kernel,
weighting_kernel, weighting_backward_kernel, weighting_kernel, weighting_backward_kernel,
bp_to_adj=False): bp_to_adj=False, adj_values=None):
super(SplineConvGPU, self).__init__() super(SplineConvGPU, self).__init__()
self.degree = degree self.degree = degree
self.f_weighting_fw = weighting_kernel self.f_weighting_fw = weighting_kernel
...@@ -481,11 +481,15 @@ class SplineConvGPU(Function): ...@@ -481,11 +481,15 @@ class SplineConvGPU(Function):
self.f_basis_fw = basis_kernel self.f_basis_fw = basis_kernel
self.f_basis_bw = basis_backward_kernel self.f_basis_bw = basis_backward_kernel
self.bp_to_adj = bp_to_adj self.bp_to_adj = bp_to_adj
self.adj_values = adj_values
def forward(self, input, weight, adj_values): def forward(self, input, weight, adj_values=None):
assert input.is_cuda and weight.is_cuda assert input.is_cuda and weight.is_cuda
self.K, self.M_in, self.M_out = weight.size() self.K, self.M_in, self.M_out = weight.size()
# If bp_to_u is false
if adj_values is None:
adj_values = self.adj_values
# Compute B-spline basis tensor products # Compute B-spline basis tensor products
adj_values = adj_values.unsqueeze(1) if len(adj_values.size()) < 2 \ adj_values = adj_values.unsqueeze(1) if len(adj_values.size()) < 2 \
else adj_values else adj_values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment