reordering.cu.h 14.9 KB
Newer Older
traveller59's avatar
traveller59 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef REORDERING_CU_H_
#define REORDERING_CU_H_
17
#include <THC/THCAtomics.cuh>
yanyan's avatar
yanyan committed
18
#include <THC/THCNumerics.cuh>
yanyan's avatar
yanyan committed
19
#include <cuda_fp16.h>
yanyan's avatar
yanyan committed
20
#include <tensorview/kernel_utils.h>
yanyan's avatar
yanyan committed
21
22
23
24
25
26
27

#if PYTORCH_VERSION < 10500
#define TH_ATOMIC_ADD atomicAdd
#else
#define TH_ATOMIC_ADD gpuAtomicAdd
#endif

traveller59's avatar
traveller59 committed
28
29
// see http://www.nvidia.com/content/GTC-2010/pdfs/2238_GTC2010.pdf.
namespace spconv {
30

traveller59's avatar
traveller59 committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void gatherGenericKernel(T *buffer, const T *features,
                                    const Index *indices, int size,
                                    int numPlanes) {
  int ILPStrideX[NumILP];
  Index inds[NumILP];
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;

  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ilp++) {
      if (ix + ILPStrideX[ilp] < size)
        inds[ilp] = indices[ix + ILPStrideX[ilp]] * numPlanes;
    }
    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
      for (int ilp = 0; ilp < NumILP; ++ilp) {
        if (ix + ILPStrideX[ilp] < size)
          buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
              features[inds[ilp] + iy];
      }
    }
  }
}

template <typename T, typename Index, int NumTLP, int NumILP, typename VecType>
__global__ void gatherVecKernel(T *buffer, const T *features,
                                const Index *indices, int size, int numPlanes) {
  int ILPStrideX[NumILP];
  Index inds[NumILP];
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;

  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ilp++) {
      if (ix + ILPStrideX[ilp] < size)
        inds[ilp] = indices[ix + ILPStrideX[ilp]] * numPlanes;
    }
    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
      for (int ilp = 0; ilp < NumILP; ++ilp) {
        if (ix + ILPStrideX[ilp] < size)
          reinterpret_cast<VecType *>(
              buffer)[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
              reinterpret_cast<const VecType *>(features)[inds[ilp] + iy];
      }
    }
  }
}

template <typename T, typename Index, int NumTLP, int NumILP,
          typename VecType = int4>
__global__ void gatherVecBlockKernel(T *buffer, const T *features,
                                     const Index *indices, int size,
                                     int numPlanes) {
yanyan's avatar
yanyan committed
90
  int ILPStrideX[NumILP];
traveller59's avatar
traveller59 committed
91
92
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
yanyan's avatar
yanyan committed
93
94
95
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  features += blockIdx.y * NumTLP;
  buffer += blockIdx.y * NumTLP;
traveller59's avatar
traveller59 committed
96

yanyan's avatar
yanyan committed
97
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
traveller59's avatar
traveller59 committed
98
99
100
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ++ilp) {
      reinterpret_cast<VecType *>(
yanyan's avatar
yanyan committed
101
          buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y] =
traveller59's avatar
traveller59 committed
102
          reinterpret_cast<const VecType *>(
yanyan's avatar
yanyan committed
103
104
              features)[indices[ix + ILPStrideX[ilp]] * numPlanes +
                        threadIdx.y];
traveller59's avatar
traveller59 committed
105
106
107
108
    }
  }
}

109
110
111
template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void batchGatherGenericKernel(T *buffer, const T *features,
                                         const Index *indices, int size,
112
                                         int numPlanes, int indice_batch_stride,
113
                                         int feature_batch_stride) {
114
115
  // size: max indice num * kernel volume
  // inds: [volume, num_elems]
116
117
  int ILPStrideX[NumILP];
  Index inds[NumILP];
118
  Index inds_elem;
119
120
121
122
123
124
125
126
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;

  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ilp++) {
      if (ix + ILPStrideX[ilp] < size) {
127
128
129
130
        inds_elem = ix + ILPStrideX[ilp];
        inds[ilp] =
            indices[(inds_elem / feature_batch_stride) * indice_batch_stride +
                    inds_elem % feature_batch_stride];
131
132
133
134
135
      }
    }
    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
      for (int ilp = 0; ilp < NumILP; ++ilp) {
yanyan's avatar
yanyan committed
136
137
138
139
140
141
142
143
144
        if (ix + ILPStrideX[ilp] < size) {
          if (inds[ilp] != -1) {
            buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
                features[inds[ilp] * numPlanes + iy];

          } else {
            buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy] = T(0);
          }
        }
145
146
147
148
149
150
      }
    }
  }
}

template <typename T, typename Index, int NumTLP, int NumILP, typename VecType>
yanyan's avatar
yanyan committed
151
152
153
154
__global__ void
batchGatherVecKernel(T *buffer, const T *features, const Index *indices,
                     int size, int feature_offset, int numPlanes,
                     int indice_batch_stride, int feature_batch_stride) {
155
156
  int ILPStrideX[NumILP];
  Index inds[NumILP];
yanyan's avatar
yanyan committed
157
158
159
160
161
162
  Index zero[sizeof(VecType) / sizeof(T)];
#pragma unroll
  for (int i = 0; i < sizeof(VecType) / sizeof(T); ++i) {
    zero[i] = T(0);
  }

163
  Index inds_elem;
164
165
166
167
168
169
170
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;

  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ilp++) {
171
172
173
174
175
176
      if (ix + ILPStrideX[ilp] < size) {
        inds_elem = ix + ILPStrideX[ilp] + feature_offset;
        inds[ilp] =
            indices[(inds_elem / feature_batch_stride) * indice_batch_stride +
                    inds_elem % feature_batch_stride];
      }
177
178
179
180
    }
    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
      for (int ilp = 0; ilp < NumILP; ++ilp) {
yanyan's avatar
yanyan committed
181
182
183
184
185
186
187
188
189
190
191
192
193
        if (ix + ILPStrideX[ilp] < size) {
          if (inds[ilp] != -1) {
            reinterpret_cast<VecType *>(
                buffer)[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
                reinterpret_cast<const VecType *>(
                    features)[inds[ilp] * numPlanes + iy];

          } else {
            reinterpret_cast<VecType *>(
                buffer)[(ix + ILPStrideX[ilp]) * numPlanes + iy] =
                reinterpret_cast<const VecType *>(&zero)[0];
          }
        }
194
195
196
197
198
199
200
      }
    }
  }
}

template <typename T, typename Index, int NumTLP, int NumILP,
          typename VecType = int4>
201
202
203
204
__global__ void
batchGatherVecBlockKernel(T *buffer, const T *features, const Index *indices,
                          int size, int numPlanes, int indice_batch_stride,
                          int feature_batch_stride) {
yanyan's avatar
yanyan committed
205
  int ILPStrideX[NumILP];
206
  Index inds;
207
208
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
yanyan's avatar
yanyan committed
209
210
211
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  features += blockIdx.y * NumTLP;
  buffer += blockIdx.y * NumTLP;
212

213
  Index inds_elem;
yanyan's avatar
yanyan committed
214
215
216
217
218
  Index zero[sizeof(VecType) / sizeof(T)];
#pragma unroll
  for (int i = 0; i < sizeof(VecType) / sizeof(T); ++i) {
    zero[i] = T(0);
  }
219

yanyan's avatar
yanyan committed
220
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
221

222
223
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ++ilp) {
yanyan's avatar
yanyan committed
224
      inds_elem = ix + ILPStrideX[ilp];
225
226
227
228
229
      inds = indices[(inds_elem / feature_batch_stride) * indice_batch_stride +
                     inds_elem % feature_batch_stride];

      if (inds != -1) {
        reinterpret_cast<VecType *>(
yanyan's avatar
yanyan committed
230
            buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y] =
231
            reinterpret_cast<const VecType *>(
yanyan's avatar
yanyan committed
232
233
234
235
236
                features)[inds * numPlanes + threadIdx.y];
      } else {
        reinterpret_cast<VecType *>(
            buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y] =
            reinterpret_cast<const VecType *>(&zero)[0];
237
      }
238
239
240
241
    }
  }
}

traveller59's avatar
traveller59 committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void scatterAddGenericKernel(T *outFeatures, const T *buffer,
                                        const Index *indices, int size,
                                        int numPlanes) {
  int ILPStrideX[NumILP];
  Index inds[NumILP];
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ilp++) {
      if (ix + ILPStrideX[ilp] < size)
        inds[ilp] = indices[ix + ILPStrideX[ilp]] * numPlanes;
    }
    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
      for (int ilp = 0; ilp < NumILP; ++ilp) {
        if (ix + ILPStrideX[ilp] < size) {
          outFeatures[inds[ilp] + iy] +=
              buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy];
        }
      }
    }
  }
}

template <typename T, typename Index, int NumTLP, int NumILP,
          typename VecType = int4>
__global__ void scatterAddVecBlockKernel(T *outFeatures, const T *buffer,
                                         const Index *indices, int size,
                                         int numPlanes) {
yanyan's avatar
yanyan committed
274
  int ILPStrideX[NumILP];
traveller59's avatar
traveller59 committed
275
  constexpr int vecloadFactor = sizeof(VecType) / sizeof(T);
yanyan's avatar
yanyan committed
276
277
  constexpr int vecloadHalf2Factor = sizeof(VecType) / sizeof(__half2);

traveller59's avatar
traveller59 committed
278
279
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
yanyan's avatar
yanyan committed
280
281
282
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  outFeatures += blockIdx.y * NumTLP;
  buffer += blockIdx.y * NumTLP;
traveller59's avatar
traveller59 committed
283
284
285
  T buf[vecloadFactor];
  T buf2[vecloadFactor];
  Index idx;
yanyan's avatar
yanyan committed
286
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
traveller59's avatar
traveller59 committed
287
288
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ++ilp) {
yanyan's avatar
yanyan committed
289
      idx = indices[ix + ILPStrideX[ilp]] * numPlanes + threadIdx.y;
traveller59's avatar
traveller59 committed
290
291
292
      reinterpret_cast<VecType *>(buf)[0] =
          reinterpret_cast<VecType *>(outFeatures)[idx];
      reinterpret_cast<VecType *>(buf2)[0] = reinterpret_cast<const VecType *>(
yanyan's avatar
yanyan committed
293
          buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y];
yanyan's avatar
yanyan committed
294
      if (std::is_same<T, at::Half>::value) {
yanyan's avatar
yanyan committed
295
#if __CUDA_ARCH__ >= 530
traveller59's avatar
traveller59 committed
296
#pragma unroll
yanyan's avatar
yanyan committed
297
298
299
300
301
        for (int i = 0; i < vecloadHalf2Factor; i++) {
          reinterpret_cast<__half2 *>(buf)[i] =
              __hadd2(reinterpret_cast<__half2 *>(buf)[i],
                      reinterpret_cast<__half2 *>(buf2)[i]);
        }
yanyan's avatar
yanyan committed
302
#else
yanyan's avatar
yanyan committed
303
304
305
306
307
#pragma unroll
        for (int i = 0; i < vecloadFactor; i++) {
          buf[i] += buf2[i];
        }
#endif
yanyan's avatar
yanyan committed
308
309
310
311
312
      } else {
#pragma unroll
        for (int i = 0; i < vecloadFactor; i++) {
          buf[i] += buf2[i];
        }
traveller59's avatar
traveller59 committed
313
314
315
316
317
318
319
      }
      reinterpret_cast<VecType *>(outFeatures)[idx] =
          reinterpret_cast<VecType *>(buf)[0];
    }
  }
}

yanyan's avatar
yanyan committed
320
321
template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void scatterAddBlockKernel(T *outFeatures, const T *buffer,
yanyan's avatar
yanyan committed
322
323
                                      const Index *indices, int size,
                                      int numPlanes) {
yanyan's avatar
yanyan committed
324
325
326
327
328
329
330
331
332
  int ILPStrideX[NumILP];
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  outFeatures += blockIdx.y * NumTLP;
  buffer += blockIdx.y * NumTLP;
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ++ilp) {
yanyan's avatar
yanyan committed
333
334
      outFeatures[indices[ix + ILPStrideX[ilp]] * numPlanes + threadIdx.y] +=
          buffer[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y];
yanyan's avatar
yanyan committed
335
336
337
338
    }
  }
}

yanyan's avatar
yanyan committed
339
#if __CUDA_ARCH__ >= 530
yanyan's avatar
yanyan committed
340
341
template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void scatterAddHalfBlockKernel(T *outFeatures, const T *buffer,
yanyan's avatar
yanyan committed
342
343
                                          const Index *indices, int size,
                                          int numPlanes) {
yanyan's avatar
yanyan committed
344
345
346
347
348
349
350
351
352
353
354
  int ILPStrideX[NumILP];
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  outFeatures += blockIdx.y * NumTLP;
  buffer += blockIdx.y * NumTLP;
  Index idx;
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ++ilp) {
      idx = indices[ix + ILPStrideX[ilp]] * numPlanes + threadIdx.y;
yanyan's avatar
yanyan committed
355
356
357
358
      reinterpret_cast<__half2 *>(outFeatures)[idx] = __hadd2(
          reinterpret_cast<__half2 *>(outFeatures)[idx],
          reinterpret_cast<__half2 *>(
              buffer)[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y]);
yanyan's avatar
yanyan committed
359
360
361
    }
  }
}
yanyan's avatar
yanyan committed
362
#endif
yanyan's avatar
yanyan committed
363

364
365
366
367
368
369
template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void batchScatterAddGenericKernel(T *outFeatures, const T *buffer,
                                             const Index *indices, int size,
                                             int feature_offset, int numPlanes,
                                             int indice_batch_stride,
                                             int feature_batch_stride) {
yanyan's avatar
yanyan committed
370
371
372
373
  // batch scatter add is greatly slower than native scatter when the number of
  // points is large. this may due to atomicAdd?
  // batch scatter add is greatly faster than native when the number of points
  // is small.
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
  int ILPStrideX[NumILP];
  Index inds[NumILP];
  Index inds_elem;
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ilp++) {
      if (ix + ILPStrideX[ilp] < size) {
        inds_elem = ix + ILPStrideX[ilp] + feature_offset;
        inds[ilp] =
            indices[(inds_elem / feature_batch_stride) * indice_batch_stride +
                    inds_elem % feature_batch_stride];
      }
    }
    for (int iy : tv::KernelLoopY<int>(numPlanes)) {
#pragma unroll
      for (int ilp = 0; ilp < NumILP; ++ilp) {
        if (ix + ILPStrideX[ilp] < size && inds[ilp] != -1) {
yanyan's avatar
yanyan committed
394
395
          TH_ATOMIC_ADD(outFeatures + inds[ilp] * numPlanes + iy,
                        buffer[(ix + ILPStrideX[ilp]) * numPlanes + iy]);
396
397
398
399
400
401
402
403
404
405
406
        }
      }
    }
  }
}

template <typename T, typename Index, int NumTLP, int NumILP>
__global__ void
batchScatterAddBlockKernel(T *outFeatures, const T *buffer,
                           const Index *indices, int size, int numPlanes,
                           int indice_batch_stride, int feature_batch_stride) {
yanyan's avatar
yanyan committed
407
  int ILPStrideX[NumILP];
408
409
#pragma unroll
  for (int ilp = 0; ilp < NumILP; ilp++)
yanyan's avatar
yanyan committed
410
411
412
    ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x;
  outFeatures += blockIdx.y * NumTLP;
  buffer += blockIdx.y * NumTLP;
413
  Index inds, inds_elem;
yanyan's avatar
yanyan committed
414
  for (int ix : tv::KernelLoopX<int, NumILP>(size)) {
415
416
#pragma unroll
    for (int ilp = 0; ilp < NumILP; ++ilp) {
yanyan's avatar
yanyan committed
417
      inds_elem = ix + ILPStrideX[ilp];
418
419
420
      inds = indices[(inds_elem / feature_batch_stride) * indice_batch_stride +
                     inds_elem % feature_batch_stride];
      if (inds != -1) {
yanyan's avatar
yanyan committed
421
422
        TH_ATOMIC_ADD(outFeatures + inds * numPlanes + threadIdx.y,
                      buffer[(ix + ILPStrideX[ilp]) * numPlanes + threadIdx.y]);
423
424
425
426
427
      }
    }
  }
}

traveller59's avatar
traveller59 committed
428
429
} // namespace spconv

yanyan's avatar
yanyan committed
430
431
#undef TH_ATOMIC_ADD

traveller59's avatar
traveller59 committed
432
#endif