"tests/vscode:/vscode.git/clone" did not exist on "fd2f970bdfd3eada289b4e19a3adcf2c352a4d8f"
rowwise_sampling.cu 12.6 KB
Newer Older
1
2
3
/*!
 *  Copyright (c) 2021 by Contributors
 * \file array/cuda/rowwise_sampling.cu
4
 * \brief uniform rowwise sampling
5
6
7
8
9
10
11
 */

#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <numeric>

Jinjing Zhou's avatar
Jinjing Zhou committed
12
#include "./dgl_cub.cuh"
13
#include "../../array/cuda/atomic.cuh"
14
15
#include "../../runtime/cuda/cuda_common.h"

16

17
using namespace dgl::aten::cuda;
18
19
20
21
22
23
24

namespace dgl {
namespace aten {
namespace impl {

namespace {

25
constexpr int BLOCK_SIZE = 128;
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

/**
* @brief Compute the size of each row in the sampled CSR, without replacement.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeKernel(
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    IdType * const out_deg) {
45
  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
46
47
48
49

  if (tIdx < num_rows) {
    const int in_row = in_rows[tIdx];
    const int out_row = tIdx;
50
    out_deg[out_row] = min(static_cast<IdType>(num_picks), in_ptr[in_row + 1] - in_ptr[in_row]);
51

52
    if (out_row == num_rows - 1) {
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
      // make the prefixsum work
      out_deg[num_rows] = 0;
    }
  }
}

/**
* @brief Compute the size of each row in the sampled CSR, with replacement.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeReplaceKernel(
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    IdType * const out_deg) {
77
  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
78
79
80
81
82

  if (tIdx < num_rows) {
    const int64_t in_row = in_rows[tIdx];
    const int64_t out_row = tIdx;

83
    if (in_ptr[in_row + 1] - in_ptr[in_row] == 0) {
84
85
86
87
88
      out_deg[out_row] = 0;
    } else {
      out_deg[out_row] = static_cast<IdType>(num_picks);
    }

89
    if (out_row == num_rows - 1) {
90
91
92
93
94
95
96
      // make the prefixsum work
      out_deg[num_rows] = 0;
    }
  }
}

/**
97
98
* @brief Perform row-wise uniform sampling on a CSR matrix,
* and generate a COO matrix, without replacement.
99
100
*
* @tparam IdType The ID type used for matrices.
101
* @tparam TILE_SIZE The number of rows covered by each threadblock.
102
103
104
105
106
107
108
109
110
111
112
113
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_index The indices array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
114
115
template<typename IdType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleUniformKernel(
116
117
118
119
120
121
122
123
124
125
126
127
    const uint64_t rand_seed,
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    const IdType * const in_index,
    const IdType * const data,
    const IdType * const out_ptr,
    IdType * const out_rows,
    IdType * const out_cols,
    IdType * const out_idxs) {
  // we assign one warp per row
128
  assert(blockDim.x == BLOCK_SIZE);
129

130
131
  int64_t out_row = blockIdx.x * TILE_SIZE;
  const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
132

133
  curandStatePhilox4_32_10_t rng;
134
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
135

136
  while (out_row < last_row) {
137
138
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
139
    const int64_t deg = in_ptr[row + 1] - in_row_start;
140
141
142
    const int64_t out_row_start = out_ptr[out_row];

    if (deg <= num_picks) {
143
144
145
146
147
148
      // just copy row when there is not enough nodes to sample.
      for (int idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        const IdType in_idx = in_row_start + idx;
        out_rows[out_row_start + idx] = row;
        out_cols[out_row_start + idx] = in_index[in_idx];
        out_idxs[out_row_start + idx] = data ? data[in_idx] : in_idx;
149
150
151
      }
    } else {
      // generate permutation list via reservoir algorithm
152
153
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        out_idxs[out_row_start + idx] = idx;
154
      }
155
      __syncthreads();
156

157
158
      for (int idx = num_picks + threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        const int num = curand(&rng) % (idx + 1);
159
160
161
        if (num < num_picks) {
          // use max so as to achieve the replacement order the serial
          // algorithm would have
162
          AtomicMax(out_idxs + out_row_start + num, idx);
163
164
        }
      }
165
      __syncthreads();
166
167

      // copy permutation over
168
169
170
171
172
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        const IdType perm_idx = out_idxs[out_row_start + idx] + in_row_start;
        out_rows[out_row_start + idx] = row;
        out_cols[out_row_start + idx] = in_index[perm_idx];
        out_idxs[out_row_start + idx] = data ? data[perm_idx] : perm_idx;
173
174
      }
    }
175
    out_row += 1;
176
177
178
179
  }
}

/**
180
181
* @brief Perform row-wise uniform sampling on a CSR matrix,
* and generate a COO matrix, with replacement.
182
183
*
* @tparam IdType The ID type used for matrices.
184
* @tparam TILE_SIZE The number of rows covered by each threadblock.
185
186
187
188
189
190
191
192
193
194
195
196
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_index The indices array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
197
198
template<typename IdType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleUniformReplaceKernel(
199
200
201
202
203
204
205
206
207
208
209
210
    const uint64_t rand_seed,
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    const IdType * const in_index,
    const IdType * const data,
    const IdType * const out_ptr,
    IdType * const out_rows,
    IdType * const out_cols,
    IdType * const out_idxs) {
  // we assign one warp per row
211
  assert(blockDim.x == BLOCK_SIZE);
212

213
214
  int64_t out_row = blockIdx.x * TILE_SIZE;
  const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
215

216
  curandStatePhilox4_32_10_t rng;
217
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
218

219
  while (out_row < last_row) {
220
221
222
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
    const int64_t out_row_start = out_ptr[out_row];
223
    const int64_t deg = in_ptr[row + 1] - in_row_start;
224

225
226
    if (deg > 0) {
      // each thread then blindly copies in rows only if deg > 0.
227
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
228
        const int64_t edge = curand(&rng) % deg;
229
        const int64_t out_idx = out_row_start + idx;
230
        out_rows[out_idx] = row;
231
232
        out_cols[out_idx] = in_index[in_row_start + edge];
        out_idxs[out_idx] = data ? data[in_row_start + edge] : in_row_start + edge;
233
      }
234
    }
235
    out_row += 1;
236
237
238
239
  }
}
}  // namespace

240
241

///////////////////////////// CSR sampling //////////////////////////
242

243
template <DGLDeviceType XPU, typename IdType>
244
245
246
247
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
                                    IdArray rows,
                                    const int64_t num_picks,
                                    const bool replace) {
248
  const auto& ctx = rows->ctx;
249
  auto device = runtime::DeviceAPI::Get(ctx);
250
  cudaStream_t stream = runtime::getCurrentCUDAStream();
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

  const int64_t num_rows = rows->shape[0];
  const IdType * const slice_rows = static_cast<const IdType*>(rows->data);

  IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  const IdType * const in_ptr = static_cast<const IdType*>(mat.indptr->data);
  const IdType * const in_cols = static_cast<const IdType*>(mat.indices->data);
  IdType* const out_rows = static_cast<IdType*>(picked_row->data);
  IdType* const out_cols = static_cast<IdType*>(picked_col->data);
  IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);

  const IdType* const data = CSRHasData(mat) ?
      static_cast<IdType*>(mat.data->data) : nullptr;

  // compute degree
  IdType * out_deg = static_cast<IdType*>(
269
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
270
271
  if (replace) {
    const dim3 block(512);
272
273
274
275
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
        _CSRRowWiseSampleDegreeReplaceKernel,
        grid, block, 0, stream,
276
277
278
        num_picks, num_rows, slice_rows, in_ptr, out_deg);
  } else {
    const dim3 block(512);
279
280
281
282
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
        _CSRRowWiseSampleDegreeKernel,
        grid, block, 0, stream,
283
284
285
286
287
        num_picks, num_rows, slice_rows, in_ptr, out_deg);
  }

  // fill out_ptr
  IdType * out_ptr = static_cast<IdType*>(
288
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
  size_t prefix_temp_size = 0;
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
      out_deg,
      out_ptr,
      num_rows+1,
      stream));
  void * prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
      out_deg,
      out_ptr,
      num_rows+1,
      stream));
  device->FreeWorkspace(ctx, prefix_temp);
  device->FreeWorkspace(ctx, out_deg);

  cudaEvent_t copyEvent;
  CUDA_CALL(cudaEventCreate(&copyEvent));

  // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
  // a cudaevent
  IdType new_len;
310
  // copy using the internal current stream
311
312
313
  device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
      sizeof(new_len),
      ctx,
314
      DGLContext{kDGLCPU, 0},
315
      mat.indptr->dtype);
316
317
318
319
320
  CUDA_CALL(cudaEventRecord(copyEvent, stream));

  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);

  // select edges
321
322
323
324
325
326
327
328
  // the number of rows each thread block will cover
  constexpr int TILE_SIZE = 128 / BLOCK_SIZE;
  if (replace) {  // with replacement
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
        (_CSRRowWiseSampleUniformReplaceKernel<IdType, TILE_SIZE>),
        grid, block, 0, stream,
329
330
331
332
333
334
335
336
337
338
339
        random_seed,
        num_picks,
        num_rows,
        slice_rows,
        in_ptr,
        in_cols,
        data,
        out_ptr,
        out_rows,
        out_cols,
        out_idxs);
340
341
342
343
344
345
  } else {  // without replacement
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
        (_CSRRowWiseSampleUniformKernel<IdType, TILE_SIZE>),
        grid, block, 0, stream,
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        random_seed,
        num_picks,
        num_rows,
        slice_rows,
        in_ptr,
        in_cols,
        data,
        out_ptr,
        out_rows,
        out_cols,
        out_idxs);
  }
  device->FreeWorkspace(ctx, out_ptr);

  // wait for copying `new_len` to finish
  CUDA_CALL(cudaEventSynchronize(copyEvent));
  CUDA_CALL(cudaEventDestroy(copyEvent));

  picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
  picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
  picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);

  return COOMatrix(mat.num_rows, mat.num_cols, picked_row,
      picked_col, picked_idx);
}

372
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int32_t>(
373
    CSRMatrix, IdArray, int64_t, bool);
374
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int64_t>(
375
376
377
378
379
    CSRMatrix, IdArray, int64_t, bool);

}  // namespace impl
}  // namespace aten
}  // namespace dgl