Commit e1fcf1d2 authored by rusty1s's avatar rusty1s
Browse files

cuda function calls working

parent 00ec0037
......@@ -30,3 +30,17 @@ def test_spline_basis_cpu(tensor, i):
assert basis == expected_basis.view(-1).tolist()
assert index.tolist() == expected_index.tolist()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
def test_spline_basis_gpu(tensor, i):
degree = data[i].get('degree')
pseudo = Tensor(tensor, data[i]['pseudo']).cuda()
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
kernel_size = torch.cuda.LongTensor(data[i]['kernel_size'])
is_open_spline = torch.cuda.ByteTensor(data[i]['is_open_spline'])
K = kernel_size.prod()
basis, index = spline_basis_forward(degree, pseudo, kernel_size,
is_open_spline, K)
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else
void spline_(linear_basis_forward)(THCState *state, THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
printf("linear");
}
void spline_(quadratic_basis_forward)(THCState *state, THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
printf("quadratic");
}
void spline_(cubic_basis_forward)(THCState *state, THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
printf("cubic");
}
#endif
#include <THC.h>
#include "kernel.h"
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _kernel_, Real)
#include "generic/kernel.cu"
#include "THCGenerateFloatType.h"
#include "generic/kernel.cu"
#include "THCGenerateDoubleType.h"
#ifdef __cplusplus
extern "C" {
#endif
void spline_linear_basis_forward_kernel_Float (THCState *state, THCudaTensor *basis, THCudaLongTensor *weight_index, THCudaTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_linear_basis_forward_kernel_Double(THCState *state, THCudaDoubleTensor *basis, THCudaLongTensor *weight_index, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_quadratic_basis_forward_kernel_Float (THCState *state, THCudaTensor *basis, THCudaLongTensor *weight_index, THCudaTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_quadratic_basis_forward_kernel_Double(THCState *state, THCudaDoubleTensor *basis, THCudaLongTensor *weight_index, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_kernel_Float (THCState *state, THCudaTensor *basis, THCudaLongTensor *weight_index, THCudaTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_kernel_Double(THCState *state, THCudaDoubleTensor *basis, THCudaLongTensor *weight_index, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
#ifdef __cplusplus
}
#endif
#include <THC/THC.h>
#include "kernel.h"
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _cuda_, Real)
#define spline_kernel_(NAME) TH_CONCAT_4(spline_, NAME, _kernel_, Real)
extern THCState *state;
#include "generic/cuda.c"
#include "THCGenerateFloatType.h"
#include "generic/cuda.c"
......
......@@ -3,12 +3,15 @@
#else
void spline_(linear_basis_forward)(THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
spline_kernel_(linear_basis_forward)(state, basis, weight_index, pseudo, kernel_size, is_open_spline, K);
}
void spline_(quadratic_basis_forward)(THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
spline_kernel_(quadratic_basis_forward)(state, basis, weight_index, pseudo, kernel_size, is_open_spline, K);
}
void spline_(cubic_basis_forward)(THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
spline_kernel_(cubic_basis_forward)(state, basis, weight_index, pseudo, kernel_size, is_open_spline, K);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment