Commit 25700259 authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

kernels faster now - kernel load in initializer

parent c16689ca
...@@ -6,6 +6,8 @@ if torch.cuda.is_available(): ...@@ -6,6 +6,8 @@ if torch.cuda.is_available():
def edgewise_spline_weighting(input, weight, amount, index): def edgewise_spline_weighting(input, weight, amount, index):
if input.is_cuda: if input.is_cuda:
return EdgewiseSplineWeightingGPU(amount, index)(input, weight) K, M_in, M_out = weight.size()
k_max = amount.size(1)
return EdgewiseSplineWeightingGPU(amount, index, K, M_in, M_out, k_max)(input, weight)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -8,9 +8,9 @@ _edgewise_spline_weighting_forward_kernel = kernel_loop + ''' ...@@ -8,9 +8,9 @@ _edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C" extern "C"
__global__ void edgewise_spline_weighting_forward_kernel( __global__ void edgewise_spline_weighting_forward_kernel(
const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output, const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output,
const ${Dtype}* amount, const long* index) { const ${Dtype}* amount, const long* index, int num_threads) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) { CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${M_out}; const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out}; const int m_out_idx = idx % ${M_out};
...@@ -50,9 +50,9 @@ extern "C" ...@@ -50,9 +50,9 @@ extern "C"
__global__ void edgewise_spline_weighting_backward_kernel( __global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight, const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight,
const ${Dtype}* input, const ${Dtype}* weight, const ${Dtype}* amount, const ${Dtype}* input, const ${Dtype}* weight, const ${Dtype}* amount,
const long* index) { const long* index, int num_threads) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) { CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${M_out}; const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out}; const int m_out_idx = idx % ${M_out};
...@@ -97,40 +97,52 @@ const long* index) { ...@@ -97,40 +97,52 @@ const long* index) {
class EdgewiseSplineWeightingGPU(Function): class EdgewiseSplineWeightingGPU(Function):
def __init__(self, amount, index): def __init__(self, amount, index, K, M_in, M_out, k_max):
super(EdgewiseSplineWeightingGPU, self).__init__() super(EdgewiseSplineWeightingGPU, self).__init__()
assert amount.is_cuda and index.is_cuda assert amount.is_cuda and index.is_cuda
self.amount = amount self.amount = amount
self.index = index self.index = index
self.M_in = M_in
self.M_out = M_out
self.K = K
with torch.cuda.device_of(amount):
self.f_fw = load_kernel(
'edgewise_spline_weighting_forward_kernel',
_edgewise_spline_weighting_forward_kernel,
Dtype=Dtype(amount),
M_in=M_in,
M_out=M_out,
k_max=k_max)
self.f_bw = load_kernel(
'edgewise_spline_weighting_backward_kernel',
_edgewise_spline_weighting_backward_kernel,
Dtype=Dtype(amount),
M_in=M_in,
M_out=M_out,
k_max=k_max,
K=K)
def forward(self, input, weight): def forward(self, input, weight):
assert input.is_cuda and weight.is_cuda assert input.is_cuda and weight.is_cuda
self.save_for_backward(input, weight) self.save_for_backward(input, weight)
_, M_in, M_out = weight.size()
k_max = self.amount.size(1)
output = input.new(input.size(0), M_out) output = input.new(input.size(0), self.M_out)
num_threads = output.numel() num_threads = output.numel()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
f = load_kernel(
'edgewise_spline_weighting_forward_kernel', self.f_fw(block=(cuda_num_threads, 1, 1),
_edgewise_spline_weighting_forward_kernel,
Dtype=Dtype(input),
num_threads=num_threads,
M_in=M_in,
M_out=M_out,
k_max=k_max)
f(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1), grid=(get_blocks(num_threads), 1, 1),
args=[ args=[
input.data_ptr(), input.data_ptr(),
weight.data_ptr(), weight.data_ptr(),
output.data_ptr(), output.data_ptr(),
self.amount.data_ptr(), self.amount.data_ptr(),
self.index.data_ptr() self.index.data_ptr(),
num_threads
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
...@@ -139,25 +151,14 @@ class EdgewiseSplineWeightingGPU(Function): ...@@ -139,25 +151,14 @@ class EdgewiseSplineWeightingGPU(Function):
def backward(self, grad_output): def backward(self, grad_output):
input, weight = self.saved_tensors input, weight = self.saved_tensors
K, M_in, M_out = weight.size()
k_max = self.amount.size(1)
grad_input = grad_output.new(input.size(0), M_in).fill_(0) grad_input = grad_output.new(input.size(0), self.M_in).fill_(0)
grad_weight = grad_output.new(K, M_in, M_out).fill_(0) grad_weight = grad_output.new(self.K, self.M_in, self.M_out).fill_(0)
num_threads = grad_output.numel() num_threads = grad_output.numel()
with torch.cuda.device_of(grad_output): with torch.cuda.device_of(grad_output):
f = load_kernel( self.f_bw(block=(cuda_num_threads, 1, 1),
'edgewise_spline_weighting_backward_kernel',
_edgewise_spline_weighting_backward_kernel,
Dtype=Dtype(input),
num_threads=num_threads,
M_in=M_in,
M_out=M_out,
k_max=k_max,
K=K)
f(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1), grid=(get_blocks(num_threads), 1, 1),
args=[ args=[
grad_output.data_ptr(), grad_output.data_ptr(),
...@@ -166,7 +167,8 @@ class EdgewiseSplineWeightingGPU(Function): ...@@ -166,7 +167,8 @@ class EdgewiseSplineWeightingGPU(Function):
input.data_ptr(), input.data_ptr(),
weight.data_ptr(), weight.data_ptr(),
self.amount.data_ptr(), self.amount.data_ptr(),
self.index.data_ptr() self.index.data_ptr(),
num_threads
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment