kernel.cu 2.97 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include <THC.h>

#include "kernel.h"

rusty1s's avatar
rusty1s committed
5
6
#include "common.cuh"
#include "THCBasisForward.cuh"
rusty1s's avatar
rusty1s committed
7
#include "THCAtomics.cuh"
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _kernel_, Real)
rusty1s's avatar
rusty1s committed
10
11
12
13
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)

#include "generic/common.cu"
#include "THCGenerateAllTypes.h"
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
19
template<typename Real>
__global__ void weightingForwardKernel(TensorInfo<Real> output, TensorInfo<Real> input, TensorInfo<Real> weight, TensorInfo<Real> basis, TensorInfo<int64_t> weightIndex, int n) {
  KERNEL_LOOP(i, n) {
    int64_t edgeOffset = i / output.size[1], inputOffset = edgeOffset * input.stride[0];
    int64_t s, S = basis.size[1], m_in, M_in = input.size[1], m_out = i % output.size[1], M_out = output.size[1], weightOffset;
rusty1s's avatar
rusty1s committed
20
    Real value = 0;
rusty1s's avatar
rusty1s committed
21
22
23
    for (s = 0; s < S; s++) {
      weightOffset = weightIndex.data[edgeOffset * S + s] * M_in * M_out + m_out;
      for (m_in = 0; m_in < M_in; m_in++) {
rusty1s's avatar
rusty1s committed
24
        value += weight.data[weightOffset + m_in * M_out] * input.data[inputOffset + m_in * input.stride[1]];
rusty1s's avatar
rusty1s committed
25
      }
rusty1s's avatar
rusty1s committed
26
      value *= basis.data[edgeOffset * S + s];
rusty1s's avatar
rusty1s committed
27
28
29
30
31
    }
    output.data[i] = value;
  }
}

rusty1s's avatar
rusty1s committed
32
33
34
template<typename Real>
__global__ void weightingBackwardInputKernel(TensorInfo<Real> gradInput, TensorInfo<Real> gradOutput, TensorInfo<Real> weight, TensorInfo<Real> basis, TensorInfo<int64_t> weightIndex, int n) {
  KERNEL_LOOP(i, n) {
rusty1s's avatar
rusty1s committed
35
    int64_t edgeOffset = i / gradInput.size[1], gradOutputOffset = edgeOffset * gradOutput.stride[0];
rusty1s's avatar
rusty1s committed
36
    int64_t s, S = basis.size[1], m_in = i % gradInput.size[1], M_in = gradInput.size[1], m_out, M_out = gradOutput.size[1], weightOffset;
rusty1s's avatar
rusty1s committed
37
    Real value = 0;
rusty1s's avatar
rusty1s committed
38
    for (s = 0; s < S; s++) {
rusty1s's avatar
rusty1s committed
39
      weightOffset = weightIndex.data[edgeOffset * S + s] * M_in * M_out + m_in;
rusty1s's avatar
rusty1s committed
40
      for (m_out = 0; m_out < M_out; m_out++) {
rusty1s's avatar
rusty1s committed
41
        value += weight.data[weightOffset + M_in * m_out] * gradOutput.data[gradOutputOffset + m_out];
rusty1s's avatar
rusty1s committed
42
      }
rusty1s's avatar
rusty1s committed
43
      value *= basis.data[edgeOffset * S + s];
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    }
    gradInput.data[i] = value;
  }
}

template<typename Real>
__global__ void weightingBackwardWeightKernel(TensorInfo<Real> gradWeight, TensorInfo<Real> gradOutput, TensorInfo<Real> input, TensorInfo<Real> basis, TensorInfo<int64_t> weightIndex, int n) {
  KERNEL_LOOP(i, n) {
    int64_t edgeOffset = i / gradOutput.size[1], inputOffset = edgeOffset * input.stride[0];
    int64_t s, S = basis.size[1];
    int64_t m_in, M_in = input.size[1];
    int64_t m_out = i % gradOutput.size[1], M_out = gradOutput.size[1];
    int64_t weightOffset;
    Real b;
    Real value = gradOutput.data[edgeOffset * M_out + m_out];
    for (s = 0; s < S; s++) {
      b = basis.data[edgeOffset * S + s];
      weightOffset = weightIndex.data[edgeOffset * S + s] * M_in * M_out + m_out;
      for (m_in = 0; m_in < M_in; m_in++) {
rusty1s's avatar
typo  
rusty1s committed
63
        atomicAdd(&gradWeight.data[weightOffset + m_in * M_out], b * value * input.data[inputOffset + m_in * input.stride[1]]);
rusty1s's avatar
rusty1s committed
64
65
66
67
68
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
69
70
71
72
#include "generic/kernel.cu"
#include "THCGenerateFloatType.h"
#include "generic/kernel.cu"
#include "THCGenerateDoubleType.h"