rowwise_pick.h 18.5 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/rowwise_pick.h
 * @brief Template implementation for rowwise pick operators.
5
6
7
8
9
 */
#ifndef DGL_ARRAY_CPU_ROWWISE_PICK_H_
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_

#include <dgl/array.h>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
10
#include <dmlc/omp.h>
11
#include <dgl/runtime/parallel_for.h>
12
#include <functional>
13
#include <algorithm>
14
15
#include <string>
#include <vector>
16
#include <memory>
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

namespace dgl {
namespace aten {
namespace impl {

// User-defined function for picking elements from one row.
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param rowid The row to pick from.
// \param off Starting offset of this row.
// \param len NNZ of the row.
37
// \param num_picks Number of picks on the row.
38
39
40
41
42
// \param col Pointer of the column indices.
// \param data Pointer of the data indices.
// \param out_idx Picked indices in [off, off + len).
template <typename IdxType>
using PickFn = std::function<void(
43
    IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
44
45
46
    const IdxType* col, const IdxType* data,
    IdxType* out_idx)>;

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// User-defined function for determining the number of elements to pick from one row.
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param rowid The row to pick from.
// \param off Starting offset of this row.
// \param len NNZ of the row.
// \param col Pointer of the column indices.
// \param data Pointer of the data indices.
template <typename IdxType>
using NumPicksFn = std::function<IdxType(
    IdxType rowid, IdxType off, IdxType len,
    const IdxType* col, const IdxType* data)>;

69
70
71
72
73
74
75
76
77
78
79
80
81
82
// User-defined function for picking elements from a range within a row.
//
// The column indices of each element is in
//   off + et_idx[et_offset+i]), where i is in [et_offset, et_offset+et_len)
//
// Similarly, the data indices are stored in
//   data[off+et_idx[et_offset+i])]
// Data index pointer could be NULL, which means data[i] == off+et_idx[et_offset+i])
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param off Starting offset of this row.
// \param et_offset Starting offset of this range.
83
// \param cur_et The edge type.
84
85
// \param et_len Length of the range.
// \param et_idx A map from local idx to column id.
86
87
// \param et_eid Edge-type-specific id array.
// \param eid Pointer of the homogenized edge id array.
88
89
// \param out_idx Picked indices in [et_offset, et_offset + et_len).
template <typename IdxType>
90
using EtypeRangePickFn = std::function<void(
91
    IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
92
93
    const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid,
    const IdxType* eid, IdxType* out_idx)>;
94

95
96
97
98
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
99
100
                         int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
                         NumPicksFn<IdxType> num_picks_fn) {
101
102
103
104
105
106
107
  using namespace aten;
  const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
  const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
  const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr;
  const IdxType* rows_data = static_cast<IdxType*>(rows->data);
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
108
  const auto& idtype = mat.indptr->dtype;
109
110
111
112
113
114
115
116
117
118
119
120
121

  // To leverage OMP parallelization, we create two arrays to store
  // picked src and dst indices. Each array is of length num_rows * num_picks.
  // For rows whose nnz < num_picks, the indices are padded with -1.
  //
  // We check whether all the given rows
  // have at least num_picks number of nnz when replace is false.
  //
  // If the check holds, remove -1 elements by remove_if operation, which simply
  // moves valid elements to the head of arrays and create a view of the original
  // array. The implementation consumes a little extra memory than the actual requirement.
  //
  // Otherwise, directly use the row and col arrays to construct the result COO matrix.
122
123
124
  //
  // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
  //   significant. (minjie)
125

126
127
128
129
130
131
132
133
134
  // Do not use omp_get_max_threads() since that doesn't work for compiling without OpenMP.
  const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
  std::vector<int64_t> global_prefix(num_threads + 1, 0);

  // TODO(BarclayII) Using OMP parallel directly instead of using runtime::parallel_for
  // does not handle exceptions well (directly aborts when an exception pops up).
  // It runs faster though because there is less scheduling.  Need to handle
  // exceptions better.
  IdArray picked_row, picked_col, picked_idx;
135
136
137
138
139
140
141
142
143
144
145
#pragma omp parallel num_threads(num_threads)
  {
    const int thread_id = omp_get_thread_num();

    const int64_t start_i = thread_id * (num_rows/num_threads) +
        std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
    const int64_t end_i = (thread_id + 1) * (num_rows/num_threads) +
        std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
    assert(thread_id + 1 < num_threads || end_i == num_rows);

    const int64_t num_local = end_i - start_i;
146

147
148
149
150
151
152
153
    // make sure we don't have to pay initialization cost
    std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
    local_prefix[0] = 0;
    for (int64_t i = start_i; i < end_i; ++i) {
      // build prefix-sum
      const int64_t local_i = i-start_i;
      const IdxType rid = rows_data[i];
154
155
      IdxType len = num_picks_fn(
          rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
156
157
158
159
160
161
162
163
      local_prefix[local_i + 1] = local_prefix[local_i] + len;
    }
    global_prefix[thread_id + 1] = local_prefix[num_local];

    #pragma omp barrier
    #pragma omp master
    {
      for (int t = 0; t < num_threads; ++t) {
164
        global_prefix[t + 1] += global_prefix[t];
165
      }
166
167
168
      picked_row = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
169
    }
170

171
    #pragma omp barrier
172
173
174
175
    IdxType* picked_rdata = picked_row.Ptr<IdxType>();
    IdxType* picked_cdata = picked_col.Ptr<IdxType>();
    IdxType* picked_idata = picked_idx.Ptr<IdxType>();

176
    const IdxType thread_offset = global_prefix[thread_id];
177

178
179
180
181
182
183
184
185
186
187
    for (int64_t i = start_i; i < end_i; ++i) {
      const IdxType rid = rows_data[i];

      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;
      if (len == 0)
        continue;

      const int64_t local_i = i - start_i;
      const int64_t row_offset = thread_offset + local_prefix[local_i];
188
189
190
191
192
193
194
195
      const int64_t num_picks = thread_offset + local_prefix[local_i + 1] - row_offset;

      pick_fn(rid, off, len, num_picks, indices, data, picked_idata + row_offset);
      for (int64_t j = 0; j < num_picks; ++j) {
        const IdxType picked = picked_idata[row_offset + j];
        picked_rdata[row_offset + j] = rid;
        picked_cdata[row_offset + j] = indices[picked];
        picked_idata[row_offset + j] = data ? data[picked] : picked;
196
197
      }
    }
198
199
  }

200
201
  const int64_t new_len = global_prefix.back();

202
203
204
205
206
207
  return COOMatrix(
      mat.num_rows,
      mat.num_cols,
      picked_row.CreateView({new_len}, picked_row->dtype),
      picked_col.CreateView({new_len}, picked_row->dtype),
      picked_idx.CreateView({new_len}, picked_row->dtype));
208
209
}

210
211
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
212
213
214
template <typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
                                 const std::vector<int64_t>& eid2etype_offset,
215
                                 const std::vector<int64_t>& num_picks, bool replace,
216
217
                                 bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
                                 const std::vector<NDArray>& prob_or_mask) {
218
  using namespace aten;
219
220
  const IdxType* indptr = mat.indptr.Ptr<IdxType>();
  const IdxType* indices = mat.indices.Ptr<IdxType>();
221
  const IdxType* eid = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr;
222
  const IdxType* rows_data = rows.Ptr<IdxType>();
223
224
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
225
  const int64_t num_etypes = num_picks.size();
226
  const bool has_probs = (prob_or_mask.size() > 0);
227
228
229
230
  std::vector<IdArray> picked_rows(rows->shape[0]);
  std::vector<IdArray> picked_cols(rows->shape[0]);
  std::vector<IdArray> picked_idxs(rows->shape[0]);

231
232
233
234
235
236
237
238
239
  // Check if the number of picks have the same value.
  // If so, we can potentially speed up if we have a node with total number of neighbors
  // less than the given number of picks with replace=False.
  bool same_num_pick = true;
  int64_t num_pick_value = num_picks[0];
  for (int64_t num_pick : num_picks) {
    if (num_pick_value != num_pick) {
      same_num_pick = false;
      break;
240
    }
241
  }
242

243
  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
244
    for (size_t i = b; i < e; ++i) {
245
246
247
248
249
250
251
252
253
254
255
      const IdxType rid = rows_data[i];
      CHECK_LT(rid, mat.num_rows);
      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;

      // do something here
      if (len == 0) {
        picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        continue;
256
      }
257
258
259
260
261
262
263
264

      // fast path
      if (same_num_pick && len <= num_pick_value && !replace) {
        IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);
        IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdxType* cdata = cols.Ptr<IdxType>();
        IdxType* idata = idx.Ptr<IdxType>();
265
266

        int64_t k = 0;
267
        for (int64_t j = 0; j < len; ++j) {
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
          auto it = std::upper_bound(
              eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
          const IdxType heterogenized_eid = \
              homogenized_eid - eid2etype_offset[heterogenized_etype];

          if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
            // No probability array, select all
            cdata[k] = indices[off + j];
            idata[k] = homogenized_eid;
            ++k;
          } else {
            // Select the entries with non-zero probability
            const NDArray& p = prob_or_mask[heterogenized_etype];
            const DType* pdata = p.Ptr<DType>();
            if (pdata[heterogenized_eid] > 0) {
              cdata[k] = indices[off + j];
              idata[k] = homogenized_eid;
              ++k;
            }
          }
290
        }
291
292
293
294

        picked_rows[i] = rows.CreateView({k}, rows->dtype);
        picked_cols[i] = cols.CreateView({k}, cols->dtype);
        picked_idxs[i] = idx.CreateView({k}, idx->dtype);
295
296
297
298
299
300
301
302
      } else {
        // need to do per edge type sample
        std::vector<IdxType> rows;
        std::vector<IdxType> cols;
        std::vector<IdxType> idx;

        std::vector<IdxType> et(len);
        std::vector<IdxType> et_idx(len);
303
        std::vector<IdxType> et_eid(len);
304
305
        std::iota(et_idx.begin(), et_idx.end(), 0);
        for (int64_t j = 0; j < len; ++j) {
306
307
308
309
310
311
312
313
          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
          auto it = std::upper_bound(
              eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
          const IdxType heterogenized_eid = \
              homogenized_eid - eid2etype_offset[heterogenized_etype];
          et[j] = heterogenized_etype;
          et_eid[j] = heterogenized_eid;
314
        }
315
        if (!rowwise_etype_sorted)  // the edge type is sorted, not need to sort it
316
317
          std::sort(et_idx.begin(), et_idx.end(),
                    [&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
318
        CHECK_LT(et[et_idx[len - 1]], num_etypes) << "etype values exceed the number of fanouts";
319
320
321
322
323

        IdxType cur_et = et[et_idx[0]];
        int64_t et_offset = 0;
        int64_t et_len = 1;
        for (int64_t j = 0; j < len; ++j) {
324
325
          CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]]))
              << "Edge type is not sorted. Please sort in advance or specify "
326
                 "'rowwise_etype_sorted' as false.";
327
          if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) {
328
329
330
            // 1 end of the current etype
            // 2 end of the row
            // random pick for current etype
331
332
            if ((num_picks[cur_et] == -1) ||
                (et_len <= num_picks[cur_et] && !replace)) {
333
334
              // fast path, select all
              for (int64_t k = 0; k < et_len; ++k) {
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                const IdxType eid_offset = off + et_idx[et_offset + k];
                const IdxType homogenized_eid = eid ? eid[eid_offset] : eid_offset;
                auto it = std::upper_bound(
                    eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
                const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
                const IdxType heterogenized_eid = \
                    homogenized_eid - eid2etype_offset[heterogenized_etype];

                if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
                  // No probability, select all
                  rows.push_back(rid);
                  cols.push_back(indices[eid_offset]);
                  idx.push_back(homogenized_eid);
                } else {
                  // Select the entries with non-zero probability
                  const NDArray& p = prob_or_mask[heterogenized_etype];
                  const DType* pdata = p.Ptr<DType>();
                  if (pdata[heterogenized_eid] > 0) {
                    rows.push_back(rid);
                    cols.push_back(indices[eid_offset]);
                    idx.push_back(homogenized_eid);
                  }
                }
358
359
360
              }
            } else {
              IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
361
              IdxType* picked_idata = picked_idx.Ptr<IdxType>();
362
363
364

              // need call random pick
              pick_fn(off, et_offset, cur_et,
365
366
                      et_len, et_idx, et_eid,
                      eid, picked_idata);
367
368
              for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
                const IdxType picked = picked_idata[k];
369
370
                if (picked == -1)
                  continue;
371
372
                rows.push_back(rid);
                cols.push_back(indices[off+et_idx[et_offset+picked]]);
373
374
375
                if (eid) {
                  idx.push_back(eid[off+et_idx[et_offset+picked]]);
                } else {
376
                  idx.push_back(off+et_idx[et_offset+picked]);
377
                }
378
              }
379
            }
380
381
382
383
384
385
386

            if (j+1 == len)
              break;
            // next etype
            cur_et = et[et_idx[j+1]];
            et_offset = j+1;
            et_len = 1;
387
          } else {
388
            et_len++;
389
390
391
          }
        }

392
393
394
395
        picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
        picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
        picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
      }  // end processing one row
396

397
398
399
400
      CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);
      CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);
    }  // end processing all rows
  });
401
402
403
404
405
406
407
408

  IdArray picked_row = Concat(picked_rows);
  IdArray picked_col = Concat(picked_cols);
  IdArray picked_idx = Concat(picked_idxs);
  return COOMatrix(mat.num_rows, mat.num_cols,
                   picked_row, picked_col, picked_idx);
}

409
410
411
412
413
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
414
415
                         int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
                         NumPicksFn<IdxType> num_picks_fn) {
416
417
418
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
  const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
419
420
  const auto& picked = CSRRowWisePick<IdxType>(
      csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
421
422
423
424
425
426
  return COOMatrix(mat.num_rows, mat.num_cols,
                   IndexSelect(rows, picked.row),  // map the row index to the correct one
                   picked.col,
                   picked.data);
}

427
428
429
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
430
431
432
433
434
435
template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_picks, bool replace,
    EtypeRangePickFn<IdxType> pick_fn,
    const std::vector<NDArray>& prob_or_mask) {
436
437
438
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
  const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
439
440
  const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
    csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn, prob_or_mask);
441
442
443
444
445
446
  return COOMatrix(mat.num_rows, mat.num_cols,
                   IndexSelect(rows, picked.row),  // map the row index to the correct one
                   picked.col,
                   picked.data);
}

447
448
449
450
451
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_ROWWISE_PICK_H_