index_info.h 1.64 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#pragma once

#include <torch/extension.h>

#include "compat.h"

#define MAX_TENSORINFO_DIMS 25

template <typename scalar_t> struct TensorInfo {
  TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
             int st[MAX_TENSORINFO_DIMS]) {
    data = p;
    dims = dim;
    AT_ASSERT(dims < MAX_TENSORINFO_DIMS);

    for (int i = 0; i < dim; ++i) {
      sizes[i] = sz[i];
      strides[i] = st[i];
    }
  }

  scalar_t *data;
  int dims;
  int sizes[MAX_TENSORINFO_DIMS];
  int strides[MAX_TENSORINFO_DIMS];
};

template <typename scalar_t>
rusty1s's avatar
rusty1s committed
29
TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  int sizes[MAX_TENSORINFO_DIMS];
  int strides[MAX_TENSORINFO_DIMS];

  int dims = tensor.dim();
  for (int i = 0; i < dims; ++i) {
    sizes[i] = tensor.size(i);
    strides[i] = tensor.stride(i);
  }

  return TensorInfo<scalar_t>(tensor.DATA_PTR<scalar_t>(), dims, sizes,
                              strides);
}

template <typename scalar_t> struct IndexToOffset {
  static inline int get(int idx, const TensorInfo<scalar_t> &info) {
    int offset = 0;
    for (int i = info.dims - 1; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

template <typename scalar_t> struct IndexPtrToOffset {
  static inline int get(int idx, const TensorInfo<scalar_t> &info) {
    int offset = idx % (info.sizes[info.dims - 1] - 1);
    offset *= info.strides[info.dims - 1];
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};