rowwise_sampling.cc 18.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/rowwise_sampling.cc
 * \brief rowwise sampling
 */
#include <dgl/random.h>
#include <numeric>
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
                              IdxType off, IdxType len) {
  const FloatType* array_data = static_cast<FloatType*>(array->data);
  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
  FloatType* ret_data = static_cast<FloatType*>(ret->data);
  for (int64_t j = 0; j < len; ++j) {
    if (idx_data)
      ret_data[j] = array_data[idx_data[off + j]];
    else
      ret_data[j] = array_data[off + j];
  }
  return ret;
}

30
31
32
33
template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
  NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace]
34
    (IdxType rowid, IdxType off, IdxType len,
35
36
37
38
39
     const IdxType* col, const IdxType* data) {
      const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
      const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
      IdxType nnz = 0;
      for (IdxType i = off; i < off + len; ++i) {
40
41
        const IdxType eid = data ? data[i] : i;
        if (prob_or_mask_data[eid] > 0) {
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
          ++nnz;
        }
      }

      if (replace) {
        return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
      } else {
        return std::min(static_cast<IdxType>(max_num_picks), nnz);
      }
    };
  return num_picks_fn;
}

template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
  PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
60
61
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
62
63
64
65
      NDArray prob_or_mask_selected = DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
      RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
          num_picks, prob_or_mask_selected, out_idx, replace);
      for (int64_t j = 0; j < num_picks; ++j) {
66
        out_idx[j] += off;
67
68
69
70
71
      }
    };
  return pick_fn;
}

72
template <typename IdxType, typename FloatType>
73
74
75
76
inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(
    const std::vector<int64_t>& num_samples,
    const std::vector<FloatArray>& prob, bool replace) {
  EtypeRangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
77
    (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
78
    const std::vector<IdxType> &et_idx,
79
80
81
82
83
84
    const std::vector<IdxType> &et_eid,
    const IdxType* eid, IdxType* out_idx) {
      const FloatArray& p = prob[cur_et];
      const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
      FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
      FloatType* probs_data = probs.Ptr<FloatType>();
85
      for (int64_t j = 0; j < et_len; ++j) {
86
87
        const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
        probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
88
89
90
      }

      RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
91
          num_samples[cur_et], probs, out_idx, replace);
92
93
94
95
    };
  return pick_fn;
}

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
    int64_t num_samples, bool replace) {
  NumPicksFn<IdxType> num_picks_fn = [num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data) {
      const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
      if (replace) {
        return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
      } else {
        return std::min(static_cast<IdxType>(max_num_picks), len);
      }
    };
  return num_picks_fn;
}

112
113
114
115
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
    int64_t num_samples, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, replace]
116
    (IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
117
118
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
119
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
120
121
          num_picks, len, out_idx, replace);
      for (int64_t j = 0; j < num_picks; ++j) {
122
        out_idx[j] += off;
123
124
125
126
      }
    };
  return pick_fn;
}
127

128
template <typename IdxType>
129
inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(
130
    const std::vector<int64_t>& num_samples, bool replace) {
131
  EtypeRangePickFn<IdxType> pick_fn = [num_samples, replace]
132
    (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
133
    const std::vector<IdxType> &et_idx,
134
    const std::vector<IdxType> &et_eid,
135
136
    const IdxType* data, IdxType* out_idx) {
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
137
          num_samples[cur_et], et_len, out_idx, replace);
138
139
140
141
    };
  return pick_fn;
}

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
  NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data) {
      const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
      const int64_t num_tags = split->shape[1] - 1;
      const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
      const FloatType* bias_data = bias.Ptr<FloatType>();
      IdxType nnz = 0;
      for (int64_t j = 0; j < num_tags; ++j) {
        if (bias_data[j] > 0) {
          nnz += tag_offset[j + 1] - tag_offset[j];
        }
      }

      if (replace) {
        return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
      } else {
        return std::min(static_cast<IdxType>(max_num_picks), nnz);
      }
    };
  return num_picks_fn;
}

168
169
170
171
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace]
172
    (IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
173
174
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
175
    const IdxType *tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
176
    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
177
178
            num_picks, tag_offset, bias, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
179
180
181
182
183
184
      out_idx[j] += off;
    }
  };
  return pick_fn;
}

185
186
187
188
}  // namespace

/////////////////////////////// CSR ///////////////////////////////

189
template <DGLDeviceType XPU, typename IdxType, typename DType>
190
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
191
192
193
194
195
196
197
198
199
                             NDArray prob_or_mask, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
  auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  auto pick_fn = GetSamplingPickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
200
201
}

202
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
203
    CSRMatrix, IdArray, int64_t, NDArray, bool);
204
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
205
    CSRMatrix, IdArray, int64_t, NDArray, bool);
206
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
207
    CSRMatrix, IdArray, int64_t, NDArray, bool);
208
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
209
210
211
212
213
214
215
216
217
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
218

219
220
221
222
223
224
225
226
227
228
229
230
231
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
    bool replace, bool rowwise_etype_sorted) {
  CHECK(prob_or_mask.size() == num_samples.size()) <<
    "the number of probability tensors does not match the number of edge types.";
  for (auto& p : prob_or_mask)
    CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  return CSRRowWisePerEtypePick<IdxType, DType>(
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn,
      prob_or_mask);
232
233
}

234
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
235
236
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
237
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
238
239
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
240
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
241
242
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
243
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool, bool);
258

259
template <DGLDeviceType XPU, typename IdxType>
260
261
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
262
263
264
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
265
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
266
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
267
268
}

269
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
270
    CSRMatrix, IdArray, int64_t, bool);
271
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
272
273
    CSRMatrix, IdArray, int64_t, bool);

274
template <DGLDeviceType XPU, typename IdxType>
275
276
277
COOMatrix CSRRowWisePerEtypeSamplingUniform(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted) {
278
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
279
280
  return CSRRowWisePerEtypePick<IdxType, float>(
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn, {});
281
282
}

283
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
284
285
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool,
    bool);
286
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
287
288
    CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool,
    bool);
289

290
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
291
292
293
294
295
296
297
298
COOMatrix CSRRowWiseSamplingBiased(
    CSRMatrix mat,
    IdArray rows,
    int64_t num_samples,
    NDArray tag_offset,
    FloatArray bias,
    bool replace
) {
299
300
301
302
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
303
304
  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
305
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
306
307
}

308
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
309
310
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

311
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
312
313
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

314
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
315
316
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

317
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
318
319
320
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);


321
322
/////////////////////////////// COO ///////////////////////////////

323
template <DGLDeviceType XPU, typename IdxType, typename DType>
324
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
325
326
327
328
329
330
331
332
333
                             NDArray prob_or_mask, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
  auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  auto pick_fn = GetSamplingPickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
334
335
}

336
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
337
    COOMatrix, IdArray, int64_t, NDArray, bool);
338
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
339
    COOMatrix, IdArray, int64_t, NDArray, bool);
340
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
341
    COOMatrix, IdArray, int64_t, NDArray, bool);
342
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
343
344
345
346
347
348
349
350
351
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
352

353
354
355
356
357
358
359
360
361
362
363
364
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWisePerEtypeSampling(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
    bool replace) {
  CHECK(prob_or_mask.size() == num_samples.size()) <<
    "the number of probability tensors do not match the number of edge types.";
  for (auto& p : prob_or_mask)
    CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  return COORowWisePerEtypePick<IdxType, DType>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);
365
366
}

367
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
368
369
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
370
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
371
372
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
373
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
374
375
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
376
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
    const std::vector<NDArray>&, bool);
391

392
template <DGLDeviceType XPU, typename IdxType>
393
394
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
395
396
397
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
398
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
399
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
400
401
}

402
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
403
    COOMatrix, IdArray, int64_t, bool);
404
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
405
406
    COOMatrix, IdArray, int64_t, bool);

407
template <DGLDeviceType XPU, typename IdxType>
408
409
410
COOMatrix COORowWisePerEtypeSamplingUniform(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace) {
411
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
412
413
  return COORowWisePerEtypePick<IdxType, float>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, {});
414
415
}

416
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
417
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool);
418
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
419
    COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool);
420

421
422
423
}  // namespace impl
}  // namespace aten
}  // namespace dgl