segment_kernel.cu 8.39 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>

#include <thrust/execution_policy.h>

rusty1s's avatar
rusty1s committed
11
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
12
#include "compat.cuh"
rusty1s's avatar
rusty1s committed
13
#include "index.cuh"
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
#define THREADS 256
rusty1s's avatar
rusty1s committed
16
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
17
18
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
19
20
21
22
template <typename T, typename I> struct IndexPtrToOffset {
  static __host__ __device__ I
  get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
    I offset = idx % (info.sizes[info.dims - 1] - 1);
rusty1s's avatar
rusty1s committed
23
    offset *= info.strides[info.dims - 1];
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

rusty1s's avatar
rusty1s committed
33
template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
34
35
36
37
__global__ void segment_add_csr_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t E) {
rusty1s's avatar
rusty1s committed
38
39

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
40
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
41
42
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
43
44
  if (row_idx < N) {
    auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
45
    int row_start = __ldg(indptr_info.data + offset);
rusty1s's avatar
rusty1s committed
46
47
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
48
49
    scalar_t val = (scalar_t)0;

rusty1s's avatar
rusty1s committed
50
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
51
    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
rusty1s's avatar
rusty1s committed
52
      val += src_data[offset + src_idx];
rusty1s's avatar
rusty1s committed
53
54
55
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
56
    for (int i = TB / 2; i > 0; i /= 2)
rusty1s's avatar
rusty1s committed
57
      val += __shfl_down_sync(FULL_MASK, val, i); // Parallel reduction
rusty1s's avatar
rusty1s committed
58
59

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
      out_data[row_idx] = val;
    }
  }
}

template <typename scalar_t>
__global__ void segment_add_csr_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t K, size_t E) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
    auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
    scalar_t val = (scalar_t)0;

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int src_idx = row_start; src_idx < row_end; src_idx++) {
      // Coalesced read into `src_data`.
      val += src_data[offset + K * src_idx + lane_idx];
rusty1s's avatar
rusty1s committed
86
    }
rusty1s's avatar
rusty1s committed
87
88

    out_data[thread_idx] = val; // Coalesced write into `out_data`
rusty1s's avatar
rusty1s committed
89
90
91
  }
}

rusty1s's avatar
rusty1s committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
template <typename scalar_t, int TB>
__global__ void segment_add_csr_broadcast_kernel2(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t K, size_t E) {

  int thread_idx = blockIdx.y * blockDim.y + threadIdx.y;
  int row_idx = thread_idx / TB;
  int lane_idx = thread_idx & (TB - 1);
  int col_idx = blockIdx.x * blockDim.x + threadIdx.x;

  __shared__ scalar_t vals[32][32];

  if (row_idx < N) {
    auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
    scalar_t val = (scalar_t)0;

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    if (col_idx < K) {
      for (int i = row_start + lane_idx; i < row_end; i += TB) {
        val += src_data[offset + K * i + col_idx];
      }
    }

    vals[threadIdx.x][threadIdx.y] = val;
    __syncthreads();

#pragma unroll
    for (int i = 1; i < TB; i *= 2) {
      vals[threadIdx.x][threadIdx.y] += vals[threadIdx.x][threadIdx.y + i];
      __syncthreads();
    }

    if (col_idx < K && lane_idx == 0) {
      out_data[row_idx * K + col_idx] = vals[threadIdx.x][threadIdx.y];
    }
  }
}

rusty1s's avatar
rusty1s committed
134
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
rusty1s's avatar
rusty1s committed
135
  AT_ASSERTM(src.dim() >= indptr.dim());
rusty1s's avatar
rusty1s committed
136
137
138
  for (int i = 0; i < indptr.dim() - 1; i++)
    AT_ASSERTM(src.size(i) == indptr.size(i));

rusty1s's avatar
rusty1s committed
139
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
140
141
142
143
144

  auto reduce_dim = indptr.dim() - 1;
  auto sizes = src.sizes().vec();
  sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
  auto out = at::empty(sizes, src.options());
rusty1s's avatar
rusty1s committed
145

rusty1s's avatar
rusty1s committed
146
147
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
148
149
  auto E = src.size(reduce_dim);
  auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
150

rusty1s's avatar
rusty1s committed
151
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
152
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
153
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
154
155
156
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
157
    if (K == 1 && avg_length <= 4) {
rusty1s's avatar
rusty1s committed
158
159
      segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
          src_data, indptr_info, out_data, N, E);
rusty1s's avatar
rusty1s committed
160
    } else if (K == 1 && avg_length <= 8) {
rusty1s's avatar
rusty1s committed
161
162
      segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>(
          src_data, indptr_info, out_data, N, E);
rusty1s's avatar
rusty1s committed
163
    } else if (K == 1 && avg_length <= 16) {
rusty1s's avatar
rusty1s committed
164
      segment_add_csr_kernel<scalar_t, 16>
rusty1s's avatar
rusty1s committed
165
166
          <<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, N, E);
rusty1s's avatar
rusty1s committed
167
    } else if (K == 1) {
rusty1s's avatar
rusty1s committed
168
      segment_add_csr_kernel<scalar_t, 32>
rusty1s's avatar
rusty1s committed
169
170
          <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, N, E);
rusty1s's avatar
rusty1s committed
171
172
173
174
175
    } else {
      segment_add_csr_broadcast_kernel<scalar_t>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
                                                     out_data, N, K, E);
    }
rusty1s's avatar
rusty1s committed
176
177
178
179
180
  });

  return out;
}

rusty1s's avatar
rusty1s committed
181
template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
182
183
184
185
186
__global__ void segment_add_coo_kernel(const scalar_t *src_data,
                                       const int64_t *index_data,
                                       scalar_t *out_data, size_t numel) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
187
  int lane_idx = thread_idx & (TB - 1);
rusty1s's avatar
rusty1s committed
188
189
190
191
192
193

  if (thread_idx < numel) {
    auto idx = __ldg(index_data + thread_idx);
    scalar_t val = src_data[thread_idx], tmp;

#pragma unroll
rusty1s's avatar
rusty1s committed
194
    for (int offset = 1; offset < TB; offset *= 2) {
rusty1s's avatar
rusty1s committed
195
      tmp = __shfl_up_sync(FULL_MASK, val, offset);
rusty1s's avatar
rusty1s committed
196
197
198
      int idx_next = __ldg(index_data + thread_idx - offset);
      // AT_ASSERTM(lane_idx < offset || idx <= idx_next);
      if (lane_idx >= offset && idx == idx_next) {
rusty1s's avatar
rusty1s committed
199
200
201
202
        val += tmp;
      }
    }

rusty1s's avatar
rusty1s committed
203
    if (lane_idx == TB - 1 || idx != __ldg(index_data + thread_idx + 1)) {
rusty1s's avatar
rusty1s committed
204
205
206
207
208
209
210
      atomAdd(out_data + idx, val);
    }
  }
}

void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto numel = src.numel();
rusty1s's avatar
rusty1s committed
211
  auto avg_length = (float)numel / (float)out.numel();
rusty1s's avatar
rusty1s committed
212
213
214
215
216
217
218

  auto index_data = index.DATA_PTR<int64_t>();
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
219
220
221
    segment_add_coo_kernel<scalar_t, 32>
        <<<BLOCKS(1, numel), THREADS, 0, stream>>>(src_data, index_data,
                                                   out_data, numel);
rusty1s's avatar
rusty1s committed
222
  });
rusty1s's avatar
rusty1s committed
223
224
}

rusty1s's avatar
rusty1s committed
225
226
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
227
228
229
  auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
  auto policy = thrust::cuda::par(allocator).on(stream);

rusty1s's avatar
rusty1s committed
230
231
  auto key = at::full_like(out, -1, out.options().dtype(at::kLong));

rusty1s's avatar
rusty1s committed
232
  auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
rusty1s's avatar
rusty1s committed
233
234
  auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());

rusty1s's avatar
rusty1s committed
235
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_thrust_kernel", [&] {
rusty1s's avatar
rusty1s committed
236
237
    auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
    auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
rusty1s's avatar
rusty1s committed
238

rusty1s's avatar
rusty1s committed
239
    thrust::reduce_by_key(policy, index_data, index_data + index.numel(),
rusty1s's avatar
rusty1s committed
240
                          src_data, key_data, out_data);
rusty1s's avatar
rusty1s committed
241
242
  });
}