"vscode:/vscode.git/clone" did not exist on "995c1913aa29b9f21dd6bd48cd65445c153d9c45"
rowwise_sampling.cu 12.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file array/cuda/rowwise_sampling.cu
 * @brief uniform rowwise sampling
5
6
 */

7
#include <curand_kernel.h>
8
9
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
10

11
12
#include <numeric>

13
#include "../../array/cuda/atomic.cuh"
14
#include "../../runtime/cuda/cuda_common.h"
15
#include "./dgl_cub.cuh"
16
#include "./utils.h"
17
18

namespace dgl {
19
20
using namespace cuda;
using namespace aten::cuda;
21
22
23
24
25
namespace aten {
namespace impl {

namespace {

26
constexpr int BLOCK_SIZE = 128;
27
28

/**
29
30
31
32
33
34
35
36
37
38
39
 * @brief Compute the size of each row in the sampled CSR, without replacement.
 *
 * @tparam IdType The type of node and edge indexes.
 * @param num_picks The number of non-zero entries to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The index where each row's edges start.
 * @param out_deg The size of each row in the sampled matrix, as indexed by
 * `in_rows` (output).
 */
template <typename IdType>
40
__global__ void _CSRRowWiseSampleDegreeKernel(
41
42
43
    const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    IdType* const out_deg) {
44
  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
45
46
47
48

  if (tIdx < num_rows) {
    const int in_row = in_rows[tIdx];
    const int out_row = tIdx;
49
50
    out_deg[out_row] = min(
        static_cast<IdType>(num_picks), in_ptr[in_row + 1] - in_ptr[in_row]);
51

52
    if (out_row == num_rows - 1) {
53
54
55
56
57
58
59
      // make the prefixsum work
      out_deg[num_rows] = 0;
    }
  }
}

/**
60
61
62
63
64
65
66
67
68
69
70
 * @brief Compute the size of each row in the sampled CSR, with replacement.
 *
 * @tparam IdType The type of node and edge indexes.
 * @param num_picks The number of non-zero entries to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The index where each row's edges start.
 * @param out_deg The size of each row in the sampled matrix, as indexed by
 * `in_rows` (output).
 */
template <typename IdType>
71
__global__ void _CSRRowWiseSampleDegreeReplaceKernel(
72
73
74
    const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    IdType* const out_deg) {
75
  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
76
77
78
79
80

  if (tIdx < num_rows) {
    const int64_t in_row = in_rows[tIdx];
    const int64_t out_row = tIdx;

81
    if (in_ptr[in_row + 1] - in_ptr[in_row] == 0) {
82
83
84
85
86
      out_deg[out_row] = 0;
    } else {
      out_deg[out_row] = static_cast<IdType>(num_picks);
    }

87
    if (out_row == num_rows - 1) {
88
89
90
91
92
93
94
      // make the prefixsum work
      out_deg[num_rows] = 0;
    }
  }
}

/**
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
 * @brief Perform row-wise uniform sampling on a CSR matrix,
 * and generate a COO matrix, without replacement.
 *
 * @tparam IdType The ID type used for matrices.
 * @tparam TILE_SIZE The number of rows covered by each threadblock.
 * @param rand_seed The random seed to use.
 * @param num_picks The number of non-zeros to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The indptr array of the input CSR.
 * @param in_index The indices array of the input CSR.
 * @param data The data array of the input CSR.
 * @param out_ptr The offset to write each row to in the output COO.
 * @param out_rows The rows of the output COO (output).
 * @param out_cols The columns of the output COO (output).
 * @param out_idxs The data array of the output COO (output).
 */
template <typename IdType, int TILE_SIZE>
113
__global__ void _CSRRowWiseSampleUniformKernel(
114
115
116
117
118
    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    const IdType* const in_index, const IdType* const data,
    const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,
    IdType* const out_idxs) {
119
  // we assign one warp per row
120
  assert(blockDim.x == BLOCK_SIZE);
121

122
  int64_t out_row = blockIdx.x * TILE_SIZE;
123
124
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
125

126
  curandStatePhilox4_32_10_t rng;
127
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
128

129
  while (out_row < last_row) {
130
131
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
132
    const int64_t deg = in_ptr[row + 1] - in_row_start;
133
134
135
    const int64_t out_row_start = out_ptr[out_row];

    if (deg <= num_picks) {
136
137
138
139
140
141
      // just copy row when there is not enough nodes to sample.
      for (int idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        const IdType in_idx = in_row_start + idx;
        out_rows[out_row_start + idx] = row;
        out_cols[out_row_start + idx] = in_index[in_idx];
        out_idxs[out_row_start + idx] = data ? data[in_idx] : in_idx;
142
143
144
      }
    } else {
      // generate permutation list via reservoir algorithm
145
146
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        out_idxs[out_row_start + idx] = idx;
147
      }
148
      __syncthreads();
149

150
151
      for (int idx = num_picks + threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        const int num = curand(&rng) % (idx + 1);
152
153
154
        if (num < num_picks) {
          // use max so as to achieve the replacement order the serial
          // algorithm would have
155
          AtomicMax(out_idxs + out_row_start + num, idx);
156
157
        }
      }
158
      __syncthreads();
159
160

      // copy permutation over
161
162
163
164
165
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        const IdType perm_idx = out_idxs[out_row_start + idx] + in_row_start;
        out_rows[out_row_start + idx] = row;
        out_cols[out_row_start + idx] = in_index[perm_idx];
        out_idxs[out_row_start + idx] = data ? data[perm_idx] : perm_idx;
166
167
      }
    }
168
    out_row += 1;
169
170
171
172
  }
}

/**
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
 * @brief Perform row-wise uniform sampling on a CSR matrix,
 * and generate a COO matrix, with replacement.
 *
 * @tparam IdType The ID type used for matrices.
 * @tparam TILE_SIZE The number of rows covered by each threadblock.
 * @param rand_seed The random seed to use.
 * @param num_picks The number of non-zeros to pick per row.
 * @param num_rows The number of rows to pick.
 * @param in_rows The set of rows to pick.
 * @param in_ptr The indptr array of the input CSR.
 * @param in_index The indices array of the input CSR.
 * @param data The data array of the input CSR.
 * @param out_ptr The offset to write each row to in the output COO.
 * @param out_rows The rows of the output COO (output).
 * @param out_cols The columns of the output COO (output).
 * @param out_idxs The data array of the output COO (output).
 */
template <typename IdType, int TILE_SIZE>
191
__global__ void _CSRRowWiseSampleUniformReplaceKernel(
192
193
194
195
196
    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
    const IdType* const in_rows, const IdType* const in_ptr,
    const IdType* const in_index, const IdType* const data,
    const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,
    IdType* const out_idxs) {
197
  // we assign one warp per row
198
  assert(blockDim.x == BLOCK_SIZE);
199

200
  int64_t out_row = blockIdx.x * TILE_SIZE;
201
202
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
203

204
  curandStatePhilox4_32_10_t rng;
205
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
206

207
  while (out_row < last_row) {
208
209
210
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
    const int64_t out_row_start = out_ptr[out_row];
211
    const int64_t deg = in_ptr[row + 1] - in_row_start;
212

213
214
    if (deg > 0) {
      // each thread then blindly copies in rows only if deg > 0.
215
      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
216
        const int64_t edge = curand(&rng) % deg;
217
        const int64_t out_idx = out_row_start + idx;
218
        out_rows[out_idx] = row;
219
        out_cols[out_idx] = in_index[in_row_start + edge];
220
221
        out_idxs[out_idx] =
            data ? data[in_row_start + edge] : in_row_start + edge;
222
      }
223
    }
224
    out_row += 1;
225
226
227
  }
}

228
}  // namespace
229
230

///////////////////////////// CSR sampling //////////////////////////
231

232
template <DGLDeviceType XPU, typename IdType>
233
234
COOMatrix _CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {
235
  const auto& ctx = rows->ctx;
236
  auto device = runtime::DeviceAPI::Get(ctx);
237
  cudaStream_t stream = runtime::getCurrentCUDAStream();
238
239

  const int64_t num_rows = rows->shape[0];
240
241
242
243
244
245
246
247
  const IdType* const slice_rows = static_cast<const IdType*>(rows->data);

  IdArray picked_row =
      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_col =
      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_idx =
      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
248
249
250
251
  IdType* const out_rows = static_cast<IdType*>(picked_row->data);
  IdType* const out_cols = static_cast<IdType*>(picked_col->data);
  IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);

252
253
254
255
256
  const IdType* in_ptr = static_cast<IdType*>(GetDevicePointer(mat.indptr));
  const IdType* in_cols = static_cast<IdType*>(GetDevicePointer(mat.indices));
  const IdType* data = CSRHasData(mat)
                           ? static_cast<IdType*>(GetDevicePointer(mat.data))
                           : nullptr;
257
258

  // compute degree
259
  IdType* out_deg = static_cast<IdType*>(
260
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
261
262
  if (replace) {
    const dim3 block(512);
263
264
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
265
266
        _CSRRowWiseSampleDegreeReplaceKernel, grid, block, 0, stream, num_picks,
        num_rows, slice_rows, in_ptr, out_deg);
267
268
  } else {
    const dim3 block(512);
269
270
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
271
272
        _CSRRowWiseSampleDegreeKernel, grid, block, 0, stream, num_picks,
        num_rows, slice_rows, in_ptr, out_deg);
273
274
275
  }

  // fill out_ptr
276
  IdType* out_ptr = static_cast<IdType*>(
277
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
278
  size_t prefix_temp_size = 0;
279
280
281
282
283
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
      nullptr, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
  void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
      prefix_temp, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
284
285
286
287
288
289
  device->FreeWorkspace(ctx, prefix_temp);
  device->FreeWorkspace(ctx, out_deg);

  cudaEvent_t copyEvent;
  CUDA_CALL(cudaEventCreate(&copyEvent));

290
291
  // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and
  // wait on a cudaevent
292
  IdType new_len;
293
  // copy using the internal current stream
294
295
296
  device->CopyDataFromTo(
      out_ptr, num_rows * sizeof(new_len), &new_len, 0, sizeof(new_len), ctx,
      DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
297
298
299
300
301
  CUDA_CALL(cudaEventRecord(copyEvent, stream));

  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);

  // select edges
302
303
304
305
306
307
  // the number of rows each thread block will cover
  constexpr int TILE_SIZE = 128 / BLOCK_SIZE;
  if (replace) {  // with replacement
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
308
309
310
        (_CSRRowWiseSampleUniformReplaceKernel<IdType, TILE_SIZE>), grid, block,
        0, stream, random_seed, num_picks, num_rows, slice_rows, in_ptr,
        in_cols, data, out_ptr, out_rows, out_cols, out_idxs);
311
312
313
314
  } else {  // without replacement
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
315
316
317
        (_CSRRowWiseSampleUniformKernel<IdType, TILE_SIZE>), grid, block, 0,
        stream, random_seed, num_picks, num_rows, slice_rows, in_ptr, in_cols,
        data, out_ptr, out_rows, out_cols, out_idxs);
318
319
320
321
322
323
324
325
326
327
328
  }
  device->FreeWorkspace(ctx, out_ptr);

  // wait for copying `new_len` to finish
  CUDA_CALL(cudaEventSynchronize(copyEvent));
  CUDA_CALL(cudaEventDestroy(copyEvent));

  picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
  picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
  picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);

329
330
  return COOMatrix(
      mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
331
332
}

333
334
335
336
337
338
339
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {
  if (num_picks == -1) {
    // Basically this is UnitGraph::InEdges().
    COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);
    IdArray sliced_rows = IndexSelect(rows, coo.row);
340
341
    return COOMatrix(
        mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
342
  } else {
343
344
    return _CSRRowWiseSamplingUniform<XPU, IdType>(
        mat, rows, num_picks, replace);
345
346
347
  }
}

348
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int32_t>(
349
    CSRMatrix, IdArray, int64_t, bool);
350
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int64_t>(
351
352
353
354
355
    CSRMatrix, IdArray, int64_t, bool);

}  // namespace impl
}  // namespace aten
}  // namespace dgl