"examples/pytorch/vscode:/vscode.git/clone" did not exist on "421b05e7b8021f7e3020882254e8ed7b139531c4"
Commit 25700259 authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

kernels faster now - kernel load in initializer

parent c16689ca
......@@ -6,6 +6,8 @@ if torch.cuda.is_available():
def edgewise_spline_weighting(input, weight, amount, index):
if input.is_cuda:
return EdgewiseSplineWeightingGPU(amount, index)(input, weight)
K, M_in, M_out = weight.size()
k_max = amount.size(1)
return EdgewiseSplineWeightingGPU(amount, index, K, M_in, M_out, k_max)(input, weight)
else:
raise NotImplementedError
......@@ -8,9 +8,9 @@ _edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_forward_kernel(
const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output,
const ${Dtype}* amount, const long* index) {
const ${Dtype}* amount, const long* index, int num_threads) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
......@@ -50,9 +50,9 @@ extern "C"
__global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight,
const ${Dtype}* input, const ${Dtype}* weight, const ${Dtype}* amount,
const long* index) {
const long* index, int num_threads) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
CUDA_KERNEL_LOOP(idx, num_threads) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
......@@ -97,40 +97,52 @@ const long* index) {
class EdgewiseSplineWeightingGPU(Function):
def __init__(self, amount, index):
def __init__(self, amount, index, K, M_in, M_out, k_max):
super(EdgewiseSplineWeightingGPU, self).__init__()
assert amount.is_cuda and index.is_cuda
self.amount = amount
self.index = index
self.M_in = M_in
self.M_out = M_out
self.K = K
with torch.cuda.device_of(amount):
self.f_fw = load_kernel(
'edgewise_spline_weighting_forward_kernel',
_edgewise_spline_weighting_forward_kernel,
Dtype=Dtype(amount),
M_in=M_in,
M_out=M_out,
k_max=k_max)
self.f_bw = load_kernel(
'edgewise_spline_weighting_backward_kernel',
_edgewise_spline_weighting_backward_kernel,
Dtype=Dtype(amount),
M_in=M_in,
M_out=M_out,
k_max=k_max,
K=K)
def forward(self, input, weight):
assert input.is_cuda and weight.is_cuda
self.save_for_backward(input, weight)
_, M_in, M_out = weight.size()
k_max = self.amount.size(1)
output = input.new(input.size(0), M_out)
output = input.new(input.size(0), self.M_out)
num_threads = output.numel()
with torch.cuda.device_of(input):
f = load_kernel(
'edgewise_spline_weighting_forward_kernel',
_edgewise_spline_weighting_forward_kernel,
Dtype=Dtype(input),
num_threads=num_threads,
M_in=M_in,
M_out=M_out,
k_max=k_max)
f(block=(cuda_num_threads, 1, 1),
self.f_fw(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
weight.data_ptr(),
output.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr()
self.index.data_ptr(),
num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
......@@ -139,25 +151,14 @@ class EdgewiseSplineWeightingGPU(Function):
def backward(self, grad_output):
input, weight = self.saved_tensors
K, M_in, M_out = weight.size()
k_max = self.amount.size(1)
grad_input = grad_output.new(input.size(0), M_in).fill_(0)
grad_weight = grad_output.new(K, M_in, M_out).fill_(0)
grad_input = grad_output.new(input.size(0), self.M_in).fill_(0)
grad_weight = grad_output.new(self.K, self.M_in, self.M_out).fill_(0)
num_threads = grad_output.numel()
with torch.cuda.device_of(grad_output):
f = load_kernel(
'edgewise_spline_weighting_backward_kernel',
_edgewise_spline_weighting_backward_kernel,
Dtype=Dtype(input),
num_threads=num_threads,
M_in=M_in,
M_out=M_out,
k_max=k_max,
K=K)
f(block=(cuda_num_threads, 1, 1),
self.f_bw(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
grad_output.data_ptr(),
......@@ -166,7 +167,8 @@ class EdgewiseSplineWeightingGPU(Function):
input.data_ptr(),
weight.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr()
self.index.data_ptr(),
num_threads
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment