Commit 2f8f1501 authored by rusty1s's avatar rusty1s
Browse files

ignore long edges in spline convolution

parent 2cde9023
...@@ -30,8 +30,9 @@ const long* kernel_size, const long* is_open_spline, int num_threads) { ...@@ -30,8 +30,9 @@ const long* kernel_size, const long* is_open_spline, int num_threads) {
k_idx_mod = k_idx % 2; k_idx_mod = k_idx % 2;
k_idx >>= 1; k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] * value = input[e_idx * ${dim} + d_idx];
(kernel_size[d_idx] - is_open_spline[d_idx]); if (value >= 1) { a = 0; i = 0; break; }
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value); frac = value - floor(value);
...@@ -170,23 +171,23 @@ def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel): ...@@ -170,23 +171,23 @@ def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel):
input = input.unsqueeze(1) if len(input.size()) < 2 else input input = input.unsqueeze(1) if len(input.size()) < 2 else input
num_edges, dim = input.size() num_edges, dim = input.size()
k_max = 2 ** dim k_max = 2**dim
amount = input.new(num_edges, k_max) amount = input.new(num_edges, k_max)
index = input.new(num_edges, k_max).long() index = input.new(num_edges, k_max).long()
num_threads = amount.numel() num_threads = amount.numel()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
basis_kernel(block=(cuda_num_threads, 1, 1), basis_kernel(
grid=(get_blocks(num_threads), 1, 1), block=(cuda_num_threads, 1, 1),
args=[ grid=(get_blocks(num_threads), 1, 1),
input.data_ptr(), args=[
amount.data_ptr(), input.data_ptr(),
index.data_ptr(), amount.data_ptr(),
kernel_size.data_ptr(), index.data_ptr(),
is_open_spline.data_ptr(), kernel_size.data_ptr(),
num_threads is_open_spline.data_ptr(), num_threads
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return amount, index return amount, index
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from ....utils.cuda import (cuda_num_threads, Stream, Dtype, load_kernel, from ....utils.cuda import (cuda_num_threads, Stream, load_kernel, kernel_loop,
kernel_loop, get_blocks) get_blocks)
_edgewise_spline_weighting_forward_kernel = kernel_loop + ''' _edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C" extern "C"
...@@ -144,18 +144,17 @@ class EdgewiseSplineWeightingGPU(Function): ...@@ -144,18 +144,17 @@ class EdgewiseSplineWeightingGPU(Function):
num_threads = output.numel() num_threads = output.numel()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
self.f_fw(block=(cuda_num_threads, 1, 1), self.f_fw(
grid=(get_blocks(num_threads), 1, 1), block=(cuda_num_threads, 1, 1),
args=[ grid=(get_blocks(num_threads), 1, 1),
input.data_ptr(), args=[
weight.data_ptr(), input.data_ptr(),
output.data_ptr(), weight.data_ptr(),
self.amount.data_ptr(), output.data_ptr(),
self.index.data_ptr(), self.amount.data_ptr(),
num_threads self.index.data_ptr(), num_threads
], ],
stream=Stream( stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
ptr=torch.cuda.current_stream().cuda_stream))
return output return output
...@@ -168,19 +167,18 @@ class EdgewiseSplineWeightingGPU(Function): ...@@ -168,19 +167,18 @@ class EdgewiseSplineWeightingGPU(Function):
num_threads = grad_output.numel() num_threads = grad_output.numel()
with torch.cuda.device_of(grad_output): with torch.cuda.device_of(grad_output):
self.f_bw(block=(cuda_num_threads, 1, 1), self.f_bw(
grid=(get_blocks(num_threads), 1, 1), block=(cuda_num_threads, 1, 1),
args=[ grid=(get_blocks(num_threads), 1, 1),
grad_output.data_ptr(), args=[
grad_input.data_ptr(), grad_output.data_ptr(),
grad_weight.data_ptr(), grad_input.data_ptr(),
input.data_ptr(), grad_weight.data_ptr(),
weight.data_ptr(), input.data_ptr(),
self.amount.data_ptr(), weight.data_ptr(),
self.index.data_ptr(), self.amount.data_ptr(),
num_threads self.index.data_ptr(), num_threads
], ],
stream=Stream( stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
ptr=torch.cuda.current_stream().cuda_stream))
return grad_input, grad_weight return grad_input, grad_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment