rowwise_pick.h 15.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/rowwise_pick.h
 * \brief Template implementation for rowwise pick operators.
 */
#ifndef DGL_ARRAY_CPU_ROWWISE_PICK_H_
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_

#include <dgl/array.h>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
10
#include <dmlc/omp.h>
11
#include <dgl/runtime/parallel_for.h>
12
#include <functional>
13
#include <algorithm>
14
15
#include <string>
#include <vector>
16
#include <memory>
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

namespace dgl {
namespace aten {
namespace impl {

// User-defined function for picking elements from one row.
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param rowid The row to pick from.
// \param off Starting offset of this row.
// \param len NNZ of the row.
37
// \param num_picks Number of picks on the row.
38
39
40
41
42
// \param col Pointer of the column indices.
// \param data Pointer of the data indices.
// \param out_idx Picked indices in [off, off + len).
template <typename IdxType>
using PickFn = std::function<void(
43
    IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
44
45
46
    const IdxType* col, const IdxType* data,
    IdxType* out_idx)>;

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// User-defined function for determining the number of elements to pick from one row.
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param rowid The row to pick from.
// \param off Starting offset of this row.
// \param len NNZ of the row.
// \param col Pointer of the column indices.
// \param data Pointer of the data indices.
template <typename IdxType>
using NumPicksFn = std::function<IdxType(
    IdxType rowid, IdxType off, IdxType len,
    const IdxType* col, const IdxType* data)>;

69
70
71
72
73
74
75
76
77
78
79
80
81
82
// User-defined function for picking elements from a range within a row.
//
// The column indices of each element is in
//   off + et_idx[et_offset+i]), where i is in [et_offset, et_offset+et_len)
//
// Similarly, the data indices are stored in
//   data[off+et_idx[et_offset+i])]
// Data index pointer could be NULL, which means data[i] == off+et_idx[et_offset+i])
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param off Starting offset of this row.
// \param et_offset Starting offset of this range.
83
// \param cur_et The edge type.
84
85
86
87
88
89
// \param et_len Length of the range.
// \param et_idx A map from local idx to column id.
// \param data Pointer of the data indices.
// \param out_idx Picked indices in [et_offset, et_offset + et_len).
template <typename IdxType>
using RangePickFn = std::function<void(
90
    IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
91
92
93
    const std::vector<IdxType> &et_idx, const IdxType* data,
    IdxType* out_idx)>;

94
95
96
97
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
98
99
                         int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
                         NumPicksFn<IdxType> num_picks_fn) {
100
101
102
103
104
105
106
  using namespace aten;
  const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
  const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
  const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr;
  const IdxType* rows_data = static_cast<IdxType*>(rows->data);
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
107
  const auto& idtype = mat.indptr->dtype;
108
109
110
111
112
113
114
115
116
117
118
119
120

  // To leverage OMP parallelization, we create two arrays to store
  // picked src and dst indices. Each array is of length num_rows * num_picks.
  // For rows whose nnz < num_picks, the indices are padded with -1.
  //
  // We check whether all the given rows
  // have at least num_picks number of nnz when replace is false.
  //
  // If the check holds, remove -1 elements by remove_if operation, which simply
  // moves valid elements to the head of arrays and create a view of the original
  // array. The implementation consumes a little extra memory than the actual requirement.
  //
  // Otherwise, directly use the row and col arrays to construct the result COO matrix.
121
122
123
  //
  // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
  //   significant. (minjie)
124

125
126
127
128
129
130
131
132
133
  // Do not use omp_get_max_threads() since that doesn't work for compiling without OpenMP.
  const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
  std::vector<int64_t> global_prefix(num_threads + 1, 0);

  // TODO(BarclayII) Using OMP parallel directly instead of using runtime::parallel_for
  // does not handle exceptions well (directly aborts when an exception pops up).
  // It runs faster though because there is less scheduling.  Need to handle
  // exceptions better.
  IdArray picked_row, picked_col, picked_idx;
134
135
136
137
138
139
140
141
142
143
144
#pragma omp parallel num_threads(num_threads)
  {
    const int thread_id = omp_get_thread_num();

    const int64_t start_i = thread_id * (num_rows/num_threads) +
        std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
    const int64_t end_i = (thread_id + 1) * (num_rows/num_threads) +
        std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
    assert(thread_id + 1 < num_threads || end_i == num_rows);

    const int64_t num_local = end_i - start_i;
145

146
147
148
149
150
151
152
    // make sure we don't have to pay initialization cost
    std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
    local_prefix[0] = 0;
    for (int64_t i = start_i; i < end_i; ++i) {
      // build prefix-sum
      const int64_t local_i = i-start_i;
      const IdxType rid = rows_data[i];
153
154
      IdxType len = num_picks_fn(
          rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
155
156
157
158
159
160
161
162
      local_prefix[local_i + 1] = local_prefix[local_i] + len;
    }
    global_prefix[thread_id + 1] = local_prefix[num_local];

    #pragma omp barrier
    #pragma omp master
    {
      for (int t = 0; t < num_threads; ++t) {
163
        global_prefix[t + 1] += global_prefix[t];
164
      }
165
166
167
      picked_row = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
168
    }
169

170
    #pragma omp barrier
171
172
173
174
    IdxType* picked_rdata = picked_row.Ptr<IdxType>();
    IdxType* picked_cdata = picked_col.Ptr<IdxType>();
    IdxType* picked_idata = picked_idx.Ptr<IdxType>();

175
    const IdxType thread_offset = global_prefix[thread_id];
176

177
178
179
180
181
182
183
184
185
186
    for (int64_t i = start_i; i < end_i; ++i) {
      const IdxType rid = rows_data[i];

      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;
      if (len == 0)
        continue;

      const int64_t local_i = i - start_i;
      const int64_t row_offset = thread_offset + local_prefix[local_i];
187
188
189
190
191
192
193
194
      const int64_t num_picks = thread_offset + local_prefix[local_i + 1] - row_offset;

      pick_fn(rid, off, len, num_picks, indices, data, picked_idata + row_offset);
      for (int64_t j = 0; j < num_picks; ++j) {
        const IdxType picked = picked_idata[row_offset + j];
        picked_rdata[row_offset + j] = rid;
        picked_cdata[row_offset + j] = indices[picked];
        picked_idata[row_offset + j] = data ? data[picked] : picked;
195
196
      }
    }
197
198
  }

199
200
  const int64_t new_len = global_prefix.back();

201
202
203
204
205
206
  return COOMatrix(
      mat.num_rows,
      mat.num_cols,
      picked_row.CreateView({new_len}, picked_row->dtype),
      picked_col.CreateView({new_len}, picked_row->dtype),
      picked_idx.CreateView({new_len}, picked_row->dtype));
207
208
}

209
210
211
212
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
213
214
                                 const std::vector<int64_t>& num_picks, bool replace,
                                 bool etype_sorted, RangePickFn<IdxType> pick_fn) {
215
  using namespace aten;
216
217
218
219
220
  const IdxType* indptr = mat.indptr.Ptr<IdxType>();
  const IdxType* indices = mat.indices.Ptr<IdxType>();
  const IdxType* data = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr;
  const IdxType* rows_data = rows.Ptr<IdxType>();
  const int32_t* etype_data = etypes.Ptr<int32_t>();
221
222
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
223
224
  const int64_t num_etypes = num_picks.size();
  CHECK_EQ(etypes->dtype.bits / 8, sizeof(int32_t)) << "etypes must be int32";
225
226
227
228
  std::vector<IdArray> picked_rows(rows->shape[0]);
  std::vector<IdArray> picked_cols(rows->shape[0]);
  std::vector<IdArray> picked_idxs(rows->shape[0]);

229
230
231
232
233
234
235
236
237
  // Check if the number of picks have the same value.
  // If so, we can potentially speed up if we have a node with total number of neighbors
  // less than the given number of picks with replace=False.
  bool same_num_pick = true;
  int64_t num_pick_value = num_picks[0];
  for (int64_t num_pick : num_picks) {
    if (num_pick_value != num_pick) {
      same_num_pick = false;
      break;
238
    }
239
  }
240

241
  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
242
    for (size_t i = b; i < e; ++i) {
243
244
245
246
247
248
249
250
251
252
253
      const IdxType rid = rows_data[i];
      CHECK_LT(rid, mat.num_rows);
      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;

      // do something here
      if (len == 0) {
        picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        continue;
254
      }
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

      // fast path
      if (same_num_pick && len <= num_pick_value && !replace) {
        IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);
        IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdxType* cdata = cols.Ptr<IdxType>();
        IdxType* idata = idx.Ptr<IdxType>();
        for (int64_t j = 0; j < len; ++j) {
          cdata[j] = indices[off + j];
          idata[j] = data ? data[off + j] : off + j;
        }
        picked_rows[i] = rows;
        picked_cols[i] = cols;
        picked_idxs[i] = idx;
      } else {
        // need to do per edge type sample
        std::vector<IdxType> rows;
        std::vector<IdxType> cols;
        std::vector<IdxType> idx;

        std::vector<IdxType> et(len);
        std::vector<IdxType> et_idx(len);
        std::iota(et_idx.begin(), et_idx.end(), 0);
        for (int64_t j = 0; j < len; ++j) {
          et[j] = data ? etype_data[data[off+j]] : etype_data[off+j];
        }
        if (!etype_sorted)  // the edge type is sorted, not need to sort it
          std::sort(et_idx.begin(), et_idx.end(),
                    [&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
        CHECK(et[et_idx[len - 1]] < num_etypes) <<
          "etype values exceed the number of fanouts";

        IdxType cur_et = et[et_idx[0]];
        int64_t et_offset = 0;
        int64_t et_len = 1;
        for (int64_t j = 0; j < len; ++j) {
292
293
294
295
          CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]]))
              << "Edge type is not sorted. Please sort in advance or specify "
                 "'etype_sorted' as false.";
          if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) {
296
297
298
            // 1 end of the current etype
            // 2 end of the row
            // random pick for current etype
299
300
            if ((num_picks[cur_et] == -1) ||
                (et_len <= num_picks[cur_et] && !replace)) {
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
              // fast path, select all
              for (int64_t k = 0; k < et_len; ++k) {
                rows.push_back(rid);
                cols.push_back(indices[off+et_idx[et_offset+k]]);
                if (data)
                  idx.push_back(data[off+et_idx[et_offset+k]]);
                else
                  idx.push_back(off+et_idx[et_offset+k]);
              }
            } else {
              IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
              IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);

              // need call random pick
              pick_fn(off, et_offset, cur_et,
                      et_len, et_idx,
                      data, picked_idata);
              for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
                const IdxType picked = picked_idata[k];
                rows.push_back(rid);
                cols.push_back(indices[off+et_idx[et_offset+picked]]);
                if (data)
                  idx.push_back(data[off+et_idx[et_offset+picked]]);
                else
                  idx.push_back(off+et_idx[et_offset+picked]);
              }
327
            }
328
329
330
331
332
333
334

            if (j+1 == len)
              break;
            // next etype
            cur_et = et[et_idx[j+1]];
            et_offset = j+1;
            et_len = 1;
335
          } else {
336
            et_len++;
337
338
339
          }
        }

340
341
342
343
        picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
        picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
        picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
      }  // end processing one row
344

345
346
347
348
      CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);
      CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);
    }  // end processing all rows
  });
349
350
351
352
353
354
355
356

  IdArray picked_row = Concat(picked_rows);
  IdArray picked_col = Concat(picked_cols);
  IdArray picked_idx = Concat(picked_idxs);
  return COOMatrix(mat.num_rows, mat.num_cols,
                   picked_row, picked_col, picked_idx);
}

357
358
359
360
361
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
362
363
                         int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
                         NumPicksFn<IdxType> num_picks_fn) {
364
365
366
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
  const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
367
368
  const auto& picked = CSRRowWisePick<IdxType>(
      csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
369
370
371
372
373
374
  return COOMatrix(mat.num_rows, mat.num_cols,
                   IndexSelect(rows, picked.row),  // map the row index to the correct one
                   picked.col,
                   picked.data);
}

375
376
377
378
379
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePerEtypePick(COOMatrix mat, IdArray rows, IdArray etypes,
380
381
                                 const std::vector<int64_t>& num_picks, bool replace,
                                 bool etype_sorted, RangePickFn<IdxType> pick_fn) {
382
383
384
385
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
  const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
  const auto& picked = CSRRowWisePerEtypePick<IdxType>(
386
    csr, new_rows, etypes, num_picks, replace, etype_sorted, pick_fn);
387
388
389
390
391
392
  return COOMatrix(mat.num_rows, mat.num_cols,
                   IndexSelect(rows, picked.row),  // map the row index to the correct one
                   picked.col,
                   picked.data);
}

393
394
395
396
397
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_ROWWISE_PICK_H_