dim_apply.h 8.51 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#pragma once

#include <torch/torch.h>

#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE)  \
  TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>();                               \
  auto TENSOR1##_size = TENSOR1.size(DIM);                                     \
  auto TENSOR1##_stride = TENSOR1.stride(DIM);                                 \
                                                                               \
  TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>();                               \
  auto TENSOR2##_size = TENSOR2.size(DIM);                                     \
  auto TENSOR2##_stride = TENSOR2.stride(DIM);                                 \
                                                                               \
  TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>();                               \
  auto TENSOR3##_size = TENSOR3.size(DIM);                                     \
  auto TENSOR3##_stride = TENSOR3.stride(DIM);                                 \
                                                                               \
  auto dims = TENSOR1.dim();                                                   \
  auto zeros = at::zeros(torch::CPU(at::kLong), {dims});                       \
  auto counter = zeros.data<int64_t>();                                        \
  bool has_finished = false;                                                   \
                                                                               \
  while (!has_finished) {                                                      \
    CODE;                                                                      \
    if (dims == 1)                                                             \
      break;                                                                   \
                                                                               \
    for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) {                     \
      if (cur_dim == DIM) {                                                    \
        if (cur_dim == dims - 1) {                                             \
          has_finished = true;                                                 \
          break;                                                               \
        }                                                                      \
        continue;                                                              \
      }                                                                        \
                                                                               \
      counter[cur_dim]++;                                                      \
      TENSOR1##_data += TENSOR1.stride(cur_dim);                               \
      TENSOR2##_data += TENSOR2.stride(cur_dim);                               \
      TENSOR3##_data += TENSOR3.stride(cur_dim);                               \
                                                                               \
      if (counter[cur_dim] == TENSOR1.size(cur_dim)) {                         \
        if (cur_dim == dims - 1) {                                             \
          has_finished = true;                                                 \
          break;                                                               \
        } else {                                                               \
          TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim);        \
          TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim);        \
          TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim);        \
          counter[cur_dim] = 0;                                                \
        }                                                                      \
      } else                                                                   \
        break;                                                                 \
    }                                                                          \
  }

#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4,      \
                   TENSOR4, DIM, CODE)                                         \
  TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>();                               \
  auto TENSOR1##_size = TENSOR1.size(DIM);                                     \
  auto TENSOR1##_stride = TENSOR1.stride(DIM);                                 \
                                                                               \
  TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>();                               \
  auto TENSOR2##_size = TENSOR2.size(DIM);                                     \
  auto TENSOR2##_stride = TENSOR2.stride(DIM);                                 \
                                                                               \
  TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>();                               \
  auto TENSOR3##_size = TENSOR3.size(DIM);                                     \
  auto TENSOR3##_stride = TENSOR3.stride(DIM);                                 \
                                                                               \
  TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>();                               \
  auto TENSOR4##_size = TENSOR4.size(DIM);                                     \
  auto TENSOR4##_stride = TENSOR4.stride(DIM);                                 \
                                                                               \
  auto dims = TENSOR1.dim();                                                   \
  auto zeros = at::zeros(torch::CPU(at::kLong), {dims});                       \
  auto counter = zeros.data<int64_t>();                                        \
  bool has_finished = false;                                                   \
                                                                               \
  while (!has_finished) {                                                      \
    CODE;                                                                      \
    if (dims == 1)                                                             \
      break;                                                                   \
                                                                               \
    for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) {                     \
      if (cur_dim == DIM) {                                                    \
        if (cur_dim == dims - 1) {                                             \
          has_finished = true;                                                 \
          break;                                                               \
        }                                                                      \
        continue;                                                              \
      }                                                                        \
                                                                               \
      counter[cur_dim]++;                                                      \
      TENSOR1##_data += TENSOR1.stride(cur_dim);                               \
      TENSOR2##_data += TENSOR2.stride(cur_dim);                               \
      TENSOR3##_data += TENSOR3.stride(cur_dim);                               \
      TENSOR4##_data += TENSOR4.stride(cur_dim);                               \
                                                                               \
      if (counter[cur_dim] == TENSOR1.size(cur_dim)) {                         \
        if (cur_dim == dims - 1) {                                             \
          has_finished = true;                                                 \
          break;                                                               \
        } else {                                                               \
          TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim);        \
          TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim);        \
          TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim);        \
          TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim);        \
          counter[cur_dim] = 0;                                                \
        }                                                                      \
      } else                                                                   \
        break;                                                                 \
    }                                                                          \
  }