"tests/vscode:/vscode.git/clone" did not exist on "dd26ff10a8a1af1bea8fd7b52c7caea1a52e7dc8"
Commit 2f8f1501 authored by rusty1s's avatar rusty1s
Browse files

ignore long edges in spline convolution

parent 2cde9023
......@@ -30,8 +30,9 @@ const long* kernel_size, const long* is_open_spline, int num_threads) {
k_idx_mod = k_idx % 2;
k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - is_open_spline[d_idx]);
value = input[e_idx * ${dim} + d_idx];
if (value >= 1) { a = 0; i = 0; break; }
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value);
......@@ -170,23 +171,23 @@ def compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel):
input = input.unsqueeze(1) if len(input.size()) < 2 else input
num_edges, dim = input.size()
k_max = 2 ** dim
k_max = 2**dim
amount = input.new(num_edges, k_max)
index = input.new(num_edges, k_max).long()
num_threads = amount.numel()
with torch.cuda.device_of(input):
basis_kernel(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
amount.data_ptr(),
index.data_ptr(),
kernel_size.data_ptr(),
is_open_spline.data_ptr(),
num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
basis_kernel(
block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
amount.data_ptr(),
index.data_ptr(),
kernel_size.data_ptr(),
is_open_spline.data_ptr(), num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return amount, index
import torch
from torch.autograd import Function
from ....utils.cuda import (cuda_num_threads, Stream, Dtype, load_kernel,
kernel_loop, get_blocks)
from ....utils.cuda import (cuda_num_threads, Stream, load_kernel, kernel_loop,
get_blocks)
_edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C"
......@@ -144,18 +144,17 @@ class EdgewiseSplineWeightingGPU(Function):
num_threads = output.numel()
with torch.cuda.device_of(input):
self.f_fw(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
weight.data_ptr(),
output.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(),
num_threads
],
stream=Stream(
ptr=torch.cuda.current_stream().cuda_stream))
self.f_fw(
block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
weight.data_ptr(),
output.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(), num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return output
......@@ -168,19 +167,18 @@ class EdgewiseSplineWeightingGPU(Function):
num_threads = grad_output.numel()
with torch.cuda.device_of(grad_output):
self.f_bw(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
grad_output.data_ptr(),
grad_input.data_ptr(),
grad_weight.data_ptr(),
input.data_ptr(),
weight.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(),
num_threads
],
stream=Stream(
ptr=torch.cuda.current_stream().cuda_stream))
self.f_bw(
block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
grad_output.data_ptr(),
grad_input.data_ptr(),
grad_weight.data_ptr(),
input.data_ptr(),
weight.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr(), num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return grad_input, grad_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment