rowwise_pick.h 18.5 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/rowwise_pick.h
 * @brief Template implementation for rowwise pick operators.
5
6
7
8
9
 */
#ifndef DGL_ARRAY_CPU_ROWWISE_PICK_H_
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_

#include <dgl/array.h>
10
#include <dgl/runtime/parallel_for.h>
11
12
#include <dmlc/omp.h>

13
#include <algorithm>
14
15
#include <functional>
#include <memory>
16
17
#include <string>
#include <vector>
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

namespace dgl {
namespace aten {
namespace impl {

// User-defined function for picking elements from one row.
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
35
36
37
38
39
40
41
// @param rowid The row to pick from.
// @param off Starting offset of this row.
// @param len NNZ of the row.
// @param num_picks Number of picks on the row.
// @param col Pointer of the column indices.
// @param data Pointer of the data indices.
// @param out_idx Picked indices in [off, off + len).
42
43
template <typename IdxType>
using PickFn = std::function<void(
44
    IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
45
    const IdxType* col, const IdxType* data, IdxType* out_idx)>;
46

47
48
// User-defined function for determining the number of elements to pick from one
// row.
49
50
51
52
53
54
55
56
57
58
59
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
60
61
62
63
64
// @param rowid The row to pick from.
// @param off Starting offset of this row.
// @param len NNZ of the row.
// @param col Pointer of the column indices.
// @param data Pointer of the data indices.
65
66
template <typename IdxType>
using NumPicksFn = std::function<IdxType(
67
68
    IdxType rowid, IdxType off, IdxType len, const IdxType* col,
    const IdxType* data)>;
69

70
71
72
73
74
75
76
// User-defined function for picking elements from a range within a row.
//
// The column indices of each element is in
//   off + et_idx[et_offset+i]), where i is in [et_offset, et_offset+et_len)
//
// Similarly, the data indices are stored in
//   data[off+et_idx[et_offset+i])]
77
78
// Data index pointer could be NULL, which means data[i] ==
// off+et_idx[et_offset+i])
79
80
81
82
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
83
84
85
86
87
88
89
90
// @param off Starting offset of this row.
// @param et_offset Starting offset of this range.
// @param cur_et The edge type.
// @param et_len Length of the range.
// @param et_idx A map from local idx to column id.
// @param et_eid Edge-type-specific id array.
// @param eid Pointer of the homogenized edge id array.
// @param out_idx Picked indices in [et_offset, et_offset + et_len).
91
template <typename IdxType>
92
using EtypeRangePickFn = std::function<void(
93
    IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
94
95
    const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid,
    const IdxType* eid, IdxType* out_idx)>;
96

97
// Template for picking non-zero values row-wise. The implementation utilizes
98
99
// OpenMP parallelization on rows because each row performs computation
// independently.
100
template <typename IdxType>
101
102
103
COOMatrix CSRRowWisePick(
    CSRMatrix mat, IdArray rows, int64_t num_picks, bool replace,
    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
104
105
106
  using namespace aten;
  const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
  const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
107
108
  const IdxType* data =
      CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
109
110
111
  const IdxType* rows_data = static_cast<IdxType*>(rows->data);
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
112
  const auto& idtype = mat.indptr->dtype;
113
114
115
116
117
118
119
120
121

  // To leverage OMP parallelization, we create two arrays to store
  // picked src and dst indices. Each array is of length num_rows * num_picks.
  // For rows whose nnz < num_picks, the indices are padded with -1.
  //
  // We check whether all the given rows
  // have at least num_picks number of nnz when replace is false.
  //
  // If the check holds, remove -1 elements by remove_if operation, which simply
122
123
124
  // moves valid elements to the head of arrays and create a view of the
  // original array. The implementation consumes a little extra memory than the
  // actual requirement.
125
  //
126
127
  // Otherwise, directly use the row and col arrays to construct the result COO
  // matrix.
128
  //
129
130
  // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism
  // is more
131
  //   significant. (minjie)
132

133
134
  // Do not use omp_get_max_threads() since that doesn't work for compiling
  // without OpenMP.
135
136
137
  const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
  std::vector<int64_t> global_prefix(num_threads + 1, 0);

138
139
140
141
  // TODO(BarclayII) Using OMP parallel directly instead of using
  // runtime::parallel_for does not handle exceptions well (directly aborts when
  // an exception pops up). It runs faster though because there is less
  // scheduling.  Need to handle exceptions better.
142
  IdArray picked_row, picked_col, picked_idx;
143
144
145
146
#pragma omp parallel num_threads(num_threads)
  {
    const int thread_id = omp_get_thread_num();

147
148
    const int64_t start_i =
        thread_id * (num_rows / num_threads) +
149
        std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
150
151
    const int64_t end_i =
        (thread_id + 1) * (num_rows / num_threads) +
152
153
154
155
        std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
    assert(thread_id + 1 < num_threads || end_i == num_rows);

    const int64_t num_local = end_i - start_i;
156

157
158
159
160
161
    // make sure we don't have to pay initialization cost
    std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
    local_prefix[0] = 0;
    for (int64_t i = start_i; i < end_i; ++i) {
      // build prefix-sum
162
      const int64_t local_i = i - start_i;
163
      const IdxType rid = rows_data[i];
164
165
      IdxType len = num_picks_fn(
          rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
166
167
168
169
      local_prefix[local_i + 1] = local_prefix[local_i] + len;
    }
    global_prefix[thread_id + 1] = local_prefix[num_local];

170
171
#pragma omp barrier
#pragma omp master
172
173
    {
      for (int t = 0; t < num_threads; ++t) {
174
        global_prefix[t + 1] += global_prefix[t];
175
      }
176
177
178
      picked_row = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
179
    }
180

181
#pragma omp barrier
182
183
184
185
    IdxType* picked_rdata = picked_row.Ptr<IdxType>();
    IdxType* picked_cdata = picked_col.Ptr<IdxType>();
    IdxType* picked_idata = picked_idx.Ptr<IdxType>();

186
    const IdxType thread_offset = global_prefix[thread_id];
187

188
189
190
191
192
    for (int64_t i = start_i; i < end_i; ++i) {
      const IdxType rid = rows_data[i];

      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;
193
      if (len == 0) continue;
194
195
196

      const int64_t local_i = i - start_i;
      const int64_t row_offset = thread_offset + local_prefix[local_i];
197
198
      const int64_t num_picks =
          thread_offset + local_prefix[local_i + 1] - row_offset;
199

200
201
      pick_fn(
          rid, off, len, num_picks, indices, data, picked_idata + row_offset);
202
203
204
205
206
      for (int64_t j = 0; j < num_picks; ++j) {
        const IdxType picked = picked_idata[row_offset + j];
        picked_rdata[row_offset + j] = rid;
        picked_cdata[row_offset + j] = indices[picked];
        picked_idata[row_offset + j] = data ? data[picked] : picked;
207
208
      }
    }
209
210
  }

211
212
  const int64_t new_len = global_prefix.back();

213
  return COOMatrix(
214
      mat.num_rows, mat.num_cols,
215
216
217
      picked_row.CreateView({new_len}, picked_row->dtype),
      picked_col.CreateView({new_len}, picked_row->dtype),
      picked_idx.CreateView({new_len}, picked_row->dtype));
218
219
}

220
// Template for picking non-zero values row-wise. The implementation utilizes
221
222
// OpenMP parallelization on rows because each row performs computation
// independently.
223
template <typename IdxType, typename DType>
224
225
226
227
228
COOMatrix CSRRowWisePerEtypePick(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_picks, bool replace,
    bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
    const std::vector<NDArray>& prob_or_mask) {
229
  using namespace aten;
230
231
  const IdxType* indptr = mat.indptr.Ptr<IdxType>();
  const IdxType* indices = mat.indices.Ptr<IdxType>();
232
  const IdxType* eid = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;
233
  const IdxType* rows_data = rows.Ptr<IdxType>();
234
235
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
236
  const int64_t num_etypes = num_picks.size();
237
  const bool has_probs = (prob_or_mask.size() > 0);
238
239
240
241
  std::vector<IdArray> picked_rows(rows->shape[0]);
  std::vector<IdArray> picked_cols(rows->shape[0]);
  std::vector<IdArray> picked_idxs(rows->shape[0]);

242
  // Check if the number of picks have the same value.
243
244
  // If so, we can potentially speed up if we have a node with total number of
  // neighbors less than the given number of picks with replace=False.
245
246
247
248
249
250
  bool same_num_pick = true;
  int64_t num_pick_value = num_picks[0];
  for (int64_t num_pick : num_picks) {
    if (num_pick_value != num_pick) {
      same_num_pick = false;
      break;
251
    }
252
  }
253

254
  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
255
    for (size_t i = b; i < e; ++i) {
256
257
258
259
260
261
262
263
264
265
266
      const IdxType rid = rows_data[i];
      CHECK_LT(rid, mat.num_rows);
      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;

      // do something here
      if (len == 0) {
        picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        continue;
267
      }
268
269
270
271
272
273
274
275

      // fast path
      if (same_num_pick && len <= num_pick_value && !replace) {
        IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);
        IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdxType* cdata = cols.Ptr<IdxType>();
        IdxType* idata = idx.Ptr<IdxType>();
276
277

        int64_t k = 0;
278
        for (int64_t j = 0; j < len; ++j) {
279
280
          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
          auto it = std::upper_bound(
281
282
              eid2etype_offset.begin(), eid2etype_offset.end(),
              homogenized_eid);
283
          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
284
          const IdxType heterogenized_eid =
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
              homogenized_eid - eid2etype_offset[heterogenized_etype];

          if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
            // No probability array, select all
            cdata[k] = indices[off + j];
            idata[k] = homogenized_eid;
            ++k;
          } else {
            // Select the entries with non-zero probability
            const NDArray& p = prob_or_mask[heterogenized_etype];
            const DType* pdata = p.Ptr<DType>();
            if (pdata[heterogenized_eid] > 0) {
              cdata[k] = indices[off + j];
              idata[k] = homogenized_eid;
              ++k;
            }
          }
302
        }
303
304
305
306

        picked_rows[i] = rows.CreateView({k}, rows->dtype);
        picked_cols[i] = cols.CreateView({k}, cols->dtype);
        picked_idxs[i] = idx.CreateView({k}, idx->dtype);
307
308
309
310
311
312
313
314
      } else {
        // need to do per edge type sample
        std::vector<IdxType> rows;
        std::vector<IdxType> cols;
        std::vector<IdxType> idx;

        std::vector<IdxType> et(len);
        std::vector<IdxType> et_idx(len);
315
        std::vector<IdxType> et_eid(len);
316
317
        std::iota(et_idx.begin(), et_idx.end(), 0);
        for (int64_t j = 0; j < len; ++j) {
318
319
          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
          auto it = std::upper_bound(
320
321
              eid2etype_offset.begin(), eid2etype_offset.end(),
              homogenized_eid);
322
          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
323
          const IdxType heterogenized_eid =
324
325
326
              homogenized_eid - eid2etype_offset[heterogenized_etype];
          et[j] = heterogenized_etype;
          et_eid[j] = heterogenized_eid;
327
        }
328
329
330
331
332
333
334
        if (!rowwise_etype_sorted)  // the edge type is sorted, not need to sort
                                    // it
          std::sort(
              et_idx.begin(), et_idx.end(),
              [&et](IdxType i1, IdxType i2) { return et[i1] < et[i2]; });
        CHECK_LT(et[et_idx[len - 1]], num_etypes)
            << "etype values exceed the number of fanouts";
335
336
337
338
339

        IdxType cur_et = et[et_idx[0]];
        int64_t et_offset = 0;
        int64_t et_len = 1;
        for (int64_t j = 0; j < len; ++j) {
340
341
          CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]]))
              << "Edge type is not sorted. Please sort in advance or specify "
342
                 "'rowwise_etype_sorted' as false.";
343
          if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) {
344
345
346
            // 1 end of the current etype
            // 2 end of the row
            // random pick for current etype
347
348
            if ((num_picks[cur_et] == -1) ||
                (et_len <= num_picks[cur_et] && !replace)) {
349
350
              // fast path, select all
              for (int64_t k = 0; k < et_len; ++k) {
351
                const IdxType eid_offset = off + et_idx[et_offset + k];
352
353
                const IdxType homogenized_eid =
                    eid ? eid[eid_offset] : eid_offset;
354
                auto it = std::upper_bound(
355
356
357
358
359
                    eid2etype_offset.begin(), eid2etype_offset.end(),
                    homogenized_eid);
                const IdxType heterogenized_etype =
                    it - eid2etype_offset.begin() - 1;
                const IdxType heterogenized_eid =
360
361
                    homogenized_eid - eid2etype_offset[heterogenized_etype];

362
363
                if (!has_probs ||
                    IsNullArray(prob_or_mask[heterogenized_etype])) {
364
365
366
367
368
369
370
371
372
373
374
375
376
377
                  // No probability, select all
                  rows.push_back(rid);
                  cols.push_back(indices[eid_offset]);
                  idx.push_back(homogenized_eid);
                } else {
                  // Select the entries with non-zero probability
                  const NDArray& p = prob_or_mask[heterogenized_etype];
                  const DType* pdata = p.Ptr<DType>();
                  if (pdata[heterogenized_eid] > 0) {
                    rows.push_back(rid);
                    cols.push_back(indices[eid_offset]);
                    idx.push_back(homogenized_eid);
                  }
                }
378
379
              }
            } else {
380
381
              IdArray picked_idx =
                  Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
382
              IdxType* picked_idata = picked_idx.Ptr<IdxType>();
383
384

              // need call random pick
385
386
387
              pick_fn(
                  off, et_offset, cur_et, et_len, et_idx, et_eid, eid,
                  picked_idata);
388
389
              for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
                const IdxType picked = picked_idata[k];
390
                if (picked == -1) continue;
391
                rows.push_back(rid);
392
                cols.push_back(indices[off + et_idx[et_offset + picked]]);
393
                if (eid) {
394
                  idx.push_back(eid[off + et_idx[et_offset + picked]]);
395
                } else {
396
                  idx.push_back(off + et_idx[et_offset + picked]);
397
                }
398
              }
399
            }
400

401
            if (j + 1 == len) break;
402
            // next etype
403
404
            cur_et = et[et_idx[j + 1]];
            et_offset = j + 1;
405
            et_len = 1;
406
          } else {
407
            et_len++;
408
409
410
          }
        }

411
412
413
414
        picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
        picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
        picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
      }  // end processing one row
415

416
417
418
419
      CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);
      CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);
    }  // end processing all rows
  });
420
421
422
423

  IdArray picked_row = Concat(picked_rows);
  IdArray picked_col = Concat(picked_cols);
  IdArray picked_idx = Concat(picked_idxs);
424
425
  return COOMatrix(
      mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
426
427
}

428
429
430
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
431
template <typename IdxType>
432
433
434
COOMatrix COORowWisePick(
    COOMatrix mat, IdArray rows, int64_t num_picks, bool replace,
    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
435
436
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
437
438
  const IdArray new_rows =
      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
439
440
  const auto& picked = CSRRowWisePick<IdxType>(
      csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
441
442
443
444
  return COOMatrix(
      mat.num_rows, mat.num_cols,
      IndexSelect(rows, picked.row),  // map the row index to the correct one
      picked.col, picked.data);
445
446
}

447
448
449
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
450
451
452
453
454
455
template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_picks, bool replace,
    EtypeRangePickFn<IdxType> pick_fn,
    const std::vector<NDArray>& prob_or_mask) {
456
457
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
458
459
  const IdArray new_rows =
      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
460
  const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
461
462
463
464
465
466
      csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn,
      prob_or_mask);
  return COOMatrix(
      mat.num_rows, mat.num_cols,
      IndexSelect(rows, picked.row),  // map the row index to the correct one
      picked.col, picked.data);
467
468
}

469
470
471
472
473
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_ROWWISE_PICK_H_