spmat_op_impl_coo.cc 32.4 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/cpu/spmat_op_impl.cc
 * @brief CPU implementation of COO sparse matrix operators
5
 */
6
#include <dgl/runtime/parallel_for.h>
7
8
#include <dmlc/omp.h>

9
#include <numeric>
10
11
12
13
14
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>

15
16
17
18
19
#include "array_utils.h"

namespace dgl {

using runtime::NDArray;
20
using runtime::parallel_for;
21
22
23
24

namespace aten {
namespace impl {

25
/**
26
27
28
29
30
31
32
33
 * TODO(BarclayII):
 * For row-major sorted COOs, we have faster implementation with binary search,
 * sorted search, etc.  Later we should benchmark how much we can gain with
 * sorted COOs on hypersparse graphs.
 */

///////////////////////////// COOIsNonZero /////////////////////////////

34
template <DGLDeviceType XPU, typename IdType>
35
36
37
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
  CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
38
39
  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
40
  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
41
    if (coo_row_data[i] == row && coo_col_data[i] == col) return true;
42
43
44
45
  }
  return false;
}

46
47
template bool COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
48

49
template <DGLDeviceType XPU, typename IdType>
50
51
52
53
54
NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto rstlen = std::max(rowlen, collen);
  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
55
56
57
  IdType *rst_data = static_cast<IdType *>(rst->data);
  const IdType *row_data = static_cast<IdType *>(row->data);
  const IdType *col_data = static_cast<IdType *>(col->data);
58
59
60
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const int64_t kmax = std::max(rowlen, collen);
61
62
63
64
  parallel_for(0, kmax, [=](size_t b, size_t e) {
    for (auto k = b; k < e; ++k) {
      int64_t i = row_stride * k;
      int64_t j = col_stride * k;
65
66
      rst_data[k] =
          COOIsNonZero<XPU, IdType>(coo, row_data[i], col_data[j]) ? 1 : 0;
67
68
    }
  });
69
70
71
  return rst;
}

72
73
template NDArray COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray);
74
75
76

///////////////////////////// COOHasDuplicate /////////////////////////////

77
template <DGLDeviceType XPU, typename IdType>
78
79
bool COOHasDuplicate(COOMatrix coo) {
  std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;
80
81
  const IdType *src_data = static_cast<IdType *>(coo.row->data);
  const IdType *dst_data = static_cast<IdType *>(coo.col->data);
82
83
  const auto nnz = coo.row->shape[0];
  for (IdType eid = 0; eid < nnz; ++eid) {
84
    const auto &p = std::make_pair(src_data[eid], dst_data[eid]);
85
86
87
88
89
90
91
92
93
    if (hashmap.count(p)) {
      return true;
    } else {
      hashmap.insert(p);
    }
  }
  return false;
}

94
95
template bool COOHasDuplicate<kDGLCPU, int32_t>(COOMatrix coo);
template bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo);
96
97
98

///////////////////////////// COOGetRowNNZ /////////////////////////////

99
template <DGLDeviceType XPU, typename IdType>
100
101
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
  CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
102
  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
103
104
  int64_t result = 0;
  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
105
    if (coo_row_data[i] == row) ++result;
106
107
108
109
  }
  return result;
}

110
111
template int64_t COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, int64_t);
112

113
template <DGLDeviceType XPU, typename IdType>
114
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
115
  CHECK_SAME_DTYPE(coo.col, rows);
116
  const auto len = rows->shape[0];
117
  const IdType *vid_data = static_cast<IdType *>(rows->data);
118
  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
119
  IdType *rst_data = static_cast<IdType *>(rst->data);
120
#pragma omp parallel for
121
  for (int64_t i = 0; i < len; ++i) {
122
    rst_data[i] = COOGetRowNNZ<XPU, IdType>(coo, vid_data[i]);
123
  }
124
125
126
  return rst;
}

127
128
template NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray);
129

130
////////////////////////// COOGetRowDataAndIndices /////////////////////////////
131

132
template <DGLDeviceType XPU, typename IdType>
133
134
135
136
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row) {
  CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;

137
138
139
140
  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
  const IdType *coo_data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
141
142

  std::vector<IdType> indices;
143
  std::vector<IdType> data;
144
145
146
147
148
149
150
151

  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
    if (coo_row_data[i] == row) {
      indices.push_back(coo_col_data[i]);
      data.push_back(coo_data ? coo_data[i] : i);
    }
  }

152
153
  return std::make_pair(
      NDArray::FromVector(data), NDArray::FromVector(indices));
154
155
}

156
157
158
159
template std::pair<NDArray, NDArray> COOGetRowDataAndIndices<kDGLCPU, int32_t>(
    COOMatrix, int64_t);
template std::pair<NDArray, NDArray> COOGetRowDataAndIndices<kDGLCPU, int64_t>(
    COOMatrix, int64_t);
160
161
162

///////////////////////////// COOGetData /////////////////////////////

163
template <DGLDeviceType XPU, typename IdType>
164
165
166
167
IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];
  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
168
      << "Invalid row and col Id array:" << rows << " " << cols;
169
170
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
171
172
  const IdType *row_data = rows.Ptr<IdType>();
  const IdType *col_data = cols.Ptr<IdType>();
173

174
175
176
  const IdType *coo_row = coo.row.Ptr<IdType>();
  const IdType *coo_col = coo.col.Ptr<IdType>();
  const IdType *data = COOHasData(coo) ? coo.data.Ptr<IdType>() : nullptr;
177
178
179
180
  const int64_t nnz = coo.row->shape[0];

  const int64_t retlen = std::max(rowlen, collen);
  IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx);
181
  IdType *ret_data = ret.Ptr<IdType>();
182

183
184
185
  // TODO(minjie): We might need to consider sorting the COO beforehand
  // especially when the number of (row, col) pairs is large. Need more
  // benchmarks to justify the choice.
186
187

  if (coo.row_sorted) {
188
189
    parallel_for(0, retlen, [&](size_t b, size_t e) {
      for (auto p = b; p < e; ++p) {
190
191
        const IdType row_id = row_data[p * row_stride],
                     col_id = col_data[p * col_stride];
192
193
194
195
        auto it = std::lower_bound(coo_row, coo_row + nnz, row_id);
        for (; it < coo_row + nnz && *it == row_id; ++it) {
          const auto idx = it - coo_row;
          if (coo_col[idx] == col_id) {
196
            ret_data[p] = data ? data[idx] : idx;
197
198
            break;
          }
199
200
        }
      }
201
    });
202
203
204
  } else {
#pragma omp parallel for
    for (int64_t p = 0; p < retlen; ++p) {
205
206
      const IdType row_id = row_data[p * row_stride],
                   col_id = col_data[p * col_stride];
207
208
      for (int64_t idx = 0; idx < nnz; ++idx) {
        if (coo_row[idx] == row_id && coo_col[idx] == col_id) {
209
          ret_data[p] = data ? data[idx] : idx;
210
211
212
213
          break;
        }
      }
    }
214
  }
215
216

  return ret;
217
218
}

219
220
template IdArray COOGetData<kDGLCPU, int32_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray);
221
222
223

///////////////////////////// COOGetDataAndIndices /////////////////////////////

224
template <DGLDeviceType XPU, typename IdType>
225
226
std::vector<NDArray> COOGetDataAndIndices(
    COOMatrix coo, NDArray rows, NDArray cols) {
227
228
  CHECK_SAME_DTYPE(coo.col, rows);
  CHECK_SAME_DTYPE(coo.col, cols);
229
230
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];
231
  const int64_t len = std::max(rowlen, collen);
232
233

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
234
      << "Invalid row and col id array.";
235
236
237

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
238
239
  const IdType *row_data = static_cast<IdType *>(rows->data);
  const IdType *col_data = static_cast<IdType *>(cols->data);
240

241
242
243
244
  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
  const IdType *data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
245
246

  std::vector<IdType> ret_rows, ret_cols;
247
  std::vector<IdType> ret_data;
248
249
250
251
252
  ret_rows.reserve(len);
  ret_cols.reserve(len);
  ret_data.reserve(len);

  // NOTE(BarclayII): With a small number of lookups, linear scan is faster.
253
254
255
  // The threshold 200 comes from benchmarking both algorithms on a P3.8x
  // instance. I also tried sorting plus binary search.  The speed gain is only
  // significant for medium-sized graphs and lookups, so I didn't include it.
256
  if (len >= 200) {
257
258
259
260
261
    // TODO(BarclayII) Ideally we would want to cache this object.  However I'm
    // not sure what is the best way to do so since this object is valid for CPU
    // only.
    std::unordered_multimap<std::pair<IdType, IdType>, IdType, PairHash>
        pair_map;
262
263
    pair_map.reserve(coo.row->shape[0]);
    for (int64_t k = 0; k < coo.row->shape[0]; ++k)
264
265
      pair_map.emplace(
          std::make_pair(coo_row_data[k], coo_col_data[k]), data ? data[k] : k);
266

267
268
    for (int64_t i = 0, j = 0; i < rowlen && j < collen;
         i += row_stride, j += col_stride) {
269
      const IdType row_id = row_data[i], col_id = col_data[j];
270
271
272
273
      CHECK(row_id >= 0 && row_id < coo.num_rows)
          << "Invalid row index: " << row_id;
      CHECK(col_id >= 0 && col_id < coo.num_cols)
          << "Invalid col index: " << col_id;
274
275
      auto range = pair_map.equal_range({row_id, col_id});
      for (auto it = range.first; it != range.second; ++it) {
276
277
        ret_rows.push_back(row_id);
        ret_cols.push_back(col_id);
278
279
280
281
        ret_data.push_back(it->second);
      }
    }
  } else {
282
283
    for (int64_t i = 0, j = 0; i < rowlen && j < collen;
         i += row_stride, j += col_stride) {
284
      const IdType row_id = row_data[i], col_id = col_data[j];
285
286
287
288
      CHECK(row_id >= 0 && row_id < coo.num_rows)
          << "Invalid row index: " << row_id;
      CHECK(col_id >= 0 && col_id < coo.num_cols)
          << "Invalid col index: " << col_id;
289
290
291
292
293
294
      for (int64_t k = 0; k < coo.row->shape[0]; ++k) {
        if (coo_row_data[k] == row_id && coo_col_data[k] == col_id) {
          ret_rows.push_back(row_id);
          ret_cols.push_back(col_id);
          ret_data.push_back(data ? data[k] : k);
        }
295
296
297
298
      }
    }
  }

299
300
301
  return {
      NDArray::FromVector(ret_rows), NDArray::FromVector(ret_cols),
      NDArray::FromVector(ret_data)};
302
303
}

304
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int32_t>(
305
    COOMatrix coo, NDArray rows, NDArray cols);
306
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int64_t>(
307
308
309
310
    COOMatrix coo, NDArray rows, NDArray cols);

///////////////////////////// COOTranspose /////////////////////////////

311
template <DGLDeviceType XPU, typename IdType>
312
313
314
315
COOMatrix COOTranspose(COOMatrix coo) {
  return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data};
}

316
317
template COOMatrix COOTranspose<kDGLCPU, int32_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo);
318
319

///////////////////////////// COOToCSR /////////////////////////////
320
namespace {
321

322
323
template <class IdType>
CSRMatrix SortedCOOToCSR(const COOMatrix &coo) {
324
325
  const int64_t N = coo.num_rows;
  const int64_t NNZ = coo.row->shape[0];
326
327
328
  const IdType *const row_data = static_cast<IdType *>(coo.row->data);
  const IdType *const data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
329

330
  NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx);
331
332
333
334
335
336
337
338
339
340
  NDArray ret_indices = coo.col;
  NDArray ret_data = data == nullptr
                         ? NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx)
                         : coo.data;

  // compute indptr
  IdType *const Bp = static_cast<IdType *>(ret_indptr->data);
  Bp[0] = 0;

  IdType *const fill_data =
341
      data ? nullptr : static_cast<IdType *>(ret_data->data);
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380

  if (NNZ > 0) {
    auto num_threads = omp_get_max_threads();
    parallel_for(0, num_threads, [&](int b, int e) {
      for (auto thread_id = b; thread_id < e; ++thread_id) {
        // We partition the set the of non-zeros among the threads
        const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads;
        const int64_t nz_start = thread_id * nz_chunk;
        const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk);

        // Each thread searchs the row array for a change, and marks it's
        // location in Bp. Threads, other than the first, start at the last
        // index covered by the previous, in order to detect changes in the row
        // array between thread partitions. This means that each thread after
        // the first, searches the range [nz_start-1, nz_end). That is,
        // if we had 10 non-zeros, and 4 threads, the indexes searched by each
        // thread would be:
        // 0: [0, 1, 2]
        // 1: [2, 3, 4, 5]
        // 2: [5, 6, 7, 8]
        // 3: [8, 9]
        //
        // That way, if the row array were [0, 0, 1, 2, 2, 2, 4, 5, 5, 6], each
        // change in row would be captured by one thread:
        //
        // 0: [0, 0, 1] - row 0
        // 1: [1, 2, 2, 2] - row 1
        // 2: [2, 4, 5, 5] - rows 2, 3, and 4
        // 3: [5, 6] - rows 5 and 6
        //
        int64_t row = 0;
        if (nz_start < nz_end) {
          row = nz_start == 0 ? 0 : row_data[nz_start - 1];
          for (int64_t i = nz_start; i < nz_end; ++i) {
            while (row != row_data[i]) {
              ++row;
              Bp[row] = i;
            }
          }
381

382
383
384
385
386
387
388
389
390
          // We will not detect the row change for the last row, nor any empty
          // rows at the end of the matrix, so the last active thread needs
          // mark all remaining rows in Bp with NNZ.
          if (nz_end == NNZ) {
            while (row < N) {
              ++row;
              Bp[row] = NNZ;
            }
          }
391

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
          if (fill_data) {
            // TODO(minjie): Many of our current implementation assumes that CSR
            // must have
            //   a data array. This is a temporary workaround. Remove this
            //   after:
            //   - The old immutable graph implementation is deprecated.
            //   - The old binary reduce kernel is deprecated.
            std::iota(fill_data + nz_start, fill_data + nz_end, nz_start);
          }
        }
      }
    });
  } else {
    std::fill(Bp, Bp + N + 1, 0);
  }

408
409
410
  return CSRMatrix(
      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,
      coo.col_sorted);
411
412
}

413
414
template <class IdType>
CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) {
415
416
417
418
419
  // Unsigned version of the original integer index data type.
  // It avoids overflow in (N + num_threads) and (n_start + n_chunk) below.
  typedef typename std::make_unsigned<IdType>::type UIdType;

  const UIdType N = coo.num_rows;
420
421
422
423
424
425
  const int64_t NNZ = coo.row->shape[0];
  const IdType *const row_data = static_cast<IdType *>(coo.row->data);
  const IdType *const col_data = static_cast<IdType *>(coo.col->data);
  const IdType *const data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;

426
427
  NDArray ret_indptr = NDArray::Empty(
      {static_cast<int64_t>(N) + 1}, coo.row->dtype, coo.row->ctx);
428
429
430
  NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  IdType *const Bp = static_cast<IdType *>(ret_indptr->data);
431
  Bp[N] = 0;
432
433
434
435
436
437
438
439
440
  IdType *const Bi = static_cast<IdType *>(ret_indices->data);
  IdType *const Bx = static_cast<IdType *>(ret_data->data);

  // store sorted data and original index.
  NDArray sorted_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  NDArray sorted_data_pos = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  IdType *const Sx = static_cast<IdType *>(sorted_data->data);
  IdType *const Si = static_cast<IdType *>(sorted_data_pos->data);

441
442
443
444
445
446
447
  // Lower number of threads if cost of parallelization is grater than gain
  // from making calculation parallel.
  const int64_t min_chunk_size = 1000;
  const int64_t num_threads_for_batch = 2 + (NNZ + N) / min_chunk_size;
  const int num_threads_required = std::min(
      static_cast<int64_t>(omp_get_max_threads()), num_threads_for_batch);

448
  // record row_idx in each thread.
449
450
  std::vector<std::vector<int64_t>> p_sum(
      num_threads_required, std::vector<int64_t>(num_threads_required));
451

452
#pragma omp parallel num_threads(num_threads_required)
453
454
455
456
457
458
459
460
461
  {
    const int num_threads = omp_get_num_threads();
    const int thread_id = omp_get_thread_num();
    CHECK_LT(thread_id, num_threads);

    const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads;
    const int64_t nz_start = thread_id * nz_chunk;
    const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk);

462
463
464
    const UIdType n_chunk = (N + num_threads - 1) / num_threads;
    const UIdType n_start = thread_id * n_chunk;
    const UIdType n_end = std::min(N, n_start + n_chunk);
465
466

    for (auto i = n_start; i < n_end; ++i) {
467
      Bp[i] = 0;
468
469
    }

470
471
    // iterate on NNZ data and count row_idx.
    for (auto i = nz_start; i < nz_end; ++i) {
472
473
      const IdType row_idx = row_data[i];
      const IdType row_thread_id = row_idx / n_chunk;
474
475
476
477
478
479
480
481
      ++p_sum[thread_id][row_thread_id];
    }

#pragma omp barrier
#pragma omp master
    // accumulate row_idx.
    {
      int64_t cum = 0;
482
483
      for (int j = 0; j < num_threads; ++j) {
        for (int i = 0; i < num_threads; ++i) {
484
485
486
          auto tmp = p_sum[i][j];
          p_sum[i][j] = cum;
          cum += tmp;
487
        }
488
489
      }
      CHECK_EQ(cum, NNZ);
490
    }
491
#pragma omp barrier
492
493
494
495
    const int64_t i_start = p_sum[0][thread_id];
    const int64_t i_end =
        thread_id + 1 == num_threads ? NNZ : p_sum[0][thread_id + 1];
#pragma omp barrier
496

497
    // sort data by row_idx and place into Sx/Si.
498
    auto &data_pos = p_sum[thread_id];
499
    for (auto i = nz_start; i < nz_end; ++i) {
500
501
      const IdType row_idx = row_data[i];
      const IdType row_thread_id = row_idx / n_chunk;
502
503
504
505
      const int64_t pos = data_pos[row_thread_id]++;
      Sx[pos] = data == nullptr ? i : data[i];
      Si[pos] = i;
    }
506

507
#pragma omp barrier
508

509
510
511
    // Now we're able to do coo2csr on sorted data in each thread in parallel.
    // compute data number on each row_idx.
    for (auto i = i_start; i < i_end; ++i) {
512
      const UIdType row_idx = row_data[Si[i]];
513
514
515
516
      ++Bp[row_idx + 1];
    }

    // accumulate on each row
517
518
519
520
    IdType cumsum = i_start;
    for (auto i = n_start + 1; i <= n_end; ++i) {
      const auto tmp = Bp[i];
      Bp[i] = cumsum;
521
522
523
524
525
      cumsum += tmp;
    }

    // update Bi/Bp/Bx
    for (auto i = i_start; i < i_end; ++i) {
526
527
      const UIdType row_idx = row_data[Si[i]];
      const int64_t dest = (Bp[row_idx + 1]++);
528
529
530
531
      Bi[dest] = col_data[Si[i]];
      Bx[dest] = Sx[i];
    }
  }
532
533
534
  return CSRMatrix(
      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,
      coo.col_sorted);
535
536
}

537
538
template <class IdType>
CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) {
539
540
541
542
543
  // Unsigned version of the original integer index data type.
  // It avoids overflow in (N + num_threads) and (n_start + n_chunk) below.
  typedef typename std::make_unsigned<IdType>::type UIdType;

  const UIdType N = coo.num_rows;
544
545
546
547
548
549
  const int64_t NNZ = coo.row->shape[0];
  const IdType *const row_data = static_cast<IdType *>(coo.row->data);
  const IdType *const col_data = static_cast<IdType *>(coo.col->data);
  const IdType *const data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;

550
551
  NDArray ret_indptr = NDArray::Empty(
      {static_cast<int64_t>(N) + 1}, coo.row->dtype, coo.row->ctx);
552
553
554
555
556
557
558
559
560
561
  NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  IdType *const Bp = static_cast<IdType *>(ret_indptr->data);
  Bp[0] = 0;
  IdType *const Bi = static_cast<IdType *>(ret_indices->data);
  IdType *const Bx = static_cast<IdType *>(ret_data->data);

  // the offset within each row, that each thread will write to
  std::vector<std::vector<IdType>> local_ptrs;
  std::vector<int64_t> thread_prefixsum;
562
563

#pragma omp parallel
564
565
566
567
  {
    const int num_threads = omp_get_num_threads();
    const int thread_id = omp_get_thread_num();
    CHECK_LT(thread_id, num_threads);
568

569
570
571
    const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads;
    const int64_t nz_start = thread_id * nz_chunk;
    const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk);
572

573
574
575
    const UIdType n_chunk = (N + num_threads - 1) / num_threads;
    const UIdType n_start = thread_id * n_chunk;
    const UIdType n_end = std::min(N, n_start + n_chunk);
576
577

#pragma omp master
578
579
580
581
    {
      local_ptrs.resize(num_threads);
      thread_prefixsum.resize(num_threads + 1);
    }
582
583

#pragma omp barrier
584
    local_ptrs[thread_id].resize(N, 0);
585

586
587
588
    for (int64_t i = nz_start; i < nz_end; ++i) {
      ++local_ptrs[thread_id][row_data[i]];
    }
589
590

#pragma omp barrier
591
592
    // compute prefixsum in parallel
    int64_t sum = 0;
593
    for (UIdType i = n_start; i < n_end; ++i) {
594
595
      IdType tmp = 0;
      for (int j = 0; j < num_threads; ++j) {
596
597
598
        auto previous = local_ptrs[j][i];
        local_ptrs[j][i] = tmp;
        tmp += previous;
599
      }
600
601
602
603
      sum += tmp;
      Bp[i + 1] = sum;
    }
    thread_prefixsum[thread_id + 1] = sum;
604
605
606

#pragma omp barrier
#pragma omp master
607
    {
608
      for (int i = 0; i < num_threads; ++i) {
609
        thread_prefixsum[i + 1] += thread_prefixsum[i];
610
      }
611
612
      CHECK_EQ(thread_prefixsum[num_threads], NNZ);
    }
613
614
#pragma omp barrier

615
    sum = thread_prefixsum[thread_id];
616
    for (UIdType i = n_start; i < n_end; ++i) {
617
618
      Bp[i + 1] += sum;
    }
619
620

#pragma omp barrier
621
622
623
624
625
    for (int64_t i = nz_start; i < nz_end; ++i) {
      const IdType r = row_data[i];
      const int64_t index = Bp[r] + local_ptrs[thread_id][r]++;
      Bi[index] = col_data[i];
      Bx[index] = data ? data[i] : i;
626
    }
627
  }
628
629
  CHECK_EQ(Bp[N], NNZ);

630
631
632
  return CSRMatrix(
      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,
      coo.col_sorted);
633
}
634

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
// complexity: time O(NNZ), space O(1)
template <typename IdType>
CSRMatrix UnSortedSmallCOOToCSR(COOMatrix coo) {
  const int64_t N = coo.num_rows;
  const int64_t NNZ = coo.row->shape[0];
  const IdType *row_data = static_cast<IdType *>(coo.row->data);
  const IdType *col_data = static_cast<IdType *>(coo.col->data);
  const IdType *data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
  NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx);
  NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
  IdType *Bp = static_cast<IdType *>(ret_indptr->data);
  IdType *Bi = static_cast<IdType *>(ret_indices->data);
  IdType *Bx = static_cast<IdType *>(ret_data->data);

  // Count elements in each row
  std::fill(Bp, Bp + N, 0);
  for (int64_t i = 0; i < NNZ; ++i) {
    Bp[row_data[i]]++;
  }

  // Convert to indexes
  for (IdType i = 0, cumsum = 0; i < N; ++i) {
    const IdType temp = Bp[i];
    Bp[i] = cumsum;
    cumsum += temp;
  }

  for (int64_t i = 0; i < NNZ; ++i) {
    const IdType r = row_data[i];
    Bi[Bp[r]] = col_data[i];
    Bx[Bp[r]] = data ? data[i] : i;
    Bp[r]++;
  }

  // Restore the indptr
  for (int64_t i = N; i > 0; --i) {
    Bp[i] = Bp[i - 1];
  }
  Bp[0] = 0;

  return CSRMatrix(
      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,
      coo.col_sorted);
}

enum class COOToCSRAlg {
  sorted = 0,
  unsortedSmall,
  unsortedSparse,
  unsortedDense
};
688

689
/**
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
 * Chose COO to CSR format conversion algorithm for given COO matrix according
 * to heuristic based on measured performance.
 *
 * Implementation and complexity details. N: num_nodes, NNZ: num_edges, P:
 * num_threads.
 *   1. If row is sorted in COO, SortedCOOToCSR<> is applied. Time: O(NNZ/P),
 * space: O(1).
 *   2 If row is NOT sorted in COO and graph is small (small number of NNZ),
 * UnSortedSmallCOOToCSR<> is applied. Time: O(NNZ), space O(N).
 *   3 If row is NOT sorted in COO and graph is sparse (low average degree),
 * UnSortedSparseCOOToCSR<> is applied. Time: O(NNZ/P + N/P + P^2),
 * space O(NNZ + P^2).
 *   4. If row is NOT sorted in COO and graph is dense (medium/high average
 * degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P),
 * space O(NNZ + N*P).
 *
 * Note:
 *   If you change this function, change also _TestCOOToCSRAlgs in
 * tests/cpp/test_spmat_coo.cc
 */
template <typename IdType>
inline COOToCSRAlg WhichCOOToCSR(const COOMatrix &coo) {
  if (coo.row_sorted) {
    return COOToCSRAlg::sorted;
  } else {
#ifdef _WIN32
    // On Windows omp_get_max_threads() gives larger value than later OMP can
    // spawn.
    int64_t num_threads;
#pragma omp parallel
#pragma master
    { num_threads = omp_get_num_threads(); }
#else
    const int64_t num_threads = omp_get_max_threads();
#endif
    const int64_t N = coo.num_rows;
    const int64_t NNZ = coo.row->shape[0];
    // Parameters below are heuristically chosen according to measured
    // performance.
    const int64_t type_scale = sizeof(IdType) >> 1;
    const int64_t small = 50 * num_threads * type_scale * type_scale;
    if (NNZ < small || num_threads == 1) {
      // For relatively small number of non zero elements cost of spread
      // algorithm between threads is bigger than improvements from using
      // many cores
      return COOToCSRAlg::unsortedSmall;
    } else if (type_scale * NNZ < num_threads * N) {
      // For relatively small number of non zero elements in matrix, sparse
      // parallel version of algorithm is more efficient than dense.
      return COOToCSRAlg::unsortedSparse;
    }
    return COOToCSRAlg::unsortedDense;
  }
}

}  // namespace

747
template <DGLDeviceType XPU, typename IdType>
748
CSRMatrix COOToCSR(COOMatrix coo) {
749
750
751
752
753
754
755
  switch (WhichCOOToCSR<IdType>(coo)) {
    case COOToCSRAlg::sorted:
      return SortedCOOToCSR<IdType>(coo);
    case COOToCSRAlg::unsortedSmall:
    default:
      return UnSortedSmallCOOToCSR<IdType>(coo);
    case COOToCSRAlg::unsortedSparse:
756
      return UnSortedSparseCOOToCSR<IdType>(coo);
757
758
    case COOToCSRAlg::unsortedDense:
      return UnSortedDenseCOOToCSR<IdType>(coo);
759
  }
760
761
}

762
763
template CSRMatrix COOToCSR<kDGLCPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCPU, int64_t>(COOMatrix coo);
764
765
766

///////////////////////////// COOSliceRows /////////////////////////////

767
template <DGLDeviceType XPU, typename IdType>
768
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
769
  // TODO(minjie): use binary search when coo.row_sorted is true
770
771
772
  CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start;
  CHECK(end > 0 && end <= coo.num_rows) << "Invalid end row " << end;

773
774
775
776
  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
  const IdType *coo_data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
777
778

  std::vector<IdType> ret_row, ret_col;
779
  std::vector<IdType> ret_data;
780
781
782
783
784
785
786
787
788
789

  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
    const IdType row_id = coo_row_data[i];
    const IdType col_id = coo_col_data[i];
    if (row_id < end && row_id >= start) {
      ret_row.push_back(row_id - start);
      ret_col.push_back(col_id);
      ret_data.push_back(coo_data ? coo_data[i] : i);
    }
  }
790
  return COOMatrix(
791
792
793
      end - start, coo.num_cols, NDArray::FromVector(ret_row),
      NDArray::FromVector(ret_col), NDArray::FromVector(ret_data),
      coo.row_sorted, coo.col_sorted);
794
795
}

796
797
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
798

799
template <DGLDeviceType XPU, typename IdType>
800
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
801
802
803
804
  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
  const IdType *coo_data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
805
806

  std::vector<IdType> ret_row, ret_col;
807
  std::vector<IdType> ret_data;
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822

  IdHashMap<IdType> hashmap(rows);

  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
    const IdType row_id = coo_row_data[i];
    const IdType col_id = coo_col_data[i];
    const IdType mapped_row_id = hashmap.Map(row_id, -1);
    if (mapped_row_id != -1) {
      ret_row.push_back(mapped_row_id);
      ret_col.push_back(col_id);
      ret_data.push_back(coo_data ? coo_data[i] : i);
    }
  }

  return COOMatrix{
823
824
825
826
827
828
829
      rows->shape[0],
      coo.num_cols,
      NDArray::FromVector(ret_row),
      NDArray::FromVector(ret_col),
      NDArray::FromVector(ret_data),
      coo.row_sorted,
      coo.col_sorted};
830
831
}

832
833
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, NDArray);
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, NDArray);
834
835
836

///////////////////////////// COOSliceMatrix /////////////////////////////

837
template <DGLDeviceType XPU, typename IdType>
838
839
840
841
842
843
COOMatrix COOSliceMatrix(
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {
  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
  const IdType *coo_data =
      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
844
845
846
847

  IdHashMap<IdType> row_map(rows), col_map(cols);

  std::vector<IdType> ret_row, ret_col;
848
  std::vector<IdType> ret_data;
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863

  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
    const IdType row_id = coo_row_data[i];
    const IdType col_id = coo_col_data[i];
    const IdType mapped_row_id = row_map.Map(row_id, -1);
    if (mapped_row_id != -1) {
      const IdType mapped_col_id = col_map.Map(col_id, -1);
      if (mapped_col_id != -1) {
        ret_row.push_back(mapped_row_id);
        ret_col.push_back(mapped_col_id);
        ret_data.push_back(coo_data ? coo_data[i] : i);
      }
    }
  }

864
865
866
867
  return COOMatrix(
      rows->shape[0], cols->shape[0], NDArray::FromVector(ret_row),
      NDArray::FromVector(ret_col), NDArray::FromVector(ret_data),
      coo.row_sorted, coo.col_sorted);
868
869
}

870
template COOMatrix COOSliceMatrix<kDGLCPU, int32_t>(
871
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
872
template COOMatrix COOSliceMatrix<kDGLCPU, int64_t>(
873
874
    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);

Da Zheng's avatar
Da Zheng committed
875
876
///////////////////////////// COOReorder /////////////////////////////

877
template <DGLDeviceType XPU, typename IdType>
878
879
880
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_id_arr,
    runtime::NDArray new_col_id_arr) {
Da Zheng's avatar
Da Zheng committed
881
882
883
884
  CHECK_SAME_DTYPE(coo.row, new_row_id_arr);
  CHECK_SAME_DTYPE(coo.col, new_col_id_arr);

  // Input COO
885
886
  const IdType *in_rows = static_cast<IdType *>(coo.row->data);
  const IdType *in_cols = static_cast<IdType *>(coo.col->data);
Da Zheng's avatar
Da Zheng committed
887
888
889
890
  int64_t num_rows = coo.num_rows;
  int64_t num_cols = coo.num_cols;
  int64_t nnz = coo.row->shape[0];
  CHECK_EQ(num_rows, new_row_id_arr->shape[0])
891
892
      << "The new row Id array needs to be the same as the number of rows of "
         "COO";
Da Zheng's avatar
Da Zheng committed
893
  CHECK_EQ(num_cols, new_col_id_arr->shape[0])
894
895
      << "The new col Id array needs to be the same as the number of cols of "
         "COO";
Da Zheng's avatar
Da Zheng committed
896
897

  // New row/col Ids.
898
899
  const IdType *new_row_ids = static_cast<IdType *>(new_row_id_arr->data);
  const IdType *new_col_ids = static_cast<IdType *>(new_col_id_arr->data);
Da Zheng's avatar
Da Zheng committed
900
901
902
903
904

  // Output COO
  NDArray out_row_arr = NDArray::Empty({nnz}, coo.row->dtype, coo.row->ctx);
  NDArray out_col_arr = NDArray::Empty({nnz}, coo.col->dtype, coo.col->ctx);
  NDArray out_data_arr = COOHasData(coo) ? coo.data : NullArray();
905
906
  IdType *out_row = static_cast<IdType *>(out_row_arr->data);
  IdType *out_col = static_cast<IdType *>(out_col_arr->data);
Da Zheng's avatar
Da Zheng committed
907

908
909
910
911
912
913
  parallel_for(0, nnz, [=](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
      out_row[i] = new_row_ids[in_rows[i]];
      out_col[i] = new_col_ids[in_cols[i]];
    }
  });
Da Zheng's avatar
Da Zheng committed
914
915
916
  return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr);
}

917
918
919
920
template COOMatrix COOReorder<kDGLCPU, int64_t>(
    COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template COOMatrix COOReorder<kDGLCPU, int32_t>(
    COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
921

922
923
924
}  // namespace impl
}  // namespace aten
}  // namespace dgl