dim_apply.h 8.86 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#pragma once

rusty1s's avatar
rusty1s committed
3
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
6
#include "compat.h"

rusty1s's avatar
rusty1s committed
7
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE)  \
rusty1s's avatar
rusty1s committed
8
  [&] {                                                                        \
rusty1s's avatar
rusty1s committed
9
    TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>();                         \
rusty1s's avatar
rusty1s committed
10
11
    auto TENSOR1##_size = TENSOR1.size(DIM);                                   \
    auto TENSOR1##_stride = TENSOR1.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
12
                                                                               \
rusty1s's avatar
rusty1s committed
13
    TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>();                         \
rusty1s's avatar
rusty1s committed
14
15
    auto TENSOR2##_size = TENSOR2.size(DIM);                                   \
    auto TENSOR2##_stride = TENSOR2.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
16
                                                                               \
rusty1s's avatar
rusty1s committed
17
    TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>();                         \
rusty1s's avatar
rusty1s committed
18
19
    auto TENSOR3##_size = TENSOR3.size(DIM);                                   \
    auto TENSOR3##_stride = TENSOR3.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
20
                                                                               \
rusty1s's avatar
rusty1s committed
21
    auto dims = TENSOR1.dim();                                                 \
rusty1s's avatar
rusty1s committed
22
    auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong));          \
rusty1s's avatar
rusty1s committed
23
    auto counter = zeros.DATA_PTR<int64_t>();                                  \
rusty1s's avatar
rusty1s committed
24
    bool has_finished = false;                                                 \
rusty1s's avatar
rusty1s committed
25
                                                                               \
rusty1s's avatar
rusty1s committed
26
27
28
29
    while (!has_finished) {                                                    \
      CODE;                                                                    \
      if (dims == 1)                                                           \
        break;                                                                 \
rusty1s's avatar
rusty1s committed
30
                                                                               \
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
      for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) {                   \
        if (cur_dim == DIM) {                                                  \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          }                                                                    \
          continue;                                                            \
rusty1s's avatar
rusty1s committed
38
39
        }                                                                      \
                                                                               \
rusty1s's avatar
rusty1s committed
40
41
42
43
        counter[cur_dim]++;                                                    \
        TENSOR1##_data += TENSOR1.stride(cur_dim);                             \
        TENSOR2##_data += TENSOR2.stride(cur_dim);                             \
        TENSOR3##_data += TENSOR3.stride(cur_dim);                             \
rusty1s's avatar
rusty1s committed
44
                                                                               \
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
        if (counter[cur_dim] == TENSOR1.size(cur_dim)) {                       \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          } else {                                                             \
            TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim);      \
            TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim);      \
            TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim);      \
            counter[cur_dim] = 0;                                              \
          }                                                                    \
        } else                                                                 \
rusty1s's avatar
rusty1s committed
56
          break;                                                               \
rusty1s's avatar
rusty1s committed
57
      }                                                                        \
rusty1s's avatar
rusty1s committed
58
    }                                                                          \
rusty1s's avatar
rusty1s committed
59
  }()
rusty1s's avatar
rusty1s committed
60
61
62

#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4,      \
                   TENSOR4, DIM, CODE)                                         \
rusty1s's avatar
rusty1s committed
63
  [&] {                                                                        \
rusty1s's avatar
rusty1s committed
64
    TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>();                         \
rusty1s's avatar
rusty1s committed
65
66
    auto TENSOR1##_size = TENSOR1.size(DIM);                                   \
    auto TENSOR1##_stride = TENSOR1.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
67
                                                                               \
rusty1s's avatar
rusty1s committed
68
    TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>();                         \
rusty1s's avatar
rusty1s committed
69
70
    auto TENSOR2##_size = TENSOR2.size(DIM);                                   \
    auto TENSOR2##_stride = TENSOR2.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
71
                                                                               \
rusty1s's avatar
rusty1s committed
72
    TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>();                         \
rusty1s's avatar
rusty1s committed
73
74
    auto TENSOR3##_size = TENSOR3.size(DIM);                                   \
    auto TENSOR3##_stride = TENSOR3.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
75
                                                                               \
rusty1s's avatar
rusty1s committed
76
    TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>();                         \
rusty1s's avatar
rusty1s committed
77
78
    auto TENSOR4##_size = TENSOR4.size(DIM);                                   \
    auto TENSOR4##_stride = TENSOR4.stride(DIM);                               \
rusty1s's avatar
rusty1s committed
79
                                                                               \
rusty1s's avatar
rusty1s committed
80
    auto dims = TENSOR1.dim();                                                 \
rusty1s's avatar
rusty1s committed
81
    auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong));          \
rusty1s's avatar
rusty1s committed
82
    auto counter = zeros.DATA_PTR<int64_t>();                                  \
rusty1s's avatar
rusty1s committed
83
    bool has_finished = false;                                                 \
rusty1s's avatar
rusty1s committed
84
                                                                               \
rusty1s's avatar
rusty1s committed
85
86
87
88
    while (!has_finished) {                                                    \
      CODE;                                                                    \
      if (dims == 1)                                                           \
        break;                                                                 \
rusty1s's avatar
rusty1s committed
89
                                                                               \
rusty1s's avatar
rusty1s committed
90
91
92
93
94
95
96
      for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) {                   \
        if (cur_dim == DIM) {                                                  \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          }                                                                    \
          continue;                                                            \
rusty1s's avatar
rusty1s committed
97
98
        }                                                                      \
                                                                               \
rusty1s's avatar
rusty1s committed
99
100
101
102
103
        counter[cur_dim]++;                                                    \
        TENSOR1##_data += TENSOR1.stride(cur_dim);                             \
        TENSOR2##_data += TENSOR2.stride(cur_dim);                             \
        TENSOR3##_data += TENSOR3.stride(cur_dim);                             \
        TENSOR4##_data += TENSOR4.stride(cur_dim);                             \
rusty1s's avatar
rusty1s committed
104
                                                                               \
rusty1s's avatar
rusty1s committed
105
106
107
108
109
110
111
112
113
114
115
116
        if (counter[cur_dim] == TENSOR1.size(cur_dim)) {                       \
          if (cur_dim == dims - 1) {                                           \
            has_finished = true;                                               \
            break;                                                             \
          } else {                                                             \
            TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim);      \
            TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim);      \
            TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim);      \
            TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim);      \
            counter[cur_dim] = 0;                                              \
          }                                                                    \
        } else                                                                 \
rusty1s's avatar
rusty1s committed
117
          break;                                                               \
rusty1s's avatar
rusty1s committed
118
      }                                                                        \
rusty1s's avatar
rusty1s committed
119
    }                                                                          \
rusty1s's avatar
rusty1s committed
120
  }()