THWeighting.c 4.77 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THWeighting.c"
#else

void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight, THTensor *basis,
                                  THLongTensor *weightIndex) {
  real *selfData = THTensor_(data)(self);
  real *srcData = THTensor_(data)(src);
  real *weightData = THTensor_(data)(weight);
  real *basisData = THTensor_(data)(basis);
  int64_t *weightIndexData = THLongTensor_data(weightIndex);

  ptrdiff_t e, mOut, s, mIn;
rusty1s's avatar
rusty1s committed
14
  real v, b, tmp;
rusty1s's avatar
rusty1s committed
15
16
  int64_t wi;
  for (e = 0; e < THTensor_(size)(src, 0); e++) {
rusty1s's avatar
rusty1s committed
17
    for (mOut = 0; mOut < THTensor_(size)(self, 1); mOut++) {
rusty1s's avatar
rusty1s committed
18
19
20
21
      v = 0;
      for (s = 0; s < THTensor_(size)(basis, 1); s++) {
        b = basisData[e * basis->stride[0] + s * basis->stride[1]];
        wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
rusty1s's avatar
rusty1s committed
22
23
24
25
        for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
          tmp = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
          tmp *= b * srcData[e * src->stride[0] + mIn * src->stride[1]];
          v += tmp;
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32
33
34
        }
      }
      selfData[e * self->stride[0] + mOut * self->stride[1]] = v;
    }
  }
}

void THTensor_(weightingBackwardSrc)(THTensor *self, THTensor *gradOutput, THTensor *weight,
                                      THTensor *basis, THLongTensor *weightIndex) {
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  THTensor_(fill)(self, 0);

  real *selfData = THTensor_(data)(self);
  real *gradOutputData = THTensor_(data)(gradOutput);
  real *weightData = THTensor_(data)(weight);
  real *basisData = THTensor_(data)(basis);
  int64_t *weightIndexData = THLongTensor_data(weightIndex);

  ptrdiff_t e, mOut, s, mIn;
  real g, b, v;
  int64_t wi;
  for (e = 0; e < THTensor_(size)(self, 0); e++) {
    for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
      g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
      for (s = 0; s < THTensor_(size)(basis, 1); s++) {
        b = basisData[e * basis->stride[0] + s * basis->stride[1]];
        wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
        for (mIn = 0; mIn < THTensor_(size)(self, 1); mIn++) {
          v = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
          selfData[e * self->stride[0] + mIn * self->stride[1]] += g * b * v;
        }
      }
    }
  }
rusty1s's avatar
rusty1s committed
59
60
61
62
}

void THTensor_(weightingBackwardWeight)(THTensor *self, THTensor *gradOutput, THTensor *src,
                                         THTensor *basis, THLongTensor *weightIndex) {
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  THTensor_(fill)(self, 0);

  real *selfData = THTensor_(data)(self);
  real *gradOutputData = THTensor_(data)(gradOutput);
  real *srcData = THTensor_(data)(src);
  real *basisData = THTensor_(data)(basis);
  int64_t *weightIndexData = THLongTensor_data(weightIndex);

  ptrdiff_t e, mOut, s, mIn;
  real g, b, v;
  int64_t wi;
  for (e = 0; e < THTensor_(size)(src, 0); e++) {
    for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
      g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
      for (s = 0; s < THTensor_(size)(basis, 1); s++) {
        b = basisData[e * basis->stride[0] + s * basis->stride[1]];
        wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
        for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
          v = b * g * srcData[e * src->stride[0] + mIn * src->stride[1]];
          selfData[wi * self->stride[0] + mIn * self->stride[1] + mOut * self->stride[2]] += v;
        }
      }
    }
  }
rusty1s's avatar
rusty1s committed
87
88
89
90
}

void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THTensor *src,
                                        THTensor *weight, THLongTensor *weightIndex) {
rusty1s's avatar
rusty1s committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  THTensor_(fill)(self, 0);

  real *selfData = THTensor_(data)(self);
  real *gradOutputData = THTensor_(data)(gradOutput);
  real *srcData = THTensor_(data)(src);
  real *weightData = THTensor_(data)(weight);
  int64_t *weightIndexData = THLongTensor_data(weightIndex);

  ptrdiff_t e, mOut, s, mIn;
  real g, v, tmp;
  int64_t wi;
  for (e = 0; e < THTensor_(size)(src, 0); e++) {
    for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
      g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
      for (s = 0; s < THLongTensor_size(weightIndex, 1); s++) {
        v = 0;
        wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
        for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
          tmp = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
          tmp *= srcData[e * src->stride[0] + mIn * src->stride[1]];
          v += tmp;
        }
        selfData[e * self->stride[0] + s * self->stride[1]] += g * v;
      }
    }
  }
rusty1s's avatar
rusty1s committed
117
118
119
}

#endif // TH_GENERIC_FILE