THCWeighting.cu 4.26 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "THCWeighting.h"

#include "common.cuh"
#include "THCNumerics.cuh"
rusty1s's avatar
rusty1s committed
5
#include "THCAtomics.cuh"
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

template<typename T>
__global__ void weightingForwardKernel(TensorInfo<T> self, TensorInfo<T> src, TensorInfo<T> weight,
                                       TensorInfo<T> basis, TensorInfo<int64_t> weightIndex,
                                       int n) {
  KERNEL_LOOP(i, n) {
    ptrdiff_t e = i / self.size[1], mOut = i % self.size[1], s, mIn;
    T v = ScalarConvert<int, T>::to(0), b, tmp;
    int64_t wi;
    for (s = 0; s < basis.size[1]; s++) {
      b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
      wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
      for (mIn = 0; mIn < src.size[1]; mIn++) {
        tmp = weight.data[wi * weight.stride[0] + mIn * weight.stride[1] + mOut * weight.stride[2]];
        tmp = THCNumerics<T>::mul(tmp, src.data[e * src.stride[0] + mIn * src.stride[1]]);
rusty1s's avatar
rusty1s committed
21
        tmp = THCNumerics<T>::mul(tmp, b);
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
        v = THCNumerics<T>::add(v, tmp);
      }
    }
    self.data[e * self.stride[0] + mOut * self.stride[1]] = v;
  }
}

rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
template<typename T>
__global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gradOutput,
                                           TensorInfo<T> weight, TensorInfo<T> basis,
                                           TensorInfo<int64_t> weightIndex, int n) {
  KERNEL_LOOP(i, n) {
    ptrdiff_t e = i / self.size[1], mIn = i % self.size[1], s, mOut;
    T v = ScalarConvert<int, T>::to(0), b, tmp;
    int64_t wi;
    for (s = 0; s < basis.size[1]; s++) {
      b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
      wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
      for (mOut = 0; mOut < gradOutput.size[1]; mOut++) {
        tmp = weight.data[wi * weight.stride[0] + mOut * weight.stride[1] + mIn * weight.stride[2]];
        tmp = THCNumerics<T>::mul(tmp, gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]]);
        tmp = THCNumerics<T>::mul(tmp, b);
        v = THCNumerics<T>::add(v, tmp);
      }
    }
    self.data[e * self.stride[0] + mIn * self.stride[1]] = v;
  }
}

rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
template<typename T>
__global__ void weightingBackwardWeightKernel(TensorInfo<T> self, TensorInfo<T> gradOutput,
                                              TensorInfo<T> src, TensorInfo<T> basis,
                                              TensorInfo<int64_t> weightIndex, int n) {
  KERNEL_LOOP(i, n) {
    ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
    T b, v;
    int64_t wi;
    T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
    for (s = 0; s < weightIndex.size[1]; s++) {
      b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
      wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
      for (mIn = 0; mIn < src.size[1]; mIn++) {
        v = src.data[e * src.stride[0] + mIn * src.stride[1]];
        v = THCNumerics<T>::mul(v, b);
        v = THCNumerics<T>::mul(v, g);
        atomicAdd(&self.data[wi * self.stride[0] + mIn * self.stride[1] + mOut * self.stride[2]], v);
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
73
74
75
76
77
78
template<typename T>
__global__ void weightingBackwardBasisKernel(TensorInfo<T> self, TensorInfo<T> gradOutput,
                                             TensorInfo<T> src, TensorInfo<T> weight,
                                             TensorInfo<int64_t> weightIndex, int n) {
  KERNEL_LOOP(i, n) {
    ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
rusty1s's avatar
rusty1s committed
79
    T v, tmp;
rusty1s's avatar
rusty1s committed
80
    int64_t wi;
rusty1s's avatar
rusty1s committed
81
    T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    for (s = 0; s < weightIndex.size[1]; s++) {
      v = ScalarConvert<int, T>::to(0);
      wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
      for (mIn = 0; mIn < src.size[1]; mIn++) {
        tmp = weight.data[wi * weight.stride[0] + mIn * weight.stride[1] + mOut * weight.stride[2]];
        tmp = THCNumerics<T>::mul(tmp, src.data[e * src.stride[0] + mIn * src.stride[1]]);
        tmp = THCNumerics<T>::mul(tmp, g);
        v = THCNumerics<T>::add(v, tmp);
      }
      atomicAdd(&self.data[e * self.stride[0] + s * self.stride[1]], v);
    }
  }
}

rusty1s's avatar
rusty1s committed
96
97
#include "generic/THCWeighting.cu"
#include "THC/THCGenerateFloatTypes.h"