padding_cuda.cu 8.68 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#include "padding_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024
#define FULL_MASK 0xffffffff
#define BLOCKS(N) (N + THREADS - 1) / THREADS

__global__ void bin_kernel(const int64_t *__restrict__ rowcount,
                           const int64_t *__restrict__ binptr,
                           int64_t *__restrict__ bin, int64_t *__restrict__ idx,
rusty1s's avatar
DONE  
rusty1s committed
14
15
16
17
                           int *__restrict__ node_size,
                           int *__restrict__ max_deg, const size_t B,
                           const size_t N) {

rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
25
26
27
28
  for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
       thread_idx < N; thread_idx += gridDim.x * blockDim.x) {

    int bin_idx = -1, deg = rowcount[thread_idx];
    for (ptrdiff_t b = 1; b <= B; b++) {
      if (deg < __ldg(binptr + b)) {
        bin_idx = b - 1;
        break;
      }
    }

rusty1s's avatar
DONE  
rusty1s committed
29
    if (bin_idx == -1) {
rusty1s's avatar
rusty1s committed
30
      bin_idx = B - 1;
rusty1s's avatar
DONE  
rusty1s committed
31
    }
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
DONE  
rusty1s committed
33
34
    int old = atomicAdd(node_size + bin_idx, 1);
    atomicMax(max_deg + bin_idx, deg);
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40

    bin[thread_idx] = bin_idx;
    idx[thread_idx] = old;
  }
}

rusty1s's avatar
DONE  
rusty1s committed
41
42
43
44
45
__global__ void info_kernel(const int *__restrict__ node_size,
                            const int *__restrict__ max_deg,
                            int *__restrict__ edge_size,
                            int *__restrict__ node_offset,
                            int *__restrict__ edge_offset, const size_t B) {
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
DONE  
rusty1s committed
47
48
49
50
51
52
53
  int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
  int bin_idx = thread_idx / 32;
  int lane_idx = thread_idx % 32;

  if (bin_idx <= B) { // Computes `node_offset` and `edge_offset`.
    int node_tmp = 0;
    int edge_tmp = 0;
rusty1s's avatar
rusty1s committed
54
55

    for (int i = lane_idx; i < bin_idx; i += 32) {
rusty1s's avatar
DONE  
rusty1s committed
56
57
      node_tmp += node_size[i];
      edge_tmp += node_size[i] * max_deg[i];
rusty1s's avatar
rusty1s committed
58
59
60
    }

    for (int i = 32 / 2; i > 0; i /= 2) {
rusty1s's avatar
DONE  
rusty1s committed
61
62
      node_tmp += __shfl_down_sync(FULL_MASK, node_tmp, i);
      edge_tmp += __shfl_down_sync(FULL_MASK, edge_tmp, i);
rusty1s's avatar
rusty1s committed
63
64
    }

rusty1s's avatar
DONE  
rusty1s committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    if (lane_idx == 0) {
      node_offset[bin_idx] = node_tmp;
      edge_offset[bin_idx] = edge_tmp;
    }
  } else if (bin_idx == B + 1) { // Computes `edge_size`.
    for (int i = lane_idx; i < B; i += 32) {
      edge_size[i] = node_size[i] * max_deg[i];
    }
  }
}

__global__ void node_perm_kernel(const int64_t *__restrict__ bin,
                                 const int64_t *__restrict__ idx,
                                 const int *__restrict__ node_offset,
                                 int64_t *__restrict__ out, const size_t N) {

  for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
       thread_idx < N; thread_idx += gridDim.x * blockDim.x) {
    out[__ldg(node_offset + bin[thread_idx]) + idx[thread_idx]] = thread_idx;
rusty1s's avatar
rusty1s committed
84
85
86
87
88
  }
}

template <int TB>
__global__ void padded_index_kernel(
rusty1s's avatar
DONE  
rusty1s committed
89
90
91
92
93
    const int64_t *__restrict__ rowptr, const int64_t *__restrict__ col,
    const int64_t *__restrict__ rowcount, const int64_t *__restrict__ bin,
    const int64_t *__restrict__ idx, const int *__restrict__ max_deg,
    const int *__restrict__ edge_offset, int64_t *__restrict__ row_perm,
    int64_t *__restrict__ col_perm, bool *__restrict__ edge_mask,
rusty1s's avatar
rusty1s committed
94
95
96
97
98
99
100
101
102
    const size_t B, const size_t N) {

  for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
       thread_idx < TB * N; thread_idx += gridDim.x * blockDim.x) {

    int row_idx = thread_idx / TB;
    int lane_idx = thread_idx % TB;

    int64_t bin_idx = bin[row_idx];
rusty1s's avatar
DONE  
rusty1s committed
103
104
    int len = __ldg(max_deg + bin_idx);
    int off = __ldg(edge_offset + bin_idx) + len * idx[row_idx];
rusty1s's avatar
rusty1s committed
105
106

    int64_t row_start = rowptr[row_idx], deg = rowcount[row_idx];
rusty1s's avatar
DONE  
rusty1s committed
107
    int64_t row_tmp, col_tmp;
rusty1s's avatar
rusty1s committed
108
    for (int i = lane_idx; i < len; i += TB) {
rusty1s's avatar
DONE  
rusty1s committed
109
110
111
112
113
114
115
116
      row_tmp = -1, col_tmp = -1;
      if (i < deg) {
        row_tmp = row_idx;
        col_tmp = col[row_start + i];
      }
      row_perm[off + i] = row_tmp;
      col_perm[off + i] = col_tmp;
      edge_mask[off + i] = row_tmp == -1;
rusty1s's avatar
rusty1s committed
117
118
119
120
121
    }
  }
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
rusty1s's avatar
DONE  
rusty1s committed
122
123
124
125
126
127
128
129
           std::vector<int64_t>, std::vector<int64_t>>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
                  torch::Tensor rowcount, torch::Tensor binptr) {
  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  CHECK_CUDA(rowcount);
  CHECK_CUDA(binptr);
  CHECK_INPUT(rowptr.numel() == rowcount.numel() + 1);
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
rusty1s committed
131
132
133
134
135
136
137
138
139
140
  cudaSetDevice(rowcount.get_device());
  auto stream = at::cuda::getCurrentCUDAStream();
  size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;

  size_t B = binptr.numel() - 1;
  size_t N = rowcount.numel();

  auto bin = torch::empty(N, rowptr.options());
  auto idx = torch::empty(N, rowptr.options());

rusty1s's avatar
DONE  
rusty1s committed
141
142
143
144
145
146
  auto d_info = torch::zeros(5 * B + 2, col.options().dtype(torch::kInt));
  auto d_node_size = d_info.narrow(0, 0, B);
  auto d_edge_size = d_info.narrow(0, B, B);
  auto d_max_deg = d_info.narrow(0, 2 * B, B);
  auto d_node_offset = d_info.narrow(0, 3 * B, B + 1);
  auto d_edge_offset = d_info.narrow(0, 4 * B + 1, B + 1);
rusty1s's avatar
rusty1s committed
147
148
149

  bin_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
      rowcount.data_ptr<int64_t>(), binptr.data_ptr<int64_t>(),
rusty1s's avatar
DONE  
rusty1s committed
150
151
      bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
      d_node_size.data_ptr<int>(), d_max_deg.data_ptr<int>(), B, N);
rusty1s's avatar
rusty1s committed
152

rusty1s's avatar
DONE  
rusty1s committed
153
154
155
156
  info_kernel<<<BLOCKS(32 * (B + 2)), THREADS, 0, stream>>>(
      d_node_size.data_ptr<int>(), d_max_deg.data_ptr<int>(),
      d_edge_size.data_ptr<int>(), d_node_offset.data_ptr<int>(),
      d_edge_offset.data_ptr<int>(), B);
rusty1s's avatar
rusty1s committed
157

rusty1s's avatar
DONE  
rusty1s committed
158
  auto node_perm = torch::empty(N, rowptr.options());
rusty1s's avatar
rusty1s committed
159

rusty1s's avatar
DONE  
rusty1s committed
160
161
162
163
164
165
166
167
168
169
170
171
172
  node_perm_kernel<<<std::min(BLOCKS(N), mpc * 8), THREADS, 0, stream>>>(
      bin.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
      d_node_offset.data_ptr<int>(), node_perm.data_ptr<int64_t>(), N);

  auto h_info = torch::empty(
      d_info.numel(), d_info.options().device(torch::kCPU).pinned_memory(true));
  cudaMemcpy(h_info.data_ptr<int>(), d_info.data_ptr<int>(),
             d_info.numel() * sizeof(int), cudaMemcpyDeviceToHost);

  size_t E = h_info.data_ptr<int>()[5 * B + 1];
  auto row_perm = torch::empty(E, col.options());
  auto col_perm = torch::empty(E, col.options());
  auto edge_mask = torch::empty(E, col.options().dtype(torch::kBool));
rusty1s's avatar
rusty1s committed
173
174
175

  padded_index_kernel<8>
      <<<std::min(BLOCKS(N * 8), mpc * 8), THREADS, 0, stream>>>(
rusty1s's avatar
DONE  
rusty1s committed
176
177
178
179
180
181
182
183
184
185
186
187
188
          rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
          rowcount.data_ptr<int64_t>(), bin.data_ptr<int64_t>(),
          idx.data_ptr<int64_t>(), d_max_deg.data_ptr<int>(),
          d_edge_offset.data_ptr<int>(), row_perm.data_ptr<int64_t>(),
          col_perm.data_ptr<int64_t>(), edge_mask.data_ptr<bool>(), B, N);

  h_info = h_info.to(torch::kLong);
  auto h_info_data = h_info.data_ptr<int64_t>();
  std::vector<int64_t> node_sizes(h_info_data, h_info_data + B);
  std::vector<int64_t> edge_sizes(h_info_data + B, h_info_data + 2 * B);

  return std::make_tuple(node_perm, row_perm, col_perm, edge_mask, node_sizes,
                         edge_sizes);
rusty1s's avatar
rusty1s committed
189
190
191
192
193
194
195
}

template <typename scalar_t>
__global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
                                           const int64_t *__restrict__ index,
                                           scalar_t *__restrict__ out,
                                           const scalar_t fill_value,
rusty1s's avatar
DONE  
rusty1s committed
196
                                           const size_t E, const size_t F) {
rusty1s's avatar
rusty1s committed
197
198
199
200
201
202
203
204

  for (ptrdiff_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
       thread_idx < E * F; thread_idx += gridDim.x * blockDim.x) {

    int64_t row_idx = thread_idx / F;
    int64_t lane_idx = thread_idx % F;
    int64_t index_idx = __ldg(index + row_idx);

rusty1s's avatar
rusty1s committed
205
    scalar_t tmp = fill_value;
rusty1s's avatar
rusty1s committed
206
    if (index_idx != -1) {
rusty1s's avatar
DONE  
rusty1s committed
207
      tmp = src[index_idx * F + lane_idx];
rusty1s's avatar
rusty1s committed
208
209
210
211
212
213
    }

    out[thread_idx] = tmp;
  }
}

rusty1s's avatar
DONE  
rusty1s committed
214
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
rusty1s's avatar
rusty1s committed
215
                                       torch::Tensor fill_value) {
rusty1s's avatar
DONE  
rusty1s committed
216
217
218
219
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  CHECK_INPUT(src.dim() == 2);
  CHECK_INPUT(index.dim() == 1);
rusty1s's avatar
rusty1s committed
220

rusty1s's avatar
rusty1s committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
  cudaSetDevice(src.get_device());
  auto stream = at::cuda::getCurrentCUDAStream();
  size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;

  size_t E = index.numel();
  size_t F = src.size(-1);

  auto out = torch::empty(E * F, src.options());

  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "padded_index_select_kernel", [&] {
    scalar_t *fill;
    if (fill_value.is_cuda()) {
      fill = (scalar_t *)malloc(sizeof(scalar_t));
      cudaMemcpy(fill, fill_value.data_ptr<scalar_t>(), sizeof(scalar_t),
                 cudaMemcpyDeviceToHost);
    } else {
      fill = fill_value.data_ptr<scalar_t>();
    }

    padded_index_select_kernel<scalar_t>
        <<<std::min(BLOCKS(E * F), mpc * 8), THREADS, 0, stream>>>(
rusty1s's avatar
DONE  
rusty1s committed
242
243
            src.data_ptr<scalar_t>(), index.data_ptr<int64_t>(),
            out.data_ptr<scalar_t>(), fill[0], E, F);
rusty1s's avatar
rusty1s committed
244
245
246
247
  });

  return out;
}