rowwise_sampling.cu 13 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file array/cuda/rowwise_sampling.cu
 * @brief uniform rowwise sampling
5
6
 */

7
#include <curand_kernel.h>
8
9
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
10

11
12
#include <numeric>

13
#include "../../array/cuda/atomic.cuh"
14
#include "../../runtime/cuda/cuda_common.h"
15
#include "./dgl_cub.cuh"
16

17
using namespace dgl::aten::cuda;
18
19
20
21
22
23
24

namespace dgl {
namespace aten {
namespace impl {

namespace {

25
constexpr int BLOCK_SIZE = 128;
26
27

/**
28
29
30
31
32
33
34
35
36
37
38
 * @brief Compute the size of each row in the sampled CSR, without replacement.
 *
 * @tparam IdType The type of node and edge indexes.
 * @param num_picks The number of non-zero entries to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The index where each row's edges start.
 * @param out_deg The size of each row in the sampled matrix, as indexed by
 * `in_rows` (output).
 */
template <typename IdType>
39
__global__ void _CSRRowWiseSampleDegreeKernel(
40
41
42
    const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    IdType* const out_deg) {
43
  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
44
45
46
47

  if (tIdx < num_rows) {
    const int in_row = in_rows[tIdx];
    const int out_row = tIdx;
48
49
    out_deg[out_row] = min(
        static_cast<IdType>(num_picks), in_ptr[in_row + 1] - in_ptr[in_row]);
50

51
    if (out_row == num_rows - 1) {
52
53
54
55
56
57
58
      // make the prefixsum work
      out_deg[num_rows] = 0;
    }
  }
}

/**
59
60
61
62
63
64
65
66
67
68
69
 * @brief Compute the size of each row in the sampled CSR, with replacement.
 *
 * @tparam IdType The type of node and edge indexes.
 * @param num_picks The number of non-zero entries to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The index where each row's edges start.
 * @param out_deg The size of each row in the sampled matrix, as indexed by
 * `in_rows` (output).
 */
template <typename IdType>
70
__global__ void _CSRRowWiseSampleDegreeReplaceKernel(
71
72
73
    const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    IdType* const out_deg) {
74
  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
75
76
77
78
79

  if (tIdx < num_rows) {
    const int64_t in_row = in_rows[tIdx];
    const int64_t out_row = tIdx;

80
    if (in_ptr[in_row + 1] - in_ptr[in_row] == 0) {
81
82
83
84
85
      out_deg[out_row] = 0;
    } else {
      out_deg[out_row] = static_cast<IdType>(num_picks);
    }

86
    if (out_row == num_rows - 1) {
87
88
89
90
91
92
93
      // make the prefixsum work
      out_deg[num_rows] = 0;
    }
  }
}

/**
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
 * @brief Perform row-wise uniform sampling on a CSR matrix,
 * and generate a COO matrix, without replacement.
 *
 * @tparam IdType The ID type used for matrices.
 * @tparam TILE_SIZE The number of rows covered by each threadblock.
 * @param rand_seed The random seed to use.
 * @param num_picks The number of non-zeros to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The indptr array of the input CSR.
 * @param in_index The indices array of the input CSR.
 * @param data The data array of the input CSR.
 * @param out_ptr The offset to write each row to in the output COO.
 * @param out_rows The rows of the output COO (output).
 * @param out_cols The columns of the output COO (output).
 * @param out_idxs The data array of the output COO (output).
 */
template <typename IdType, int TILE_SIZE>
112
__global__ void _CSRRowWiseSampleUniformKernel(
113
114
115
116
117
    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    const IdType* const in_index, const IdType* const data,
    const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,
    IdType* const out_idxs) {
118
  // we assign one warp per row
119
  assert(blockDim.x == BLOCK_SIZE);
120

121
  int64_t out_row = blockIdx.x * TILE_SIZE;
122
123
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
124

125
  curandStatePhilox4_32_10_t rng;
126
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
127

128
  while (out_row < last_row) {
129
130
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
131
    const int64_t deg = in_ptr[row + 1] - in_row_start;
132
133
134
    const int64_t out_row_start = out_ptr[out_row];

    if (deg <= num_picks) {
135
136
137
138
139
140
      // just copy row when there is not enough nodes to sample.
      for (int idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        const IdType in_idx = in_row_start + idx;
        out_rows[out_row_start + idx] = row;
        out_cols[out_row_start + idx] = in_index[in_idx];
        out_idxs[out_row_start + idx] = data ? data[in_idx] : in_idx;
141
142
143
      }
    } else {
      // generate permutation list via reservoir algorithm
144
145
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        out_idxs[out_row_start + idx] = idx;
146
      }
147
      __syncthreads();
148

149
150
      for (int idx = num_picks + threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        const int num = curand(&rng) % (idx + 1);
151
152
153
        if (num < num_picks) {
          // use max so as to achieve the replacement order the serial
          // algorithm would have
154
          AtomicMax(out_idxs + out_row_start + num, idx);
155
156
        }
      }
157
      __syncthreads();
158
159

      // copy permutation over
160
161
162
163
164
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        const IdType perm_idx = out_idxs[out_row_start + idx] + in_row_start;
        out_rows[out_row_start + idx] = row;
        out_cols[out_row_start + idx] = in_index[perm_idx];
        out_idxs[out_row_start + idx] = data ? data[perm_idx] : perm_idx;
165
166
      }
    }
167
    out_row += 1;
168
169
170
171
  }
}

/**
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
 * @brief Perform row-wise uniform sampling on a CSR matrix,
 * and generate a COO matrix, with replacement.
 *
 * @tparam IdType The ID type used for matrices.
 * @tparam TILE_SIZE The number of rows covered by each threadblock.
 * @param rand_seed The random seed to use.
 * @param num_picks The number of non-zeros to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The indptr array of the input CSR.
 * @param in_index The indices array of the input CSR.
 * @param data The data array of the input CSR.
 * @param out_ptr The offset to write each row to in the output COO.
 * @param out_rows The rows of the output COO (output).
 * @param out_cols The columns of the output COO (output).
 * @param out_idxs The data array of the output COO (output).
 */
template <typename IdType, int TILE_SIZE>
190
__global__ void _CSRRowWiseSampleUniformReplaceKernel(
191
192
193
194
195
    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    const IdType* const in_index, const IdType* const data,
    const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,
    IdType* const out_idxs) {
196
  // we assign one warp per row
197
  assert(blockDim.x == BLOCK_SIZE);
198

199
  int64_t out_row = blockIdx.x * TILE_SIZE;
200
201
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
202

203
  curandStatePhilox4_32_10_t rng;
204
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
205

206
  while (out_row < last_row) {
207
208
209
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
    const int64_t out_row_start = out_ptr[out_row];
210
    const int64_t deg = in_ptr[row + 1] - in_row_start;
211

212
213
    if (deg > 0) {
      // each thread then blindly copies in rows only if deg > 0.
214
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
215
        const int64_t edge = curand(&rng) % deg;
216
        const int64_t out_idx = out_row_start + idx;
217
        out_rows[out_idx] = row;
218
        out_cols[out_idx] = in_index[in_row_start + edge];
219
220
        out_idxs[out_idx] =
            data ? data[in_row_start + edge] : in_row_start + edge;
221
      }
222
    }
223
    out_row += 1;
224
225
226
  }
}

227
}  // namespace
228
229

///////////////////////////// CSR sampling //////////////////////////
230

231
template <DGLDeviceType XPU, typename IdType>
232
233
COOMatrix _CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {
234
  const auto& ctx = rows->ctx;
235
  auto device = runtime::DeviceAPI::Get(ctx);
236
  cudaStream_t stream = runtime::getCurrentCUDAStream();
237
238

  const int64_t num_rows = rows->shape[0];
239
240
241
242
243
244
245
246
  const IdType* const slice_rows = static_cast<const IdType*>(rows->data);

  IdArray picked_row =
      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_col =
      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_idx =
      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
247
248
249
250
  IdType* const out_rows = static_cast<IdType*>(picked_row->data);
  IdType* const out_cols = static_cast<IdType*>(picked_col->data);
  IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);

251
252
253
254
  const IdType* in_ptr = mat.indptr.Ptr<IdType>();
  const IdType* in_cols = mat.indices.Ptr<IdType>();
  const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
  if (mat.is_pinned) {
255
256
    CUDA_CALL(cudaHostGetDevicePointer(&in_ptr, mat.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(cudaHostGetDevicePointer(&in_cols, mat.indices.Ptr<IdType>(), 0));
257
    if (CSRHasData(mat)) {
258
      CUDA_CALL(cudaHostGetDevicePointer(&data, mat.data.Ptr<IdType>(), 0));
259
260
    }
  }
261
262

  // compute degree
263
  IdType* out_deg = static_cast<IdType*>(
264
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
265
266
  if (replace) {
    const dim3 block(512);
267
268
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
269
270
        _CSRRowWiseSampleDegreeReplaceKernel, grid, block, 0, stream, num_picks,
        num_rows, slice_rows, in_ptr, out_deg);
271
272
  } else {
    const dim3 block(512);
273
274
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
275
276
        _CSRRowWiseSampleDegreeKernel, grid, block, 0, stream, num_picks,
        num_rows, slice_rows, in_ptr, out_deg);
277
278
279
  }

  // fill out_ptr
280
  IdType* out_ptr = static_cast<IdType*>(
281
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
282
  size_t prefix_temp_size = 0;
283
284
285
286
287
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
      nullptr, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
  void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
      prefix_temp, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
288
289
290
291
292
293
  device->FreeWorkspace(ctx, prefix_temp);
  device->FreeWorkspace(ctx, out_deg);

  cudaEvent_t copyEvent;
  CUDA_CALL(cudaEventCreate(&copyEvent));

294
295
  // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and
  // wait on a cudaevent
296
  IdType new_len;
297
  // copy using the internal current stream
298
299
300
  device->CopyDataFromTo(
      out_ptr, num_rows * sizeof(new_len), &new_len, 0, sizeof(new_len), ctx,
      DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
301
302
303
304
305
  CUDA_CALL(cudaEventRecord(copyEvent, stream));

  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);

  // select edges
306
307
308
309
310
311
  // the number of rows each thread block will cover
  constexpr int TILE_SIZE = 128 / BLOCK_SIZE;
  if (replace) {  // with replacement
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
312
313
314
        (_CSRRowWiseSampleUniformReplaceKernel<IdType, TILE_SIZE>), grid, block,
        0, stream, random_seed, num_picks, num_rows, slice_rows, in_ptr,
        in_cols, data, out_ptr, out_rows, out_cols, out_idxs);
315
316
317
318
  } else {  // without replacement
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
319
320
321
        (_CSRRowWiseSampleUniformKernel<IdType, TILE_SIZE>), grid, block, 0,
        stream, random_seed, num_picks, num_rows, slice_rows, in_ptr, in_cols,
        data, out_ptr, out_rows, out_cols, out_idxs);
322
323
324
325
326
327
328
329
330
331
332
  }
  device->FreeWorkspace(ctx, out_ptr);

  // wait for copying `new_len` to finish
  CUDA_CALL(cudaEventSynchronize(copyEvent));
  CUDA_CALL(cudaEventDestroy(copyEvent));

  picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
  picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
  picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);

333
334
  return COOMatrix(
      mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
335
336
}

337
338
339
340
341
342
343
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {
  if (num_picks == -1) {
    // Basically this is UnitGraph::InEdges().
    COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);
    IdArray sliced_rows = IndexSelect(rows, coo.row);
344
345
    return COOMatrix(
        mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
346
  } else {
347
348
    return _CSRRowWiseSamplingUniform<XPU, IdType>(
        mat, rows, num_picks, replace);
349
350
351
  }
}

352
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int32_t>(
353
    CSRMatrix, IdArray, int64_t, bool);
354
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int64_t>(
355
356
357
358
359
    CSRMatrix, IdArray, int64_t, bool);

}  // namespace impl
}  // namespace aten
}  // namespace dgl