index.cuh 3.93 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#pragma once

#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>

template <typename scalar1, typename scalar2, int64_t Dims>
struct IndexToScatterOffsets3 {
  static __device__ void
  compute(int64_t i, const int64_t dim,
          const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
          int64_t *indexOffset,
          const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
          int64_t *t1Offset,
          const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
          int64_t *t2Offset) {
    for (int64_t d = Dims - 1; d >= 0; d--) {
      int64_t curDimIndex = i % index.sizes[d];
      *indexOffset += curDimIndex * index.strides[d];
      *t1Offset += curDimIndex * t1.strides[d];
      if (d != dim) {
        *t2Offset += curDimIndex * t2.strides[d];
      }
      i /= index.sizes[d];
    }
    int64_t indexValue = index.data[*indexOffset];
    *t2Offset += indexValue * t2.strides[dim];
  }
};

template <typename scalar1, typename scalar2>
struct IndexToScatterOffsets3<scalar1, scalar2, -1> {
  static __device__ void
  compute(int64_t i, const int64_t dim,
          const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
          int64_t *indexOffset,
          const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
          int64_t *t1Offset,
          const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
          int64_t *t2Offset) {
    for (int64_t d = index.dims - 1; d >= 0; d--) {
      int64_t curDimIndex = i % index.sizes[d];
      *indexOffset += curDimIndex * index.strides[d];
      *t1Offset += curDimIndex * t1.strides[d];
      if (d != dim) {
        *t2Offset += curDimIndex * t2.strides[d];
      }
      i /= index.sizes[d];
    }
    int64_t indexValue = index.data[*indexOffset];
    *t2Offset += indexValue * t2.strides[dim];
  }
};

template <typename scalar1, typename scalar2, typename scalar3, int64_t Dims>
struct IndexToScatterOffsets4 {
  static __device__ void
  compute(int64_t i, const int64_t dim,
          const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
          int64_t *indexOffset,
          const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
          int64_t *t1Offset,
          const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
          int64_t *t2Offset,
          const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
          int64_t *t3Offset) {
    for (int64_t d = Dims - 1; d >= 0; d--) {
      int64_t curDimIndex = i % index.sizes[d];
      *indexOffset += curDimIndex * index.strides[d];
      *t1Offset += curDimIndex * t1.strides[d];
      if (d != dim) {
        *t2Offset += curDimIndex * t2.strides[d];
        *t3Offset += curDimIndex * t3.strides[d];
      }
      i /= index.sizes[d];
    }
    int64_t indexValue = index.data[*indexOffset];
    *t2Offset += indexValue * t2.strides[dim];
    *t3Offset += indexValue * t3.strides[dim];
  }
};

template <typename scalar1, typename scalar2, typename scalar3>
struct IndexToScatterOffsets4<scalar1, scalar2, scalar3, -1> {
  static __device__ void
  compute(int64_t i, const int64_t dim,
          const at::cuda::detail::TensorInfo<int64_t, int64_t> &index,
          int64_t *indexOffset,
          const at::cuda::detail::TensorInfo<scalar1, int64_t> &t1,
          int64_t *t1Offset,
          const at::cuda::detail::TensorInfo<scalar2, int64_t> &t2,
          int64_t *t2Offset,
          const at::cuda::detail::TensorInfo<scalar3, int64_t> &t3,
          int64_t *t3Offset) {
    for (int64_t d = index.dims - 1; d >= 0; d--) {
      int64_t curDimIndex = i % index.sizes[d];
      *indexOffset += curDimIndex * index.strides[d];
      *t1Offset += curDimIndex * t1.strides[d];
      if (d != dim) {
        *t2Offset += curDimIndex * t2.strides[d];
        *t3Offset += curDimIndex * t3.strides[d];
      }
      i /= index.sizes[d];
    }
    int64_t indexValue = index.data[*indexOffset];
    *t2Offset += indexValue * t2.strides[dim];
    *t3Offset += indexValue * t3.strides[dim];
  }
};