dim_apply.h 8.84 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#pragma once

rusty1s's avatar
rusty1s committed
3
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
4
5

#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE)  \
rusty1s's avatar
rusty1s committed
6
7
8
9
  [&] {                                                                        \
    TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>();                             \
    auto TENSOR1##_size = TENSOR1.size(DIM);                                   \
    auto TENSOR1##_stride = TENSOR1.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
10
                                                                               \
rusty1s's avatar
rusty1s committed
11
12
13
    TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>();                             \
    auto TENSOR2##_size = TENSOR2.size(DIM);                                   \
    auto TENSOR2##_stride = TENSOR2.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
14
                                                                               \
rusty1s's avatar
rusty1s committed
15
16
17
    TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>();                             \
    auto TENSOR3##_size = TENSOR3.size(DIM);                                   \
    auto TENSOR3##_stride = TENSOR3.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
18
                                                                               \
rusty1s's avatar
rusty1s committed
19
    auto dims = TENSOR1.dim();                                                 \
rusty1s's avatar
rusty1s committed
20
    auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong));          \
rusty1s's avatar
rusty1s committed
21
22
    auto counter = zeros.data<int64_t>();                                      \
    bool has_finished = false;                                                 \
rusty1s's avatar
rusty1s committed
23
                                                                               \
rusty1s's avatar
rusty1s committed
24
25
26
27
    while (!has_finished) {                                                    \
      CODE;                                                                    \
      if (dims == 1)                                                           \
        break;                                                                 \
rusty1s's avatar
rusty1s committed
28
                                                                               \
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
      for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) {                   \
        if (cur_dim == DIM) {                                                  \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          }                                                                    \
          continue;                                                            \
rusty1s's avatar
rusty1s committed
36
37
        }                                                                      \
                                                                               \
rusty1s's avatar
rusty1s committed
38
39
40
41
        counter[cur_dim]++;                                                    \
        TENSOR1##_data += TENSOR1.stride(cur_dim);                             \
        TENSOR2##_data += TENSOR2.stride(cur_dim);                             \
        TENSOR3##_data += TENSOR3.stride(cur_dim);                             \
rusty1s's avatar
rusty1s committed
42
                                                                               \
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
50
51
52
53
        if (counter[cur_dim] == TENSOR1.size(cur_dim)) {                       \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          } else {                                                             \
            TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim);      \
            TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim);      \
            TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim);      \
            counter[cur_dim] = 0;                                              \
          }                                                                    \
        } else                                                                 \
rusty1s's avatar
rusty1s committed
54
          break;                                                               \
rusty1s's avatar
rusty1s committed
55
      }                                                                        \
rusty1s's avatar
rusty1s committed
56
    }                                                                          \
rusty1s's avatar
rusty1s committed
57
  }()
rusty1s's avatar
rusty1s committed
58
59
60

#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4,      \
                   TENSOR4, DIM, CODE)                                         \
rusty1s's avatar
rusty1s committed
61
62
63
64
  [&] {                                                                        \
    TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>();                             \
    auto TENSOR1##_size = TENSOR1.size(DIM);                                   \
    auto TENSOR1##_stride = TENSOR1.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
65
                                                                               \
rusty1s's avatar
rusty1s committed
66
67
68
    TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>();                             \
    auto TENSOR2##_size = TENSOR2.size(DIM);                                   \
    auto TENSOR2##_stride = TENSOR2.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
69
                                                                               \
rusty1s's avatar
rusty1s committed
70
71
72
    TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>();                             \
    auto TENSOR3##_size = TENSOR3.size(DIM);                                   \
    auto TENSOR3##_stride = TENSOR3.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
73
                                                                               \
rusty1s's avatar
rusty1s committed
74
75
76
    TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>();                             \
    auto TENSOR4##_size = TENSOR4.size(DIM);                                   \
    auto TENSOR4##_stride = TENSOR4.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
77
                                                                               \
rusty1s's avatar
rusty1s committed
78
    auto dims = TENSOR1.dim();                                                 \
rusty1s's avatar
rusty1s committed
79
    auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong));          \
rusty1s's avatar
rusty1s committed
80
81
    auto counter = zeros.data<int64_t>();                                      \
    bool has_finished = false;                                                 \
rusty1s's avatar
rusty1s committed
82
                                                                               \
rusty1s's avatar
rusty1s committed
83
84
85
86
    while (!has_finished) {                                                    \
      CODE;                                                                    \
      if (dims == 1)                                                           \
        break;                                                                 \
rusty1s's avatar
rusty1s committed
87
                                                                               \
rusty1s's avatar
rusty1s committed
88
89
90
91
92
93
94
      for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) {                   \
        if (cur_dim == DIM) {                                                  \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          }                                                                    \
          continue;                                                            \
rusty1s's avatar
rusty1s committed
95
96
        }                                                                      \
                                                                               \
rusty1s's avatar
rusty1s committed
97
98
99
100
101
        counter[cur_dim]++;                                                    \
        TENSOR1##_data += TENSOR1.stride(cur_dim);                             \
        TENSOR2##_data += TENSOR2.stride(cur_dim);                             \
        TENSOR3##_data += TENSOR3.stride(cur_dim);                             \
        TENSOR4##_data += TENSOR4.stride(cur_dim);                             \
rusty1s's avatar
rusty1s committed
102
                                                                               \
rusty1s's avatar
rusty1s committed
103
104
105
106
107
108
109
110
111
112
113
114
        if (counter[cur_dim] == TENSOR1.size(cur_dim)) {                       \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          } else {                                                             \
            TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim);      \
            TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim);      \
            TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim);      \
            TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim);      \
            counter[cur_dim] = 0;                                              \
          }                                                                    \
        } else                                                                 \
rusty1s's avatar
rusty1s committed
115
          break;                                                               \
rusty1s's avatar
rusty1s committed
116
      }                                                                        \
rusty1s's avatar
rusty1s committed
117
    }                                                                          \
rusty1s's avatar
rusty1s committed
118
  }()