THWeighting.c 1.72 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THWeighting.c"
#else

void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight, THTensor *basis,
                                  THLongTensor *weightIndex) {
  real *selfData = THTensor_(data)(self);
  real *srcData = THTensor_(data)(src);
  real *weightData = THTensor_(data)(weight);
  real *basisData = THTensor_(data)(basis);
  int64_t *weightIndexData = THLongTensor_data(weightIndex);

  ptrdiff_t e, mOut, s, mIn;
  real v, b;
  int64_t wi;
  for (e = 0; e < THTensor_(size)(src, 0); e++) {
    for (mOut = 0; mOut < THTensor_(size)(weight, 2); mOut++) {
      v = 0;
      for (s = 0; s < THTensor_(size)(basis, 1); s++) {
        b = basisData[e * basis->stride[0] + s * basis->stride[1]];
        wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
        for (mIn = 0; mIn < THTensor_(size)(weight, 1); mIn++) {
          v += b * weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]] * srcData[e * src->stride[0] + mIn * src->stride[1]];
        }
      }
      selfData[e * self->stride[0] + mOut * self->stride[1]] = v;
    }
  }
}

void THTensor_(weightingBackwardSrc)(THTensor *self, THTensor *gradOutput, THTensor *weight,
                                      THTensor *basis, THLongTensor *weightIndex) {
}

void THTensor_(weightingBackwardWeight)(THTensor *self, THTensor *gradOutput, THTensor *src,
                                         THTensor *basis, THLongTensor *weightIndex) {
}

void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THTensor *src,
                                        THTensor *weight, THLongTensor *weightIndex) {
}

#endif // TH_GENERIC_FILE