rowptr_kernel.cu 1.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include "compat.cuh"

#define THREADS 256

__global__ void rowptr_kernel(const int64_t *row_data, int64_t *out_data,
                              int64_t numel, int64_t size) {

  int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  if (thread_idx == 0) {
rusty1s's avatar
rusty1s committed
14
    for (int64_t i = 0; i <= row_data[0]; i++)
15
      out_data[i] = 0;
rusty1s's avatar
rusty1s committed
16
  } else if (thread_idx < numel) {
17
    for (int64_t i = row_data[thread_idx - 1]; i < row_data[thread_idx]; i++)
rusty1s's avatar
rusty1s committed
18
19
20
21
      out_data[i + 1] = thread_idx;
  } else if (thread_idx == numel) {
    for (int64_t i = row_data[numel - 1] + 1; i < size + 1; i++)
      out_data[i] = numel;
22
23
24
  }
}

rusty1s's avatar
rusty1s committed
25
at::Tensor rowptr_cuda(at::Tensor row, int64_t size) {
26
27
28
29
30
31
32
33
34
35
36
37
  AT_ASSERTM(row.dim() == 1, "Row needs to be one-dimensional");

  auto out = at::empty(size + 1, row.options());
  auto row_data = row.DATA_PTR<int64_t>();
  auto out_data = out.DATA_PTR<int64_t>();

  auto stream = at::cuda::getCurrentCUDAStream();
  rowptr_kernel<<<(row.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
                  stream>>>(row_data, out_data, row.numel(), size);

  return out;
}