"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "6578cf27de8b1ec622ddb825dc2857423493b6ef"
spmat_op_impl.cc 23.3 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cpu/spmat_op_impl.cc
 * \brief Sparse matrix operator CPU implementation
 */
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
9
#include <numeric>
10
#include "array_utils.h"
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

///////////////////////////// CSRIsNonZero /////////////////////////////

template <DLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
Da Zheng's avatar
Da Zheng committed
27
28
29
30
31
32
33
34
35
  if (csr.sorted) {
    const IdType *start = indices_data + indptr_data[row];
    const IdType *end = indices_data + indptr_data[row + 1];
    return std::binary_search(start, end, col);
  } else {
    for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) {
      if (indices_data[i] == col) {
        return true;
      }
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    }
  }
  return false;
}

template bool CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);

template <DLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto rstlen = std::max(rowlen, collen);
  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
  const IdType* row_data = static_cast<IdType*>(row->data);
  const IdType* col_data = static_cast<IdType*>(col->data);
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
    *(rst_data++) = CSRIsNonZero<XPU, IdType>(csr, row_data[i], col_data[j])? 1 : 0;
  }
  return rst;
}

template NDArray CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, NDArray, NDArray);

///////////////////////////// CSRHasDuplicate /////////////////////////////

template <DLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) {
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
  for (IdType src = 0; src < csr.num_rows; ++src) {
    std::unordered_set<IdType> hashmap;
    for (IdType eid = indptr_data[src]; eid < indptr_data[src+1]; ++eid) {
      const IdType dst = indices_data[eid];
      if (hashmap.count(dst)) {
        return true;
      } else {
        hashmap.insert(dst);
      }
    }
  }
  return false;
}

template bool CSRHasDuplicate<kDLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLCPU, int64_t>(CSRMatrix csr);

///////////////////////////// CSRGetRowNNZ /////////////////////////////

template <DLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  return indptr_data[row + 1] - indptr_data[row];
}

template int64_t CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);

template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
  const auto len = rows->shape[0];
  const IdType* vid_data = static_cast<IdType*>(rows->data);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
  for (int64_t i = 0; i < len; ++i) {
    const auto vid = vid_data[i];
    rst_data[i] = indptr_data[vid + 1] - indptr_data[vid];
  }
  return rst;
}

template NDArray CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, NDArray);

///////////////////////////// CSRGetRowColumnIndices /////////////////////////////

template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const int64_t offset = indptr_data[row] * sizeof(IdType);
  return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}

template NDArray CSRGetRowColumnIndices<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLCPU, int64_t>(CSRMatrix, int64_t);

///////////////////////////// CSRGetRowData /////////////////////////////

132
template <DLDeviceType XPU, typename IdType>
133
134
135
136
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
137
138
139
140
141
  const int64_t offset = indptr_data[row] * sizeof(IdType);
  if (CSRHasData(csr))
    return csr.data.CreateView({len}, csr.data->dtype, offset);
  else
    return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
142
143
}

144
145
template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t);
146
147
148

///////////////////////////// CSRGetData /////////////////////////////

149
150
template <DLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
Da Zheng's avatar
Da Zheng committed
151
                           const IdType start, const IdType end, const IdType col,
152
                           std::vector<IdType> *ret_vec) {
Da Zheng's avatar
Da Zheng committed
153
154
155
156
157
158
159
160
161
  const IdType *start_ptr = indices_data + start;
  const IdType *end_ptr = indices_data + end;
  auto it = std::lower_bound(start_ptr, end_ptr, col);
  // This might be a multi-graph. We need to collect all of the matched
  // columns.
  for (; it != end_ptr; it++) {
    // If the col exist
    if (*it == col) {
      IdType idx = it - indices_data;
162
      ret_vec->push_back(data? data[idx] : idx);
Da Zheng's avatar
Da Zheng committed
163
164
165
166
167
168
169
    } else {
      // If we find a column that is different, we can stop searching now.
      break;
    }
  }
}

170
template <DLDeviceType XPU, typename IdType>
171
172
173
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
174
  std::vector<IdType> ret_vec;
175
176
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
177
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
Da Zheng's avatar
Da Zheng committed
178
  if (csr.sorted) {
179
180
181
    CollectDataFromSorted<XPU, IdType>(indices_data, data,
                                       indptr_data[row], indptr_data[row + 1],
                                       col, &ret_vec);
Da Zheng's avatar
Da Zheng committed
182
183
184
  } else {
    for (IdType i = indptr_data[row]; i < indptr_data[row+1]; ++i) {
      if (indices_data[i] == col) {
185
        ret_vec.push_back(data? data[i] : i);
Da Zheng's avatar
Da Zheng committed
186
      }
187
188
    }
  }
189
  return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
190
191
}

192
193
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
194

195
template <DLDeviceType XPU, typename IdType>
196
197
198
199
200
201
202
203
204
205
206
207
208
209
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
    << "Invalid row and col id array.";

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const IdType* row_data = static_cast<IdType*>(rows->data);
  const IdType* col_data = static_cast<IdType*>(cols->data);

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
210
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
211

212
  std::vector<IdType> ret_vec;
213
214
215
216
217

  for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
    const IdType row_id = row_data[i], col_id = col_data[j];
    CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
    CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
Da Zheng's avatar
Da Zheng committed
218
    if (csr.sorted) {
219
220
221
      CollectDataFromSorted<XPU, IdType>(indices_data, data,
                                         indptr_data[row_id], indptr_data[row_id + 1],
                                         col_id, &ret_vec);
Da Zheng's avatar
Da Zheng committed
222
223
224
    } else {
      for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
        if (indices_data[i] == col_id) {
225
          ret_vec.push_back(data? data[i] : i);
Da Zheng's avatar
Da Zheng committed
226
        }
227
228
229
230
      }
    }
  }

231
  return NDArray::FromVector(ret_vec, csr.data->dtype, csr.data->ctx);
232
233
}

234
235
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix csr, NDArray rows, NDArray cols);
236
237
238

///////////////////////////// CSRGetDataAndIndices /////////////////////////////

239
240
template <DLDeviceType XPU, typename IdType>
void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data,
Da Zheng's avatar
Da Zheng committed
241
242
                                  const IdType start, const IdType end, const IdType col,
                                  std::vector<IdType> *col_vec,
243
                                  std::vector<IdType> *ret_vec) {
Da Zheng's avatar
Da Zheng committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
  const IdType *start_ptr = indices_data + start;
  const IdType *end_ptr = indices_data + end;
  auto it = std::lower_bound(start_ptr, end_ptr, col);
  // This might be a multi-graph. We need to collect all of the matched
  // columns.
  for (; it != end_ptr; it++) {
    // If the col exist
    if (*it == col) {
      IdType idx = it - indices_data;
      col_vec->push_back(indices_data[idx]);
      ret_vec->push_back(data[idx]);
    } else {
      // If we find a column that is different, we can stop searching now.
      break;
    }
  }
}

262
template <DLDeviceType XPU, typename IdType>
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
  // TODO(minjie): more efficient implementation for matrix without duplicate entries
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
    << "Invalid row and col id array.";

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const IdType* row_data = static_cast<IdType*>(rows->data);
  const IdType* col_data = static_cast<IdType*>(cols->data);

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
278
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
279
280

  std::vector<IdType> ret_rows, ret_cols;
281
  std::vector<IdType> ret_data;
282
283
284
285
286

  for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
    const IdType row_id = row_data[i], col_id = col_data[j];
    CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
    CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
Da Zheng's avatar
Da Zheng committed
287
288
    if (csr.sorted) {
      // Here we collect col indices and data.
289
290
291
292
293
      CollectDataIndicesFromSorted<XPU, IdType>(indices_data, data,
                                                indptr_data[row_id],
                                                indptr_data[row_id + 1],
                                                col_id, &ret_cols,
                                                &ret_data);
Da Zheng's avatar
Da Zheng committed
294
295
296
297
298
299
300
      // We need to add row Ids.
      while (ret_rows.size() < ret_data.size()) {
        ret_rows.push_back(row_id);
      }
    } else {
      for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
        if (indices_data[i] == col_id) {
301
302
          ret_rows.push_back(row_id);
          ret_cols.push_back(col_id);
303
          ret_data.push_back(data? data[i] : i);
Da Zheng's avatar
Da Zheng committed
304
        }
305
306
307
308
      }
    }
  }

309
310
311
  return {NDArray::FromVector(ret_rows, csr.indptr->dtype, csr.indptr->ctx),
          NDArray::FromVector(ret_cols, csr.indptr->dtype, csr.indptr->ctx),
          NDArray::FromVector(ret_data, csr.data->dtype, csr.data->ctx)};
312
313
}

314
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
315
    CSRMatrix csr, NDArray rows, NDArray cols);
316
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int64_t>(
317
318
319
320
321
322
    CSRMatrix csr, NDArray rows, NDArray cols);

///////////////////////////// CSRTranspose /////////////////////////////

// for a matrix of shape (N, M) and NNZ
// complexity: time O(NNZ + max(N, M)), space O(1)
323
template <DLDeviceType XPU, typename IdType>
324
325
326
327
328
329
CSRMatrix CSRTranspose(CSRMatrix csr) {
  const int64_t N = csr.num_rows;
  const int64_t M = csr.num_cols;
  const int64_t nnz = csr.indices->shape[0];
  const IdType* Ap = static_cast<IdType*>(csr.indptr->data);
  const IdType* Aj = static_cast<IdType*>(csr.indices->data);
330
  const IdType* Ax = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
331
332
  NDArray ret_indptr = NDArray::Empty({M + 1}, csr.indptr->dtype, csr.indptr->ctx);
  NDArray ret_indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
333
  NDArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
334
335
  IdType* Bp = static_cast<IdType*>(ret_indptr->data);
  IdType* Bi = static_cast<IdType*>(ret_indices->data);
336
  IdType* Bx = static_cast<IdType*>(ret_data->data);
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

  std::fill(Bp, Bp + M, 0);

  for (int64_t j = 0; j < nnz; ++j) {
    Bp[Aj[j]]++;
  }

  // cumsum
  for (int64_t i = 0, cumsum = 0; i < M; ++i) {
    const IdType temp = Bp[i];
    Bp[i] = cumsum;
    cumsum += temp;
  }
  Bp[M] = nnz;

  for (int64_t i = 0; i < N; ++i) {
    for (IdType j = Ap[i]; j < Ap[i+1]; ++j) {
      const IdType dst = Aj[j];
      Bi[Bp[dst]] = i;
356
      Bx[Bp[dst]] = Ax? Ax[j] : j;
357
358
359
360
361
362
363
364
365
366
367
368
369
370
      Bp[dst]++;
    }
  }

  // correct the indptr
  for (int64_t i = 0, last = 0; i <= M; ++i) {
    IdType temp = Bp[i];
    Bp[i] = last;
    last = temp;
  }

  return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
}

371
372
template CSRMatrix CSRTranspose<kDLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLCPU, int64_t>(CSRMatrix csr);
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

///////////////////////////// CSRToCOO /////////////////////////////
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
  const int64_t nnz = csr.indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
  for (IdType i = 0; i < csr.indptr->shape[0] - 1; ++i) {
    std::fill(ret_row_data + indptr_data[i],
              ret_row_data + indptr_data[i + 1],
              i);
  }
  return COOMatrix{csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data};
}

template COOMatrix CSRToCOO<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr);

// complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
  CHECK(CSRHasData(csr)) << "missing data array.";
  const int64_t N = csr.num_rows;
  const int64_t M = csr.num_cols;
  const int64_t nnz = csr.indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
  // data array should have the same type as the indices arrays
  const IdType* data = static_cast<IdType*>(csr.data->data);
  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
  IdType* ret_col_data = static_cast<IdType*>(ret_col->data);
  // scatter using the indices in the data array
  for (IdType row = 0; row < N; ++row) {
    for (IdType j = indptr_data[row]; j < indptr_data[row + 1]; ++j) {
      const IdType col = indices_data[j];
      ret_row_data[data[j]] = row;
      ret_col_data[data[j]] = col;
    }
  }
415
  return COOMatrix(N, M, ret_row, ret_col);
416
417
418
419
420
421
422
}

template COOMatrix CSRToCOODataAsOrder<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int64_t>(CSRMatrix csr);

///////////////////////////// CSRSliceRows /////////////////////////////

423
template <DLDeviceType XPU, typename IdType>
424
425
426
427
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
  const int64_t num_rows = end - start;
  const int64_t nnz = indptr[end] - indptr[start];
428
429
  IdArray ret_indptr = IdArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indices->ctx);
  IdType* r_indptr = static_cast<IdType*>(ret_indptr->data);
430
431
432
433
  for (int64_t i = start; i < end + 1; ++i) {
    r_indptr[i - start] = indptr[i] - indptr[start];
  }
  // indices and data can be view arrays
434
435
436
437
438
439
440
441
442
443
444
  IdArray ret_indices = csr.indices.CreateView(
      {nnz}, csr.indices->dtype, indptr[start] * sizeof(IdType));
  IdArray ret_data;
  if (CSRHasData(csr))
    ret_data = csr.data.CreateView({nnz}, csr.data->dtype, indptr[start] * sizeof(IdType));
  else
    ret_data = aten::Range(indptr[start], indptr[end],
                           csr.indptr->dtype.bits, csr.indptr->ctx);
  return CSRMatrix(num_rows, csr.num_cols,
                   ret_indptr, ret_indices, ret_data,
                   csr.sorted);
445
446
}

447
448
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
449

450
template <DLDeviceType XPU, typename IdType>
451
452
453
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
454
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
455
456
457
458
459
460
461
462
463
464
465
466
467
  const auto len = rows->shape[0];
  const IdType* rows_data = static_cast<IdType*>(rows->data);
  int64_t nnz = 0;
  for (int64_t i = 0; i < len; ++i) {
    IdType vid = rows_data[i];
    nnz += impl::CSRGetRowNNZ<XPU, IdType>(csr, vid);
  }

  CSRMatrix ret;
  ret.num_rows = len;
  ret.num_cols = csr.num_cols;
  ret.indptr = NDArray::Empty({len + 1}, csr.indptr->dtype, csr.indices->ctx);
  ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
468
469
  ret.data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
  ret.sorted = csr.sorted;
470
471
472

  IdType* ret_indptr_data = static_cast<IdType*>(ret.indptr->data);
  IdType* ret_indices_data = static_cast<IdType*>(ret.indices->data);
473
  IdType* ret_data = static_cast<IdType*>(ret.data->data);
474
475
476
477
478
479
480
  ret_indptr_data[0] = 0;
  for (int64_t i = 0; i < len; ++i) {
    const IdType rid = rows_data[i];
    // note: zero is allowed
    ret_indptr_data[i + 1] = ret_indptr_data[i] + indptr_data[rid + 1] - indptr_data[rid];
    std::copy(indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1],
              ret_indices_data + ret_indptr_data[i]);
481
482
483
484
485
486
    if (data)
      std::copy(data + indptr_data[rid], data + indptr_data[rid + 1],
                ret_data + ret_indptr_data[i]);
    else
      std::iota(ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1],
                indptr_data[rid]);
487
488
489
490
  }
  return ret;
}

491
492
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
493
494
495

///////////////////////////// CSRSliceMatrix /////////////////////////////

496
template <DLDeviceType XPU, typename IdType>
497
498
499
500
501
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
  IdHashMap<IdType> hashmap(cols);
  const int64_t new_nrows = rows->shape[0];
  const int64_t new_ncols = cols->shape[0];
  const IdType* rows_data = static_cast<IdType*>(rows->data);
502
  const bool has_data = CSRHasData(csr);
503
504
505

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
506
  const IdType* data = has_data? static_cast<IdType*>(csr.data->data) : nullptr;
507
508

  std::vector<IdType> sub_indptr, sub_indices;
509
  std::vector<IdType> sub_data;
510
511
512
513
514
515
516
517
518
519
520
521
  sub_indptr.resize(new_nrows + 1, 0);
  const IdType kInvalidId = new_ncols + 1;
  for (int64_t i = 0; i < new_nrows; ++i) {
    // NOTE: newi == i
    const IdType oldi = rows_data[i];
    CHECK(oldi >= 0 && oldi < csr.num_rows) << "Invalid row index: " << oldi;
    for (IdType p = indptr_data[oldi]; p < indptr_data[oldi + 1]; ++p) {
      const IdType oldj = indices_data[p];
      const IdType newj = hashmap.Map(oldj, kInvalidId);
      if (newj != kInvalidId) {
        ++sub_indptr[i];
        sub_indices.push_back(newj);
522
        sub_data.push_back(has_data? data[p] : p);
523
524
525
526
527
528
529
530
531
532
533
534
535
      }
    }
  }

  // cumsum sub_indptr
  for (int64_t i = 0, cumsum = 0; i < new_nrows; ++i) {
    const IdType temp = sub_indptr[i];
    sub_indptr[i] = cumsum;
    cumsum += temp;
  }
  sub_indptr[new_nrows] = sub_indices.size();

  const int64_t nnz = sub_data.size();
536
537
  NDArray sub_data_arr = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
  IdType* ptr = static_cast<IdType*>(sub_data_arr->data);
538
539
  std::copy(sub_data.begin(), sub_data.end(), ptr);
  return CSRMatrix{new_nrows, new_ncols,
540
541
    NDArray::FromVector(sub_indptr, csr.indptr->dtype, csr.indptr->ctx),
    NDArray::FromVector(sub_indices, csr.indptr->dtype, csr.indptr->ctx),
542
543
544
    sub_data_arr};
}

545
template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t>(
546
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
547
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t>(
548
549
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

550
551
552
553
554
555
556
557
558
559
560
template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
  typedef std::pair<IdType, IdType> ShufflePair;
  const int64_t num_rows = csr->num_rows;
  const int64_t nnz = csr->indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr->indptr->data);
  IdType* indices_data = static_cast<IdType*>(csr->indices->data);
  if (!CSRHasData(*csr)) {
    csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx);
  }
  IdType* eid_data = static_cast<IdType*>(csr->data->data);
Da Zheng's avatar
Da Zheng committed
561
562
#pragma omp parallel
  {
563
    std::vector<ShufflePair> reorder_vec;
Da Zheng's avatar
Da Zheng committed
564
565
#pragma omp for
    for (int64_t row = 0; row < num_rows; row++) {
566
      const int64_t num_cols = indptr_data[row + 1] - indptr_data[row];
Da Zheng's avatar
Da Zheng committed
567
      IdType *col = indices_data + indptr_data[row];
568
      IdType *eid = eid_data + indptr_data[row];
Da Zheng's avatar
Da Zheng committed
569
570
571
572
573
574
575

      reorder_vec.resize(num_cols);
      for (int64_t i = 0; i < num_cols; i++) {
        reorder_vec[i].first = col[i];
        reorder_vec[i].second = eid[i];
      }
      std::sort(reorder_vec.begin(), reorder_vec.end(),
576
                [](const ShufflePair &e1, const ShufflePair &e2) {
Da Zheng's avatar
Da Zheng committed
577
578
579
580
581
582
583
584
                  return e1.first < e2.first;
                });
      for (int64_t i = 0; i < num_cols; i++) {
        col[i] = reorder_vec[i].first;
        eid[i] = reorder_vec[i].second;
      }
    }
  }
585
  csr->sorted = true;
Da Zheng's avatar
Da Zheng committed
586
587
}

588
589
template void CSRSort_<kDLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDLCPU, int32_t>(CSRMatrix* csr);
Da Zheng's avatar
Da Zheng committed
590

591
592
593
}  // namespace impl
}  // namespace aten
}  // namespace dgl