Commit 2cde9023 authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

complied to coding standards

parent a8109737
......@@ -144,11 +144,11 @@ const long* kernel_size, const long* is_open_spline, int num_threads) {
}
'''
def get_basis_kernel(k_max,K,dim,degree):
if degree==3:
def get_basis_kernel(k_max, K, dim, degree):
if degree == 3:
_spline_kernel = _spline_kernel_cubic
elif degree==2:
elif degree == 2:
_spline_kernel = _spline_kernel_quadratic
else:
_spline_kernel = _spline_kernel_linear
......@@ -164,12 +164,13 @@ def get_basis_kernel(k_max,K,dim,degree):
K=K)
return f
def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel):
assert input.is_cuda and kernel_size.is_cuda and is_open_spline.is_cuda
input = input.unsqueeze(1) if len(input.size()) < 2 else input
num_edges, dim = input.size()
k_max = 2**dim
k_max = 2 ** dim
amount = input.new(num_edges, k_max)
index = input.new(num_edges, k_max).long()
......@@ -177,15 +178,15 @@ def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel):
with torch.cuda.device_of(input):
basis_kernel(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
amount.data_ptr(),
index.data_ptr(),
kernel_size.data_ptr(),
is_open_spline.data_ptr(),
num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
amount.data_ptr(),
index.data_ptr(),
kernel_size.data_ptr(),
is_open_spline.data_ptr(),
num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return amount, index
......@@ -8,6 +8,6 @@ def edgewise_spline_weighting(input, weight, amount, index, k_fw, k_bw):
if input.is_cuda:
K, M_in, M_out = weight.size()
return EdgewiseSplineWeightingGPU(amount, index, K, M_in, M_out
,k_fw,k_bw)(input, weight)
, k_fw, k_bw)(input, weight)
else:
raise NotImplementedError
......@@ -95,7 +95,8 @@ const long* index, int num_threads) {
}
'''
def get_forward_kernel(M_in,M_out,k_max):
def get_forward_kernel(M_in, M_out, k_max):
cuda_tensor = torch.FloatTensor([1]).cuda()
with torch.cuda.device_of(cuda_tensor):
f_fw = load_kernel(
......@@ -107,7 +108,8 @@ def get_forward_kernel(M_in,M_out,k_max):
k_max=k_max)
return f_fw
def get_backward_kernel(M_in,M_out,k_max, K):
def get_backward_kernel(M_in, M_out, k_max, K):
cuda_tensor = torch.FloatTensor([1]).cuda()
with torch.cuda.device_of(cuda_tensor):
f_bw = load_kernel(
......@@ -133,36 +135,33 @@ class EdgewiseSplineWeightingGPU(Function):
self.f_fw = k_fw
self.f_bw = k_bw
def forward(self, input, weight):
assert input.is_cuda and weight.is_cuda
self.save_for_backward(input, weight)
output = input.new(input.size(0), self.M_out)
num_threads = output.numel()
with torch.cuda.device_of(input):
self.f_fw(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
weight.data_ptr(),
output.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(),
num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
weight.data_ptr(),
output.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(),
num_threads
],
stream=Stream(
ptr=torch.cuda.current_stream().cuda_stream))
return output
def backward(self, grad_output):
input, weight = self.saved_tensors
grad_input = grad_output.new(input.size(0), self.M_in).fill_(0)
grad_weight = grad_output.new(self.K, self.M_in, self.M_out).fill_(0)
......@@ -170,17 +169,18 @@ class EdgewiseSplineWeightingGPU(Function):
with torch.cuda.device_of(grad_output):
self.f_bw(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
grad_output.data_ptr(),
grad_input.data_ptr(),
grad_weight.data_ptr(),
input.data_ptr(),
weight.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(),
num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
grid=(get_blocks(num_threads), 1, 1),
args=[
grad_output.data_ptr(),
grad_input.data_ptr(),
grad_weight.data_ptr(),
input.data_ptr(),
weight.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(),
num_threads
],
stream=Stream(
ptr=torch.cuda.current_stream().cuda_stream))
return grad_input, grad_weight
......@@ -6,6 +6,7 @@ if torch.cuda.is_available():
def spline(input, kernel_size, is_open_spline, K, degree, basis_kernel):
if input.is_cuda:
return compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel)
return compute_spline_basis(input, kernel_size, is_open_spline, K,
basis_kernel)
else:
raise NotImplementedError()
......@@ -18,7 +18,8 @@ class SplineQuadraticGPUTest(unittest.TestCase):
K = 7
dim = 1
basis_kernel = get_basis_kernel(k_max, K, dim, 3)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 7, basis_kernel)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 7,
basis_kernel)
a2 = [
[0.1667, 0.6667, 0.1667, 0],
......@@ -44,7 +45,8 @@ class SplineQuadraticGPUTest(unittest.TestCase):
K = 4
dim = 1
basis_kernel = get_basis_kernel(k_max, K, dim, 3)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 4, basis_kernel)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 4,
basis_kernel)
a2 = [
[0.1667, 0.6667, 0.1667, 0],
......
......@@ -18,7 +18,8 @@ class SplineLinearGPUTest(unittest.TestCase):
K = 5
dim = 1
basis_kernel = get_basis_kernel(k_max, K, dim, 1)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 5, basis_kernel)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 5,
basis_kernel)
a2 = [[0, 1], [0.2, 0.8], [0, 1], [0, 1], [0, 1], [0.8, 0.2], [0, 1]]
i2 = [[1, 0], [1, 0], [2, 1], [3, 2], [4, 3], [4, 3], [0, 4]]
......@@ -35,7 +36,8 @@ class SplineLinearGPUTest(unittest.TestCase):
K = 4
dim = 1
basis_kernel = get_basis_kernel(k_max, K, dim, 1)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 4, basis_kernel)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 4,
basis_kernel)
a2 = [[0, 1], [0.2, 0.8], [0, 1], [0, 1], [0, 1], [0.8, 0.2], [0, 1]]
i2 = [[1, 0], [1, 0], [2, 1], [3, 2], [0, 3], [0, 3], [1, 0]]
......
......@@ -16,10 +16,11 @@ class SplineQuadraticGPUTest(unittest.TestCase):
is_open_spline = torch.cuda.LongTensor([1])
k_max = 3
K = 6
dim=1
basis_kernel = get_basis_kernel(k_max,K,dim,2)
dim = 1
basis_kernel = get_basis_kernel(k_max, K, dim, 2)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 6, basis_kernel)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 6,
basis_kernel)
a2 = [[0.5, 0.5, 0], [0.32, 0.66, 0.02], [0.5, 0.5, 0], [0.5, 0.5, 0],
[0.5, 0.5, 0], [0.02, 0.66, 0.32], [0.5, 0.5, 0]]
......@@ -38,7 +39,8 @@ class SplineQuadraticGPUTest(unittest.TestCase):
K = 4
dim = 1
basis_kernel = get_basis_kernel(k_max, K, dim, 2)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 4, basis_kernel)
a1, i1 = compute_spline_basis(input, kernel_size, is_open_spline, 4,
basis_kernel)
a2 = [[0.5, 0.5, 0], [0.32, 0.66, 0.02], [0.5, 0.5, 0], [0.5, 0.5, 0],
[0.5, 0.5, 0], [0.02, 0.66, 0.32], [0.5, 0.5, 0]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment