degree_padding_cuda.cu 5.81 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#include "degree_padding_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
#define THREADS 1024
rusty1s's avatar
rusty1s committed
8
9
#define BLOCKS(N) (N + THREADS - 1) / THREADS

rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
__global__ void sizes_kernel(const int64_t *__restrict__ sorted_rowcount,
                             const int64_t *__restrict__ binptr,
                             int64_t *__restrict__ size,
                             int64_t *__restrict__ length,
                             const int64_t num_bins, const int64_t numel) {
  for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
       thread_idx < numel - 1; thread_idx += gridDim.x * blockDim.x) {
rusty1s's avatar
rusty1s committed
17

rusty1s's avatar
rusty1s committed
18
19
    int64_t deg1 = sorted_rowcount[thread_idx];
    int64_t deg2 = sorted_rowcount[thread_idx + 1];
rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
    if (deg1 != deg2) {
      for (int64_t b = 1; b <= num_bins; b++) {
        if (deg1 < __ldg(binptr + b) && deg2 >= __ldg(binptr + b)) {
          size[b] = thread_idx + 1;
          length[b - 1] = deg1;
        }
rusty1s's avatar
rusty1s committed
27
      }
rusty1s's avatar
rusty1s committed
28
29
30
31
32
    }

    if (thread_idx + 1 == numel - 1) {
      size[num_bins] = numel;
      length[num_bins - 1] = deg2;
rusty1s's avatar
rusty1s committed
33
34
35
36
    }
  }
}

rusty1s's avatar
rusty1s committed
37
38
std::tuple<std::vector<torch::Tensor>, std::vector<int64_t>>
bin_assignment_cuda(torch::Tensor rowcount, torch::Tensor binptr) {
rusty1s's avatar
rusty1s committed
39
  CHECK_CUDA(rowcount);
rusty1s's avatar
rusty1s committed
40
  CHECK_CUDA(binptr);
rusty1s's avatar
rusty1s committed
41
  CHECK_INPUT(rowcount.dim() == 1);
rusty1s's avatar
rusty1s committed
42
43
  CHECK_INPUT(binptr.dim() == 1);

rusty1s's avatar
rusty1s committed
44
  cudaSetDevice(rowcount.get_device());
rusty1s's avatar
rusty1s committed
45
46
  auto stream = at::cuda::getCurrentCUDAStream();
  int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
  torch::Tensor sorted_rowcount, perm;
  std::tie(sorted_rowcount, perm) = rowcount.sort();
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
  auto size = torch::zeros({binptr.numel()}, binptr.options());
  auto length = torch::zeros({binptr.numel() - 1}, binptr.options());
rusty1s's avatar
rusty1s committed
53

rusty1s's avatar
rusty1s committed
54
55
56
57
58
  sizes_kernel<<<std::min(BLOCKS(rowcount.numel() - 1), mpc * 8), THREADS, 0,
                 stream>>>(sorted_rowcount.data_ptr<int64_t>(),
                           binptr.data_ptr<int64_t>(), size.data_ptr<int64_t>(),
                           length.data_ptr<int64_t>(), length.numel(),
                           rowcount.numel());
rusty1s's avatar
rusty1s committed
59

rusty1s's avatar
rusty1s committed
60
61
62
  size = size.cpu();
  size = size.narrow(0, 1, length.numel()) - size.narrow(0, 0, length.numel());
  auto sizes = at::IntArrayRef(size.data_ptr<int64_t>(), size.numel());
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
65
66
67
  length = length.cpu();
  int64_t *length_data = length.data_ptr<int64_t>();
  std::vector<int64_t> lengths(length.numel());
  std::copy(length_data, length_data + length.numel(), lengths.begin());
rusty1s's avatar
rusty1s committed
68

rusty1s's avatar
rusty1s committed
69
  return std::make_tuple(perm.split_with_sizes(sizes), lengths);
rusty1s's avatar
rusty1s committed
70
71
}

rusty1s's avatar
rusty1s committed
72
73
74
75
__global__ void padded_mask_select_kernel(
    const int64_t *__restrict__ rowptr, const int64_t *__restrict__ col,
    const int64_t *__restrict__ index, int64_t *__restrict__ out_idx,
    bool *__restrict__ mask, const int64_t length, const int64_t numel) {
rusty1s's avatar
rusty1s committed
76
77
78
79
80
81
82
83
84
85
86
87

  int64_t lane_idx, row_idx, row_start, row_end, col_idx;
  for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
       thread_idx < numel; thread_idx += gridDim.x * blockDim.x) {
    lane_idx = thread_idx % length;
    row_idx = index[thread_idx / length];
    row_start = rowptr[row_idx];
    row_end = rowptr[row_idx + 1];
    col_idx = -1;
    if (lane_idx < row_end - row_start) {
      col_idx = col[row_start + lane_idx];
    }
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
90
91
92
    out_idx[thread_idx] = col_idx;
    mask[thread_idx] = col_idx == -1;
  }
}
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
rusty1s committed
94
template <typename scalar_t>
rusty1s's avatar
rusty1s committed
95
96
97
98
99
__global__ void
padded_index_select_kernel(const scalar_t *__restrict__ src,
                           const int64_t *__restrict__ index,
                           scalar_t *__restrict__ out, scalar_t fill_value,
                           const int64_t dim, const int64_t numel) {
rusty1s's avatar
rusty1s committed
100
101
102
103
104
105
106
107
108

  int64_t index_idx, dim_idx, col;
  for (int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
       thread_idx < numel; thread_idx += gridDim.x * blockDim.x) {
    index_idx = thread_idx / dim;
    dim_idx = thread_idx % dim;
    col = __ldg(index + index_idx);
    if (col >= 0) {
      fill_value = src[col * dim + dim_idx];
rusty1s's avatar
rusty1s committed
109
    }
rusty1s's avatar
rusty1s committed
110
111

    out[thread_idx] = fill_value;
rusty1s's avatar
rusty1s committed
112
113
114
115
116
  }
}

std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
rusty1s's avatar
rusty1s committed
117
118
                         torch::Tensor col, torch::Tensor index, int64_t length,
                         torch::Tensor fill_value) {
rusty1s's avatar
rusty1s committed
119
120
121
122
123
124
125
126
  CHECK_CUDA(src);
  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  CHECK_CUDA(index);
  CHECK_INPUT(src.dim() == 2);
  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  CHECK_INPUT(index.dim() == 1);
rusty1s's avatar
rusty1s committed
127
  CHECK_INPUT(fill_value.numel() == 1);
rusty1s's avatar
rusty1s committed
128

rusty1s's avatar
rusty1s committed
129
  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
130
131
  auto stream = at::cuda::getCurrentCUDAStream();
  int64_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
rusty1s's avatar
rusty1s committed
132

rusty1s's avatar
rusty1s committed
133
  auto out_idx = torch::empty({index.size(0), length}, index.options());
rusty1s's avatar
rusty1s committed
134
  auto out = torch::empty({index.size(0), length, src.size(-1)}, src.options());
rusty1s's avatar
rusty1s committed
135
136
  auto mask = torch::empty({index.size(0), length, 1},
                           src.options().dtype(torch::kBool));
rusty1s's avatar
rusty1s committed
137

rusty1s's avatar
rusty1s committed
138
139
140
141
142
143
  padded_mask_select_kernel<<<
      std::min((out_idx.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
      stream>>>(rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
                index.data_ptr<int64_t>(), out_idx.data_ptr<int64_t>(),
                mask.data_ptr<bool>(), length, out_idx.numel());

rusty1s's avatar
rusty1s committed
144
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
rusty1s's avatar
rusty1s committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    scalar_t *fill;
    if (fill_value.is_cuda()) {
      fill = (scalar_t *)malloc(sizeof(scalar_t));
      cudaMemcpy(fill, fill_value.data_ptr<scalar_t>(), sizeof(scalar_t),
                 cudaMemcpyDeviceToHost);
    } else {
      fill = fill_value.data_ptr<scalar_t>();
    }

    padded_index_select_kernel<scalar_t>
        <<<std::min((out.numel() + THREADS - 1) / THREADS, mpc * 8), THREADS, 0,
           stream>>>(src.data_ptr<scalar_t>(), out_idx.data_ptr<int64_t>(),
                     out.data_ptr<scalar_t>(), fill[0], src.size(-1),
                     out.numel());
rusty1s's avatar
rusty1s committed
159
160
161
162
  });

  return std::make_tuple(out, mask);
}