spmat_op_impl_csr.cc 22.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/cpu/spmat_op_impl_csr.cc
 * @brief CSR matrix operator CPU implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/runtime/parallel_for.h>
8

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
9
#include <atomic>
10
11
12
13
#include <numeric>
#include <unordered_set>
#include <vector>

14
#include "array_utils.h"
15
16
17
18

namespace dgl {

using runtime::NDArray;
19
using runtime::parallel_for;
20
21
22
23
24
25

namespace aten {
namespace impl {

///////////////////////////// CSRIsNonZero /////////////////////////////

26
template <DGLDeviceType XPU, typename IdType>
27
28
29
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
Da Zheng's avatar
Da Zheng committed
30
  if (csr.sorted) {
31
32
    const IdType* start = indices_data + indptr_data[row];
    const IdType* end = indices_data + indptr_data[row + 1];
Da Zheng's avatar
Da Zheng committed
33
34
35
36
37
38
    return std::binary_search(start, end, col);
  } else {
    for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) {
      if (indices_data[i] == col) {
        return true;
      }
39
40
41
42
43
    }
  }
  return false;
}

44
45
template bool CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
46

47
template <DGLDeviceType XPU, typename IdType>
48
49
50
51
52
53
54
55
56
57
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto rstlen = std::max(rowlen, collen);
  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
  const IdType* row_data = static_cast<IdType*>(row->data);
  const IdType* col_data = static_cast<IdType*>(col->data);
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
58
59
60
61
62
63
64
65
66
  runtime::parallel_for(
      0, std::max(rowlen, collen), 1, [=](int64_t b, int64_t e) {
        int64_t i = (row_stride == 0) ? 0 : b;
        int64_t j = (col_stride == 0) ? 0 : b;
        for (int64_t k = b; i < e && j < e;
             i += row_stride, j += col_stride, ++k)
          rst_data[k] =
              CSRIsNonZero<XPU, IdType>(csr, row_data[i], col_data[j]) ? 1 : 0;
      });
67
68
69
  return rst;
}

70
71
template NDArray CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
72
73
74

///////////////////////////// CSRHasDuplicate /////////////////////////////

75
template <DGLDeviceType XPU, typename IdType>
76
77
78
79
80
bool CSRHasDuplicate(CSRMatrix csr) {
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
  for (IdType src = 0; src < csr.num_rows; ++src) {
    std::unordered_set<IdType> hashmap;
81
    for (IdType eid = indptr_data[src]; eid < indptr_data[src + 1]; ++eid) {
82
83
84
85
86
87
88
89
90
91
92
      const IdType dst = indices_data[eid];
      if (hashmap.count(dst)) {
        return true;
      } else {
        hashmap.insert(dst);
      }
    }
  }
  return false;
}

93
94
template bool CSRHasDuplicate<kDGLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCPU, int64_t>(CSRMatrix csr);
95
96
97

///////////////////////////// CSRGetRowNNZ /////////////////////////////

98
template <DGLDeviceType XPU, typename IdType>
99
100
101
102
103
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  return indptr_data[row + 1] - indptr_data[row];
}

104
105
template int64_t CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, int64_t);
106

107
template <DGLDeviceType XPU, typename IdType>
108
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
109
  CHECK_SAME_DTYPE(csr.indices, rows);
110
111
112
113
114
115
116
117
118
119
120
121
  const auto len = rows->shape[0];
  const IdType* vid_data = static_cast<IdType*>(rows->data);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
  for (int64_t i = 0; i < len; ++i) {
    const auto vid = vid_data[i];
    rst_data[i] = indptr_data[vid + 1] - indptr_data[vid];
  }
  return rst;
}

122
123
template NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray);
124

125
/////////////////////////// CSRGetRowColumnIndices /////////////////////////////
126

127
template <DGLDeviceType XPU, typename IdType>
128
129
130
131
132
133
134
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const int64_t offset = indptr_data[row] * sizeof(IdType);
  return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}

135
136
template NDArray CSRGetRowColumnIndices<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCPU, int64_t>(CSRMatrix, int64_t);
137
138
139

///////////////////////////// CSRGetRowData /////////////////////////////

140
template <DGLDeviceType XPU, typename IdType>
141
142
143
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
144
145
146
147
  const int64_t offset = indptr_data[row] * sizeof(IdType);
  if (CSRHasData(csr))
    return csr.data.CreateView({len}, csr.data->dtype, offset);
  else
148
149
    return aten::Range(
        offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
150
151
}

152
153
template NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t);
154
155
156
157

///////////////////////////// CSRGetData /////////////////////////////
///////////////////////////// CSRGetDataAndIndices /////////////////////////////

158
template <DGLDeviceType XPU, typename IdType>
159
160
161
162
163
164
void CollectDataIndicesFromSorted(
    const IdType* indices_data, const IdType* data, const IdType start,
    const IdType end, const IdType col, std::vector<IdType>* col_vec,
    std::vector<IdType>* ret_vec) {
  const IdType* start_ptr = indices_data + start;
  const IdType* end_ptr = indices_data + end;
Da Zheng's avatar
Da Zheng committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  auto it = std::lower_bound(start_ptr, end_ptr, col);
  // This might be a multi-graph. We need to collect all of the matched
  // columns.
  for (; it != end_ptr; it++) {
    // If the col exist
    if (*it == col) {
      IdType idx = it - indices_data;
      col_vec->push_back(indices_data[idx]);
      ret_vec->push_back(data[idx]);
    } else {
      // If we find a column that is different, we can stop searching now.
      break;
    }
  }
}

181
template <DGLDeviceType XPU, typename IdType>
182
183
184
185
std::vector<NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, NDArray rows, NDArray cols) {
  // TODO(minjie): more efficient implementation for matrix without duplicate
  // entries
186
187
188
189
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
190
      << "Invalid row and col id array.";
191
192
193
194
195
196
197
198

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const IdType* row_data = static_cast<IdType*>(rows->data);
  const IdType* col_data = static_cast<IdType*>(cols->data);

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
199
200
  const IdType* data =
      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
201
202

  std::vector<IdType> ret_rows, ret_cols;
203
  std::vector<IdType> ret_data;
204

205
206
  for (int64_t i = 0, j = 0; i < rowlen && j < collen;
       i += row_stride, j += col_stride) {
207
    const IdType row_id = row_data[i], col_id = col_data[j];
208
209
210
211
    CHECK(row_id >= 0 && row_id < csr.num_rows)
        << "Invalid row index: " << row_id;
    CHECK(col_id >= 0 && col_id < csr.num_cols)
        << "Invalid col index: " << col_id;
Da Zheng's avatar
Da Zheng committed
212
213
    if (csr.sorted) {
      // Here we collect col indices and data.
214
215
216
      CollectDataIndicesFromSorted<XPU, IdType>(
          indices_data, data, indptr_data[row_id], indptr_data[row_id + 1],
          col_id, &ret_cols, &ret_data);
Da Zheng's avatar
Da Zheng committed
217
218
219
220
221
      // We need to add row Ids.
      while (ret_rows.size() < ret_data.size()) {
        ret_rows.push_back(row_id);
      }
    } else {
222
      for (IdType i = indptr_data[row_id]; i < indptr_data[row_id + 1]; ++i) {
Da Zheng's avatar
Da Zheng committed
223
        if (indices_data[i] == col_id) {
224
225
          ret_rows.push_back(row_id);
          ret_cols.push_back(col_id);
226
          ret_data.push_back(data ? data[i] : i);
Da Zheng's avatar
Da Zheng committed
227
        }
228
229
230
231
      }
    }
  }

232
233
234
235
  return {
      NDArray::FromVector(ret_rows, csr.indptr->ctx),
      NDArray::FromVector(ret_cols, csr.indptr->ctx),
      NDArray::FromVector(ret_data, csr.data->ctx)};
236
237
}

238
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int32_t>(
239
    CSRMatrix csr, NDArray rows, NDArray cols);
240
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int64_t>(
241
242
243
244
245
246
    CSRMatrix csr, NDArray rows, NDArray cols);

///////////////////////////// CSRTranspose /////////////////////////////

// for a matrix of shape (N, M) and NNZ
// complexity: time O(NNZ + max(N, M)), space O(1)
247
template <DGLDeviceType XPU, typename IdType>
248
249
250
251
252
253
CSRMatrix CSRTranspose(CSRMatrix csr) {
  const int64_t N = csr.num_rows;
  const int64_t M = csr.num_cols;
  const int64_t nnz = csr.indices->shape[0];
  const IdType* Ap = static_cast<IdType*>(csr.indptr->data);
  const IdType* Aj = static_cast<IdType*>(csr.indices->data);
254
255
256
257
258
259
  const IdType* Ax =
      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
  NDArray ret_indptr =
      NDArray::Empty({M + 1}, csr.indptr->dtype, csr.indptr->ctx);
  NDArray ret_indices =
      NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
260
  NDArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
261
262
  IdType* Bp = static_cast<IdType*>(ret_indptr->data);
  IdType* Bi = static_cast<IdType*>(ret_indices->data);
263
  IdType* Bx = static_cast<IdType*>(ret_data->data);
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

  std::fill(Bp, Bp + M, 0);

  for (int64_t j = 0; j < nnz; ++j) {
    Bp[Aj[j]]++;
  }

  // cumsum
  for (int64_t i = 0, cumsum = 0; i < M; ++i) {
    const IdType temp = Bp[i];
    Bp[i] = cumsum;
    cumsum += temp;
  }
  Bp[M] = nnz;

  for (int64_t i = 0; i < N; ++i) {
280
    for (IdType j = Ap[i]; j < Ap[i + 1]; ++j) {
281
282
      const IdType dst = Aj[j];
      Bi[Bp[dst]] = i;
283
      Bx[Bp[dst]] = Ax ? Ax[j] : j;
284
285
286
287
288
289
290
291
292
293
294
      Bp[dst]++;
    }
  }

  // correct the indptr
  for (int64_t i = 0, last = 0; i <= M; ++i) {
    IdType temp = Bp[i];
    Bp[i] = last;
    last = temp;
  }

295
296
  return CSRMatrix{
      csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
297
298
}

299
300
template CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCPU, int64_t>(CSRMatrix csr);
301
302

///////////////////////////// CSRToCOO /////////////////////////////
303
template <DGLDeviceType XPU, typename IdType>
304
305
306
307
308
COOMatrix CSRToCOO(CSRMatrix csr) {
  const int64_t nnz = csr.indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
309
310
  parallel_for(0, csr.indptr->shape[0] - 1, 10000, [=](int64_t b, int64_t e) {
    for (auto i = b; i < e; ++i) {
311
312
      std::fill(
          ret_row_data + indptr_data[i], ret_row_data + indptr_data[i + 1], i);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
313
314
    }
  });
315
316
317
  return COOMatrix(
      csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,
      csr.sorted);
318
319
}

320
321
template COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCPU, int64_t>(CSRMatrix csr);
322
323

// complexity: time O(NNZ), space O(1)
324
template <DGLDeviceType XPU, typename IdType>
325
326
327
328
329
330
331
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
  const int64_t N = csr.num_rows;
  const int64_t M = csr.num_cols;
  const int64_t nnz = csr.indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
  // data array should have the same type as the indices arrays
332
333
  const IdType* data =
      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
334
335
336
337
338
  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
  IdType* ret_col_data = static_cast<IdType*>(ret_col->data);
  // scatter using the indices in the data array
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
339
340
341
342
343
344
345
  parallel_for(0, N, 10000, [=](int64_t b, int64_t e) {
    for (auto row = b; row < e; ++row) {
      for (IdType j = indptr_data[row]; j < indptr_data[row + 1]; ++j) {
        const IdType col = indices_data[j];
        ret_row_data[data ? data[j] : j] = row;
        ret_col_data[data ? data[j] : j] = col;
      }
346
    }
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
347
  });
348
  return COOMatrix(N, M, ret_row, ret_col);
349
350
}

351
352
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int64_t>(CSRMatrix csr);
353
354
355

///////////////////////////// CSRSliceRows /////////////////////////////

356
template <DGLDeviceType XPU, typename IdType>
357
358
359
360
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const int64_t num_rows = end - start;
  const int64_t nnz = indptr[end] - indptr[start];
361
362
  IdArray ret_indptr =
      IdArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indices->ctx);
363
  IdType* r_indptr = static_cast<IdType*>(ret_indptr->data);
364
365
366
367
  for (int64_t i = start; i < end + 1; ++i) {
    r_indptr[i - start] = indptr[i] - indptr[start];
  }
  // indices and data can be view arrays
368
369
370
371
  IdArray ret_indices = csr.indices.CreateView(
      {nnz}, csr.indices->dtype, indptr[start] * sizeof(IdType));
  IdArray ret_data;
  if (CSRHasData(csr))
372
373
    ret_data = csr.data.CreateView(
        {nnz}, csr.data->dtype, indptr[start] * sizeof(IdType));
374
  else
375
376
377
378
    ret_data = aten::Range(
        indptr[start], indptr[end], csr.indptr->dtype.bits, csr.indptr->ctx);
  return CSRMatrix(
      num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
379
380
}

381
382
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
383

384
template <DGLDeviceType XPU, typename IdType>
385
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
386
  CHECK_SAME_DTYPE(csr.indices, rows);
387
388
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
389
390
  const IdType* data =
      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
391
392
393
394
395
396
397
398
  const auto len = rows->shape[0];
  const IdType* rows_data = static_cast<IdType*>(rows->data);
  int64_t nnz = 0;

  CSRMatrix ret;
  ret.num_rows = len;
  ret.num_cols = csr.num_cols;
  ret.indptr = NDArray::Empty({len + 1}, csr.indptr->dtype, csr.indices->ctx);
399
400
401
402
403
404

  IdType* ret_indptr_data = static_cast<IdType*>(ret.indptr->data);
  ret_indptr_data[0] = 0;

  std::vector<IdType> sums;

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
405
406
407
408
  std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
  bool err = false;
  std::stringstream err_msg_stream;

409
410
// Perform two-round parallel prefix sum using OpenMP
#pragma omp parallel
411
412
413
414
  {
    int64_t tid = omp_get_thread_num();
    int64_t num_threads = omp_get_num_threads();

415
#pragma omp single
416
    {
417
418
      sums.resize(num_threads + 1);
      sums[0] = 0;
419
420
421
422
    }

    int64_t sum = 0;

423
424
// First round of parallel prefix sum. All threads perform local prefix sums.
#pragma omp for schedule(static) nowait
425
426
    for (int64_t i = 0; i < len; ++i) {
      int64_t rid = rows_data[i];
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
427
428
      if (rid >= csr.num_rows) {
        if (!err_flag.test_and_set()) {
429
430
          err_msg_stream << "expect row ID " << rid
                         << " to be less than number of rows " << csr.num_rows;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
431
432
433
434
435
436
          err = true;
        }
      } else {
        sum += indptr_data[rid + 1] - indptr_data[rid];
        ret_indptr_data[i + 1] = sum;
      }
437
438
    }
    sums[tid + 1] = sum;
439
#pragma omp barrier
440

441
#pragma omp single
442
    {
443
      for (int64_t i = 1; i < num_threads; ++i) sums[i] += sums[i - 1];
444
445
446
447
    }

    int64_t offset = sums[tid];

448
449
450
// Second round of parallel prefix sum. Update the local prefix sums.
#pragma omp for schedule(static)
    for (int64_t i = 0; i < len; ++i) ret_indptr_data[i + 1] += offset;
451
  }
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
452
453
454
455
  if (err) {
    LOG(FATAL) << err_msg_stream.str();
    return ret;
  }
456
457
458
459
460

  // After the prefix sum, the last element of ret_indptr_data holds the
  // sum of all elements
  nnz = ret_indptr_data[len];

461
  ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
462
463
  ret.data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
  ret.sorted = csr.sorted;
464
465

  IdType* ret_indices_data = static_cast<IdType*>(ret.indices->data);
466
  IdType* ret_data = static_cast<IdType*>(ret.data->data);
467
468
469
470
471

  parallel_for(0, len, [=](int64_t b, int64_t e) {
    for (auto i = b; i < e; ++i) {
      const IdType rid = rows_data[i];
      // note: zero is allowed
472
473
474
      std::copy(
          indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1],
          ret_indices_data + ret_indptr_data[i]);
475
      if (data)
476
477
478
        std::copy(
            data + indptr_data[rid], data + indptr_data[rid + 1],
            ret_data + ret_indptr_data[i]);
479
      else
480
481
482
        std::iota(
            ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1],
            indptr_data[rid]);
483
484
    }
  });
485
486
487
  return ret;
}

488
489
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, NDArray);
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, NDArray);
490
491
492

///////////////////////////// CSRSliceMatrix /////////////////////////////

493
template <DGLDeviceType XPU, typename IdType>
494
495
CSRMatrix CSRSliceMatrix(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
496
497
498
499
  IdHashMap<IdType> hashmap(cols);
  const int64_t new_nrows = rows->shape[0];
  const int64_t new_ncols = cols->shape[0];
  const IdType* rows_data = static_cast<IdType*>(rows->data);
500
  const bool has_data = CSRHasData(csr);
501
502
503

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
504
505
  const IdType* data =
      has_data ? static_cast<IdType*>(csr.data->data) : nullptr;
506
507

  std::vector<IdType> sub_indptr, sub_indices;
508
  std::vector<IdType> sub_data;
509
510
511
512
513
514
515
516
517
518
519
520
  sub_indptr.resize(new_nrows + 1, 0);
  const IdType kInvalidId = new_ncols + 1;
  for (int64_t i = 0; i < new_nrows; ++i) {
    // NOTE: newi == i
    const IdType oldi = rows_data[i];
    CHECK(oldi >= 0 && oldi < csr.num_rows) << "Invalid row index: " << oldi;
    for (IdType p = indptr_data[oldi]; p < indptr_data[oldi + 1]; ++p) {
      const IdType oldj = indices_data[p];
      const IdType newj = hashmap.Map(oldj, kInvalidId);
      if (newj != kInvalidId) {
        ++sub_indptr[i];
        sub_indices.push_back(newj);
521
        sub_data.push_back(has_data ? data[p] : p);
522
523
524
525
526
527
528
529
530
531
532
533
534
      }
    }
  }

  // cumsum sub_indptr
  for (int64_t i = 0, cumsum = 0; i < new_nrows; ++i) {
    const IdType temp = sub_indptr[i];
    sub_indptr[i] = cumsum;
    cumsum += temp;
  }
  sub_indptr[new_nrows] = sub_indices.size();

  const int64_t nnz = sub_data.size();
535
536
  NDArray sub_data_arr =
      NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
537
  IdType* ptr = static_cast<IdType*>(sub_data_arr->data);
538
  std::copy(sub_data.begin(), sub_data.end(), ptr);
539
540
541
  return CSRMatrix{
      new_nrows, new_ncols, NDArray::FromVector(sub_indptr, csr.indptr->ctx),
      NDArray::FromVector(sub_indices, csr.indptr->ctx), sub_data_arr};
542
543
}

544
template CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>(
545
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
546
template CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>(
547
548
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

Da Zheng's avatar
Da Zheng committed
549
550
///////////////////////////// CSRReorder /////////////////////////////

551
template <DGLDeviceType XPU, typename IdType>
552
553
554
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_id_arr,
    runtime::NDArray new_col_id_arr) {
Da Zheng's avatar
Da Zheng committed
555
556
557
558
559
560
561
562
563
564
565
566
  CHECK_SAME_DTYPE(csr.indices, new_row_id_arr);
  CHECK_SAME_DTYPE(csr.indices, new_col_id_arr);

  // Input CSR
  const IdType* in_indptr = static_cast<IdType*>(csr.indptr->data);
  const IdType* in_indices = static_cast<IdType*>(csr.indices->data);
  const IdType* in_data = static_cast<IdType*>(csr.data->data);
  int64_t num_rows = csr.num_rows;
  int64_t num_cols = csr.num_cols;
  int64_t nnz = csr.indices->shape[0];
  CHECK_EQ(nnz, in_indptr[num_rows]);
  CHECK_EQ(num_rows, new_row_id_arr->shape[0])
567
568
      << "The new row Id array needs to be the same as the number of rows of "
         "CSR";
Da Zheng's avatar
Da Zheng committed
569
  CHECK_EQ(num_cols, new_col_id_arr->shape[0])
570
571
      << "The new col Id array needs to be the same as the number of cols of "
         "CSR";
Da Zheng's avatar
Da Zheng committed
572
573
574
575
576
577

  // New row/col Ids.
  const IdType* new_row_ids = static_cast<IdType*>(new_row_id_arr->data);
  const IdType* new_col_ids = static_cast<IdType*>(new_col_id_arr->data);

  // Output CSR
578
579
580
581
  NDArray out_indptr_arr =
      NDArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indptr->ctx);
  NDArray out_indices_arr =
      NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
Da Zheng's avatar
Da Zheng committed
582
  NDArray out_data_arr = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx);
583
584
585
  IdType* out_indptr = static_cast<IdType*>(out_indptr_arr->data);
  IdType* out_indices = static_cast<IdType*>(out_indices_arr->data);
  IdType* out_data = static_cast<IdType*>(out_data_arr->data);
Da Zheng's avatar
Da Zheng committed
586
587
588

  // Compute the length of rows for the new matrix.
  std::vector<IdType> new_row_lens(num_rows, -1);
589
590
591
592
593
594
  parallel_for(0, num_rows, [=, &new_row_lens](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
      int64_t new_row_id = new_row_ids[i];
      new_row_lens[new_row_id] = in_indptr[i + 1] - in_indptr[i];
    }
  });
Da Zheng's avatar
Da Zheng committed
595
596
597
598
599
600
601
602
603
604
  // Compute the starting location of each row in the new matrix.
  out_indptr[0] = 0;
  // This is sequential. It should be pretty fast.
  for (int64_t i = 0; i < num_rows; i++) {
    CHECK_GE(new_row_lens[i], 0);
    out_indptr[i + 1] = out_indptr[i] + new_row_lens[i];
  }
  CHECK_EQ(out_indptr[num_rows], nnz);
  // Copy indieces and data with the new order.
  // Here I iterate rows in the order of the old matrix.
605
606
  parallel_for(0, num_rows, [=](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
607
608
      const IdType* in_row = in_indices + in_indptr[i];
      const IdType* in_row_data = in_data + in_indptr[i];
609
610

      int64_t new_row_id = new_row_ids[i];
611
612
      IdType* out_row = out_indices + out_indptr[new_row_id];
      IdType* out_row_data = out_data + out_indptr[new_row_id];
613
614
615
616
617
618
619
620

      int64_t row_len = new_row_lens[new_row_id];
      // Here I iterate col indices in a row in the order of the old matrix.
      for (int64_t j = 0; j < row_len; j++) {
        out_row[j] = new_col_ids[in_row[j]];
        out_row_data[j] = in_row_data[j];
      }
      // TODO(zhengda) maybe we should sort the column indices.
Da Zheng's avatar
Da Zheng committed
621
    }
622
  });
623
624
  return CSRMatrix(
      num_rows, num_cols, out_indptr_arr, out_indices_arr, out_data_arr);
Da Zheng's avatar
Da Zheng committed
625
626
}

627
628
629
630
template CSRMatrix CSRReorder<kDGLCPU, int64_t>(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template CSRMatrix CSRReorder<kDGLCPU, int32_t>(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
Da Zheng's avatar
Da Zheng committed
631

632
633
634
}  // namespace impl
}  // namespace aten
}  // namespace dgl