indptr.cuh 681 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
#pragma once

#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
4
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
5
6
7
8

// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
rusty1s's avatar
rusty1s committed
9
  static inline __host__ __device__ int
rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17
18
19
20
  get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
    int offset = idx % (info.sizes[info.dims - 1] - 1);
    offset *= info.strides[info.dims - 1];
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};