"docs/source/api/vscode:/vscode.git/clone" did not exist on "a7e941c379f0f9ab472c844b6a9f1d05d687b4e1"
spmat_op_impl.cc 23.3 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cpu/spmat_op_impl.cc
 * \brief Sparse matrix operator CPU implementation
 */
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
9
#include <numeric>
10
#include "array_utils.h"
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

///////////////////////////// CSRIsNonZero /////////////////////////////

template <DLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
Da Zheng's avatar
Da Zheng committed
27
28
29
30
31
32
33
34
35
  if (csr.sorted) {
    const IdType *start = indices_data + indptr_data[row];
    const IdType *end = indices_data + indptr_data[row + 1];
    return std::binary_search(start, end, col);
  } else {
    for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) {
      if (indices_data[i] == col) {
        return true;
      }
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    }
  }
  return false;
}

template bool CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);

template <DLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto rstlen = std::max(rowlen, collen);
  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
  const IdType* row_data = static_cast<IdType*>(row->data);
  const IdType* col_data = static_cast<IdType*>(col->data);
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
    *(rst_data++) = CSRIsNonZero<XPU, IdType>(csr, row_data[i], col_data[j])? 1 : 0;
  }
  return rst;
}

template NDArray CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, NDArray, NDArray);

///////////////////////////// CSRHasDuplicate /////////////////////////////

template <DLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) {
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
  for (IdType src = 0; src < csr.num_rows; ++src) {
    std::unordered_set<IdType> hashmap;
    for (IdType eid = indptr_data[src]; eid < indptr_data[src+1]; ++eid) {
      const IdType dst = indices_data[eid];
      if (hashmap.count(dst)) {
        return true;
      } else {
        hashmap.insert(dst);
      }
    }
  }
  return false;
}

template bool CSRHasDuplicate<kDLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLCPU, int64_t>(CSRMatrix csr);

///////////////////////////// CSRGetRowNNZ /////////////////////////////

template <DLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  return indptr_data[row + 1] - indptr_data[row];
}

template int64_t CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);

template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
  const auto len = rows->shape[0];
  const IdType* vid_data = static_cast<IdType*>(rows->data);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
  for (int64_t i = 0; i < len; ++i) {
    const auto vid = vid_data[i];
    rst_data[i] = indptr_data[vid + 1] - indptr_data[vid];
  }
  return rst;
}

template NDArray CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, NDArray);

///////////////////////////// CSRGetRowColumnIndices /////////////////////////////

template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const int64_t offset = indptr_data[row] * sizeof(IdType);
  return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}

template NDArray CSRGetRowColumnIndices<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLCPU, int64_t>(CSRMatrix, int64_t);

///////////////////////////// CSRGetRowData /////////////////////////////

132
template <DLDeviceType XPU, typename IdType>
133
134
135
136
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
137
138
139
140
141
  const int64_t offset = indptr_data[row] * sizeof(IdType);
  if (CSRHasData(csr))
    return csr.data.CreateView({len}, csr.data->dtype, offset);
  else
    return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
142
143
}

144
145
template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t);
146
147
148

///////////////////////////// CSRGetData /////////////////////////////

149
150
template <DLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
Da Zheng's avatar
Da Zheng committed
151
                           const IdType start, const IdType end, const IdType col,
152
                           std::vector<IdType> *ret_vec) {
Da Zheng's avatar
Da Zheng committed
153
154
155
156
157
158
159
160
161
  const IdType *start_ptr = indices_data + start;
  const IdType *end_ptr = indices_data + end;
  auto it = std::lower_bound(start_ptr, end_ptr, col);
  // This might be a multi-graph. We need to collect all of the matched
  // columns.
  for (; it != end_ptr; it++) {
    // If the col exist
    if (*it == col) {
      IdType idx = it - indices_data;
162
      ret_vec->push_back(data? data[idx] : idx);
Da Zheng's avatar
Da Zheng committed
163
164
165
166
167
168
169
    } else {
      // If we find a column that is different, we can stop searching now.
      break;
    }
  }
}

170
template <DLDeviceType XPU, typename IdType>
171
172
173
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
174
  std::vector<IdType> ret_vec;
175
176
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
177
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
Da Zheng's avatar
Da Zheng committed
178
  if (csr.sorted) {
179
180
181
    CollectDataFromSorted<XPU, IdType>(indices_data, data,
                                       indptr_data[row], indptr_data[row + 1],
                                       col, &ret_vec);
Da Zheng's avatar
Da Zheng committed
182
183
184
  } else {
    for (IdType i = indptr_data[row]; i < indptr_data[row+1]; ++i) {
      if (indices_data[i] == col) {
185
        ret_vec.push_back(data? data[i] : i);
Da Zheng's avatar
Da Zheng committed
186
      }
187
188
    }
  }
189
  return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
190
191
}

192
193
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
194

195
template <DLDeviceType XPU, typename IdType>
196
197
198
199
200
201
202
203
204
205
206
207
208
209
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
    << "Invalid row and col id array.";

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const IdType* row_data = static_cast<IdType*>(rows->data);
  const IdType* col_data = static_cast<IdType*>(cols->data);

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
210
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
211

212
  std::vector<IdType> ret_vec;
213
214
215
216
217

  for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
    const IdType row_id = row_data[i], col_id = col_data[j];
    CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
    CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
Da Zheng's avatar
Da Zheng committed
218
    if (csr.sorted) {
219
220
221
      CollectDataFromSorted<XPU, IdType>(indices_data, data,
                                         indptr_data[row_id], indptr_data[row_id + 1],
                                         col_id, &ret_vec);
Da Zheng's avatar
Da Zheng committed
222
223
224
    } else {
      for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
        if (indices_data[i] == col_id) {
225
          ret_vec.push_back(data? data[i] : i);
Da Zheng's avatar
Da Zheng committed
226
        }
227
228
229
230
      }
    }
  }

231
  return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
232
233
}

234
235
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix csr, NDArray rows, NDArray cols);
236
237
238

///////////////////////////// CSRGetDataAndIndices /////////////////////////////

239
240
template <DLDeviceType XPU, typename IdType>
void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data,
Da Zheng's avatar
Da Zheng committed
241
242
                                  const IdType start, const IdType end, const IdType col,
                                  std::vector<IdType> *col_vec,
243
                                  std::vector<IdType> *ret_vec) {
Da Zheng's avatar
Da Zheng committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
  const IdType *start_ptr = indices_data + start;
  const IdType *end_ptr = indices_data + end;
  auto it = std::lower_bound(start_ptr, end_ptr, col);
  // This might be a multi-graph. We need to collect all of the matched
  // columns.
  for (; it != end_ptr; it++) {
    // If the col exist
    if (*it == col) {
      IdType idx = it - indices_data;
      col_vec->push_back(indices_data[idx]);
      ret_vec->push_back(data[idx]);
    } else {
      // If we find a column that is different, we can stop searching now.
      break;
    }
  }
}

262
template <DLDeviceType XPU, typename IdType>
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
  // TODO(minjie): more efficient implementation for matrix without duplicate entries
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
    << "Invalid row and col id array.";

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const IdType* row_data = static_cast<IdType*>(rows->data);
  const IdType* col_data = static_cast<IdType*>(cols->data);

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
278
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
279
280

  std::vector<IdType> ret_rows, ret_cols;
281
  std::vector<IdType> ret_data;
282
283
284
285
286

  for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
    const IdType row_id = row_data[i], col_id = col_data[j];
    CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
    CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
Da Zheng's avatar
Da Zheng committed
287
288
    if (csr.sorted) {
      // Here we collect col indices and data.
289
290
291
292
293
      CollectDataIndicesFromSorted<XPU, IdType>(indices_data, data,
                                                indptr_data[row_id],
                                                indptr_data[row_id + 1],
                                                col_id, &ret_cols,
                                                &ret_data);
Da Zheng's avatar
Da Zheng committed
294
295
296
297
298
299
300
      // We need to add row Ids.
      while (ret_rows.size() < ret_data.size()) {
        ret_rows.push_back(row_id);
      }
    } else {
      for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
        if (indices_data[i] == col_id) {
301
302
          ret_rows.push_back(row_id);
          ret_cols.push_back(col_id);
303
          ret_data.push_back(data? data[i] : i);
Da Zheng's avatar
Da Zheng committed
304
        }
305
306
307
308
      }
    }
  }

309
310
311
  return {NDArray::FromVector(ret_rows, csr.indptr->dtype, csr.indptr->ctx),
          NDArray::FromVector(ret_cols, csr.indptr->dtype, csr.indptr->ctx),
          NDArray::FromVector(ret_data, csr.data->dtype, csr.data->ctx)};
312
313
}

314
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
315
    CSRMatrix csr, NDArray rows, NDArray cols);
316
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int64_t>(
317
318
319
320
321
322
    CSRMatrix csr, NDArray rows, NDArray cols);

///////////////////////////// CSRTranspose /////////////////////////////

// for a matrix of shape (N, M) and NNZ
// complexity: time O(NNZ + max(N, M)), space O(1)
323
template <DLDeviceType XPU, typename IdType>
324
325
326
327
328
329
CSRMatrix CSRTranspose(CSRMatrix csr) {
  const int64_t N = csr.num_rows;
  const int64_t M = csr.num_cols;
  const int64_t nnz = csr.indices->shape[0];
  const IdType* Ap = static_cast<IdType*>(csr.indptr->data);
  const IdType* Aj = static_cast<IdType*>(csr.indices->data);
330
  const IdType* Ax = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
331
332
  NDArray ret_indptr = NDArray::Empty({M + 1}, csr.indptr->dtype, csr.indptr->ctx);
  NDArray ret_indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
333
  NDArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
334
335
  IdType* Bp = static_cast<IdType*>(ret_indptr->data);
  IdType* Bi = static_cast<IdType*>(ret_indices->data);
336
  IdType* Bx = static_cast<IdType*>(ret_data->data);
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

  std::fill(Bp, Bp + M, 0);

  for (int64_t j = 0; j < nnz; ++j) {
    Bp[Aj[j]]++;
  }

  // cumsum
  for (int64_t i = 0, cumsum = 0; i < M; ++i) {
    const IdType temp = Bp[i];
    Bp[i] = cumsum;
    cumsum += temp;
  }
  Bp[M] = nnz;

  for (int64_t i = 0; i < N; ++i) {
    for (IdType j = Ap[i]; j < Ap[i+1]; ++j) {
      const IdType dst = Aj[j];
      Bi[Bp[dst]] = i;
356
      Bx[Bp[dst]] = Ax? Ax[j] : j;
357
358
359
360
361
362
363
364
365
366
367
368
369
370
      Bp[dst]++;
    }
  }

  // correct the indptr
  for (int64_t i = 0, last = 0; i <= M; ++i) {
    IdType temp = Bp[i];
    Bp[i] = last;
    last = temp;
  }

  return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
}

371
372
template CSRMatrix CSRTranspose<kDLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLCPU, int64_t>(CSRMatrix csr);
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400

///////////////////////////// CSRToCOO /////////////////////////////
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
  const int64_t nnz = csr.indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
  for (IdType i = 0; i < csr.indptr->shape[0] - 1; ++i) {
    std::fill(ret_row_data + indptr_data[i],
              ret_row_data + indptr_data[i + 1],
              i);
  }
  return COOMatrix{csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data};
}

template COOMatrix CSRToCOO<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr);

// complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
  const int64_t N = csr.num_rows;
  const int64_t M = csr.num_cols;
  const int64_t nnz = csr.indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
  // data array should have the same type as the indices arrays
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
401
  const IdType* data = CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
402
403
404
405
406
407
408
409
  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
  IdType* ret_col_data = static_cast<IdType*>(ret_col->data);
  // scatter using the indices in the data array
  for (IdType row = 0; row < N; ++row) {
    for (IdType j = indptr_data[row]; j < indptr_data[row + 1]; ++j) {
      const IdType col = indices_data[j];
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
410
411
      ret_row_data[data ? data[j] : j] = row;
      ret_col_data[data ? data[j] : j] = col;
412
413
    }
  }
414
  return COOMatrix(N, M, ret_row, ret_col);
415
416
417
418
419
420
421
}

template COOMatrix CSRToCOODataAsOrder<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int64_t>(CSRMatrix csr);

///////////////////////////// CSRSliceRows /////////////////////////////

422
template <DLDeviceType XPU, typename IdType>
423
424
425
426
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const int64_t num_rows = end - start;
  const int64_t nnz = indptr[end] - indptr[start];
427
428
  IdArray ret_indptr = IdArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indices->ctx);
  IdType* r_indptr = static_cast<IdType*>(ret_indptr->data);
429
430
431
432
  for (int64_t i = start; i < end + 1; ++i) {
    r_indptr[i - start] = indptr[i] - indptr[start];
  }
  // indices and data can be view arrays
433
434
435
436
437
438
439
440
441
442
443
  IdArray ret_indices = csr.indices.CreateView(
      {nnz}, csr.indices->dtype, indptr[start] * sizeof(IdType));
  IdArray ret_data;
  if (CSRHasData(csr))
    ret_data = csr.data.CreateView({nnz}, csr.data->dtype, indptr[start] * sizeof(IdType));
  else
    ret_data = aten::Range(indptr[start], indptr[end],
                           csr.indptr->dtype.bits, csr.indptr->ctx);
  return CSRMatrix(num_rows, csr.num_cols,
                   ret_indptr, ret_indices, ret_data,
                   csr.sorted);
444
445
}

446
447
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
448

449
template <DLDeviceType XPU, typename IdType>
450
451
452
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
453
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
454
455
456
457
458
459
460
461
462
463
464
465
466
  const auto len = rows->shape[0];
  const IdType* rows_data = static_cast<IdType*>(rows->data);
  int64_t nnz = 0;
  for (int64_t i = 0; i < len; ++i) {
    IdType vid = rows_data[i];
    nnz += impl::CSRGetRowNNZ<XPU, IdType>(csr, vid);
  }

  CSRMatrix ret;
  ret.num_rows = len;
  ret.num_cols = csr.num_cols;
  ret.indptr = NDArray::Empty({len + 1}, csr.indptr->dtype, csr.indices->ctx);
  ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
467
468
  ret.data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
  ret.sorted = csr.sorted;
469
470
471

  IdType* ret_indptr_data = static_cast<IdType*>(ret.indptr->data);
  IdType* ret_indices_data = static_cast<IdType*>(ret.indices->data);
472
  IdType* ret_data = static_cast<IdType*>(ret.data->data);
473
474
475
476
477
478
479
  ret_indptr_data[0] = 0;
  for (int64_t i = 0; i < len; ++i) {
    const IdType rid = rows_data[i];
    // note: zero is allowed
    ret_indptr_data[i + 1] = ret_indptr_data[i] + indptr_data[rid + 1] - indptr_data[rid];
    std::copy(indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1],
              ret_indices_data + ret_indptr_data[i]);
480
481
482
483
484
485
    if (data)
      std::copy(data + indptr_data[rid], data + indptr_data[rid + 1],
                ret_data + ret_indptr_data[i]);
    else
      std::iota(ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1],
                indptr_data[rid]);
486
487
488
489
  }
  return ret;
}

490
491
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
492
493
494

///////////////////////////// CSRSliceMatrix /////////////////////////////

495
template <DLDeviceType XPU, typename IdType>
496
497
498
499
500
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
  IdHashMap<IdType> hashmap(cols);
  const int64_t new_nrows = rows->shape[0];
  const int64_t new_ncols = cols->shape[0];
  const IdType* rows_data = static_cast<IdType*>(rows->data);
501
  const bool has_data = CSRHasData(csr);
502
503
504

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
505
  const IdType* data = has_data? static_cast<IdType*>(csr.data->data) : nullptr;
506
507

  std::vector<IdType> sub_indptr, sub_indices;
508
  std::vector<IdType> sub_data;
509
510
511
512
513
514
515
516
517
518
519
520
  sub_indptr.resize(new_nrows + 1, 0);
  const IdType kInvalidId = new_ncols + 1;
  for (int64_t i = 0; i < new_nrows; ++i) {
    // NOTE: newi == i
    const IdType oldi = rows_data[i];
    CHECK(oldi >= 0 && oldi < csr.num_rows) << "Invalid row index: " << oldi;
    for (IdType p = indptr_data[oldi]; p < indptr_data[oldi + 1]; ++p) {
      const IdType oldj = indices_data[p];
      const IdType newj = hashmap.Map(oldj, kInvalidId);
      if (newj != kInvalidId) {
        ++sub_indptr[i];
        sub_indices.push_back(newj);
521
        sub_data.push_back(has_data? data[p] : p);
522
523
524
525
526
527
528
529
530
531
532
533
534
      }
    }
  }

  // cumsum sub_indptr
  for (int64_t i = 0, cumsum = 0; i < new_nrows; ++i) {
    const IdType temp = sub_indptr[i];
    sub_indptr[i] = cumsum;
    cumsum += temp;
  }
  sub_indptr[new_nrows] = sub_indices.size();

  const int64_t nnz = sub_data.size();
535
536
  NDArray sub_data_arr = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
  IdType* ptr = static_cast<IdType*>(sub_data_arr->data);
537
538
  std::copy(sub_data.begin(), sub_data.end(), ptr);
  return CSRMatrix{new_nrows, new_ncols,
539
540
    NDArray::FromVector(sub_indptr, csr.indptr->dtype, csr.indptr->ctx),
    NDArray::FromVector(sub_indices, csr.indptr->dtype, csr.indptr->ctx),
541
542
543
    sub_data_arr};
}

544
template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t>(
545
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
546
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t>(
547
548
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

549
550
551
552
553
554
555
556
557
558
559
template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
  typedef std::pair<IdType, IdType> ShufflePair;
  const int64_t num_rows = csr->num_rows;
  const int64_t nnz = csr->indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr->indptr->data);
  IdType* indices_data = static_cast<IdType*>(csr->indices->data);
  if (!CSRHasData(*csr)) {
    csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx);
  }
  IdType* eid_data = static_cast<IdType*>(csr->data->data);
Da Zheng's avatar
Da Zheng committed
560
561
#pragma omp parallel
  {
562
    std::vector<ShufflePair> reorder_vec;
Da Zheng's avatar
Da Zheng committed
563
564
#pragma omp for
    for (int64_t row = 0; row < num_rows; row++) {
565
      const int64_t num_cols = indptr_data[row + 1] - indptr_data[row];
Da Zheng's avatar
Da Zheng committed
566
      IdType *col = indices_data + indptr_data[row];
567
      IdType *eid = eid_data + indptr_data[row];
Da Zheng's avatar
Da Zheng committed
568
569
570
571
572
573
574

      reorder_vec.resize(num_cols);
      for (int64_t i = 0; i < num_cols; i++) {
        reorder_vec[i].first = col[i];
        reorder_vec[i].second = eid[i];
      }
      std::sort(reorder_vec.begin(), reorder_vec.end(),
575
                [](const ShufflePair &e1, const ShufflePair &e2) {
Da Zheng's avatar
Da Zheng committed
576
577
578
579
580
581
582
583
                  return e1.first < e2.first;
                });
      for (int64_t i = 0; i < num_cols; i++) {
        col[i] = reorder_vec[i].first;
        eid[i] = reorder_vec[i].second;
      }
    }
  }
584
  csr->sorted = true;
Da Zheng's avatar
Da Zheng committed
585
586
}

587
588
template void CSRSort_<kDLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDLCPU, int32_t>(CSRMatrix* csr);
Da Zheng's avatar
Da Zheng committed
589

590
591
592
}  // namespace impl
}  // namespace aten
}  // namespace dgl