THBasis.c 3.13 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THBasis.c"
#else

rusty1s's avatar
rusty1s committed
5
6
7
8
inline real THTensor_(linear)(real v, int64_t kMod) {
  return 1 - v - kMod + 2 * v * kMod;
}

rusty1s's avatar
rusty1s committed
9
10
11
12
inline real THTensor_(gradLinear)(real v, int64_t kMod) {
  return 2 * kMod - 1;
}

rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
inline real THTensor_(quadratic)(real v, int64_t kMod) {
  if (kMod == 0) return 0.5 * v * v - v + 0.5;
  else if (kMod == 1) return -v * v + v + 0.5;
  else return 0.5 * v * v;
}

rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
inline real THTensor_(gradQuadratic)(real v, int64_t kMod) {
  if (kMod == 0) return v - 1;
  else if (kMod == 1) return -2 * v + 1;
  else return v;
}

rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
inline real THTensor_(cubic)(real v, int64_t kMod) {
  if (kMod == 0) { v = (1 - v); return v * v * v / 6.0; }
  else if (kMod == 1) return (3 * v * v * v - 6 * v * v + 4) / 6;
  else if (kMod == 2) return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
  else return v * v * v / 6;
}

rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
inline real THTensor_(gradCubic)(real v, int64_t kMod) {
  if (kMod == 0) return (-v * v + 2 * v - 1) / 2;
  else if (kMod == 1) return (3 * v * v -4 * v) / 2;
  else if (kMod == 2) return (-3 * v * v + 2 * v + 1) / 2;
  else return v * v / 2;
}

rusty1s's avatar
rusty1s committed
39
40
41
void THTensor_(linearBasisForward)(THTensor *basis, THLongTensor *weightIndex, THTensor *pseudo,
                                   THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
  TH_TENSOR_BASIS_FORWARD(1, basis, weightIndex, pseudo, kernelSize, isOpenSpline,
rusty1s's avatar
rusty1s committed
42
                          THTensor_(linear)(v, kMod))
rusty1s's avatar
rusty1s committed
43
44
45
46
47
}

void THTensor_(quadraticBasisForward)(THTensor *basis, THLongTensor *weightIndex, THTensor *pseudo,
                                      THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
  TH_TENSOR_BASIS_FORWARD(2, basis, weightIndex, pseudo, kernelSize, isOpenSpline,
rusty1s's avatar
rusty1s committed
48
                          THTensor_(quadratic)(v, kMod))
rusty1s's avatar
rusty1s committed
49
50
51
52
53
}

void THTensor_(cubicBasisForward)(THTensor *basis, THLongTensor *weightIndex, THTensor *pseudo,
                                  THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
  TH_TENSOR_BASIS_FORWARD(3, basis, weightIndex, pseudo, kernelSize, isOpenSpline,
rusty1s's avatar
rusty1s committed
54
                          THTensor_(cubic)(v, kMod))
rusty1s's avatar
rusty1s committed
55
56
}

rusty1s's avatar
rusty1s committed
57
58
void THTensor_(linearBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
                                    THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
rusty1s's avatar
rusty1s committed
59
60
  TH_TENSOR_BASIS_BACKWARD(1, self, gradBasis, pseudo, kernelSize, isOpenSpline,
                           THTensor_(linear)(v, kMod), THTensor_(gradLinear)(v, kMod))
rusty1s's avatar
rusty1s committed
61
62
63
64
}

void THTensor_(quadraticBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
                                       THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
rusty1s's avatar
rusty1s committed
65
66
  TH_TENSOR_BASIS_BACKWARD(2, self, gradBasis, pseudo, kernelSize, isOpenSpline,
                           THTensor_(quadratic)(v, kMod), THTensor_(gradQuadratic)(v, kMod))
rusty1s's avatar
rusty1s committed
67
68
69
70
}

void THTensor_(cubicBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
                                   THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
rusty1s's avatar
rusty1s committed
71
72
  TH_TENSOR_BASIS_BACKWARD(3, self, gradBasis, pseudo, kernelSize, isOpenSpline,
                           THTensor_(cubic)(v, kMod), THTensor_(gradCubic)(v, kMod))
rusty1s's avatar
rusty1s committed
73
74
}

rusty1s's avatar
rusty1s committed
75
#endif // TH_GENERIC_FILE