THCBasis.cu 1.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "THCBasis.h"

rusty1s's avatar
rusty1s committed
3
#include "THCBasisForward.cuh"
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10

template<typename T>
__global__ void linearBasisForwardKernel(TensorInfo<T> basis, TensorInfo<int64_t>weightIndex,
                                         TensorInfo<T> pseudo, int64_t *kernelSize,
                                         uint8_t *isOpenSpline, ptrdiff_t n) {
  THC_TENSOR_BASIS_FORWARD_KERNEL(1, basis, weightIndex, pseudo, kernelSize, isOpenSpline, n,
    v = BasisForward<T>::linear(v, kMod);
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
  )
}

template<typename T>
__global__ void quadraticBasisForwardKernel(TensorInfo<T> basis, TensorInfo<int64_t>weightIndex,
                                            TensorInfo<T> pseudo, int64_t *kernelSize,
                                            uint8_t *isOpenSpline, ptrdiff_t n) {
  THC_TENSOR_BASIS_FORWARD_KERNEL(2, basis, weightIndex, pseudo, kernelSize, isOpenSpline, n,
rusty1s's avatar
rusty1s committed
19
    v = BasisForward<T>::quadratic(v, kMod);
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
  )
}

template<typename T>
__global__ void cubicBasisForwardKernel(TensorInfo<T> basis, TensorInfo<int64_t>weightIndex,
                                        TensorInfo<T> pseudo, int64_t *kernelSize,
                                        uint8_t *isOpenSpline, ptrdiff_t n) {
  THC_TENSOR_BASIS_FORWARD_KERNEL(3, basis, weightIndex, pseudo, kernelSize, isOpenSpline, n,
rusty1s's avatar
rusty1s committed
28
    v = BasisForward<T>::cubic(v, kMod);
rusty1s's avatar
rusty1s committed
29
30
31
32
33
  )
}

#include "generic/THCBasis.cu"
#include "THC/THCGenerateFloatTypes.h"