convert_kernel.cu 1.89 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
2
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
3
4
5

#include "compat.cuh"

rusty1s's avatar
fixes  
rusty1s committed
6
#define THREADS 256
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

__global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
                               int64_t M, int64_t numel) {

  int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  if (thread_idx == 0) {
    for (int64_t i = 0; i <= ind_data[0]; i++)
      out_data[i] = 0;
  } else if (thread_idx < numel) {
    for (int64_t i = ind_data[thread_idx - 1]; i < ind_data[thread_idx]; i++)
      out_data[i + 1] = thread_idx;
  } else if (thread_idx == numel) {
    for (int64_t i = ind_data[numel - 1] + 1; i < M + 1; i++)
      out_data[i] = numel;
  }
}

rusty1s's avatar
rusty1s committed
25
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
rusty1s's avatar
rusty1s committed
26
27
  cudaSetDevice(ind.get_device());

rusty1s's avatar
rusty1s committed
28
  auto out = torch::empty(M + 1, ind.options());
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  auto ind_data = ind.DATA_PTR<int64_t>();
  auto out_data = out.DATA_PTR<int64_t>();
  auto stream = at::cuda::getCurrentCUDAStream();
  ind2ptr_kernel<<<(ind.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
                   stream>>>(ind_data, out_data, M, ind.numel());
  return out;
}

__global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
                               int64_t E, int64_t numel) {

  int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;

  if (thread_idx < numel) {
    int64_t idx = ptr_data[thread_idx], next_idx = ptr_data[thread_idx + 1];
    for (int64_t i = idx; i < next_idx; i++) {
      out_data[i] = thread_idx;
    }
  }
}

rusty1s's avatar
rusty1s committed
50
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
rusty1s's avatar
rusty1s committed
51
52
  cudaSetDevice(ptr.get_device());

rusty1s's avatar
rusty1s committed
53
  auto out = torch::empty(E, ptr.options());
rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
60
  auto ptr_data = ptr.DATA_PTR<int64_t>();
  auto out_data = out.DATA_PTR<int64_t>();
  auto stream = at::cuda::getCurrentCUDAStream();
  ptr2ind_kernel<<<(ptr.numel() + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
      ptr_data, out_data, E, ptr.numel());
  return out;
}