csr_get_data.cc 5.07 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2021 by Contributors
 * \file array/cpu/csr_get_data.cc
 * \brief Retrieve entries of a CSR matrix
 */
#include <dgl/array.h>
7
#include <dgl/runtime/parallel_for.h>
8
9
10
11
12
13
14
15
#include <vector>
#include <unordered_set>
#include <numeric>
#include "array_utils.h"

namespace dgl {

using runtime::NDArray;
16
using runtime::parallel_for;
17
18
19
namespace aten {
namespace impl {

20
template <DGLDeviceType XPU, typename IdType>
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
                           const IdType start, const IdType end, const IdType col,
                           std::vector<IdType> *ret_vec) {
  const IdType *start_ptr = indices_data + start;
  const IdType *end_ptr = indices_data + end;
  auto it = std::lower_bound(start_ptr, end_ptr, col);
  // This might be a multi-graph. We need to collect all of the matched
  // columns.
  for (; it != end_ptr; it++) {
    // If the col exist
    if (*it == col) {
      IdType idx = it - indices_data;
      ret_vec->push_back(data? data[idx] : idx);
    } else {
      // If we find a column that is different, we can stop searching now.
      break;
    }
  }
}

41
template <DGLDeviceType XPU, typename IdType, typename DType>
42
NDArray CSRGetData(
43
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
    << "Invalid row and col id array.";

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const IdType* row_data = static_cast<IdType*>(rows->data);
  const IdType* col_data = static_cast<IdType*>(cols->data);

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
  const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;

  const int64_t retlen = std::max(rowlen, collen);
  const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
  if (return_eids)
62
    BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
63
64
65
66
67
68
69
70
71
72
73
      "DType does not match row's dtype.";

  NDArray ret = Full(filler, retlen, rows->ctx);
  DType* ret_data = ret.Ptr<DType>();

  // NOTE: In most cases, the input csr is already sorted. If not, we might need to
  //   consider sorting it especially when the number of (row, col) pairs is large.
  //   Need more benchmarks to justify the choice.

  if (csr.sorted) {
    // use binary search on each row
74
75
76
77
78
79
80
81
82
83
84
85
86
    parallel_for(0, retlen, [&](size_t b, size_t e) {
      for (auto p = b; p < e; ++p) {
        const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
        CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
        CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
        const IdType *start_ptr = indices_data + indptr_data[row_id];
        const IdType *end_ptr = indices_data + indptr_data[row_id + 1];
        auto it = std::lower_bound(start_ptr, end_ptr, col_id);
        if (it != end_ptr && *it == col_id) {
          const IdType idx = it - indices_data;
          IdType eid = data ? data[idx] : idx;
          ret_data[p] = return_eids ? eid : weight_data[eid];
        }
87
      }
88
    });
89
90
  } else {
    // linear search on each row
91
92
93
94
95
96
97
98
99
100
101
    parallel_for(0, retlen, [&](size_t b, size_t e) {
      for (auto p = b; p < e; ++p) {
        const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
        CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
        CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
        for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1]; ++idx) {
          if (indices_data[idx] == col_id) {
            IdType eid = data ? data[idx] : idx;
            ret_data[p] = return_eids ? eid : weight_data[eid];
            break;
          }
102
103
        }
      }
104
    });
105
106
107
108
  }
  return ret;
}

109
template NDArray CSRGetData<kDGLCPU, int32_t, float>(
110
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
111
template NDArray CSRGetData<kDGLCPU, int64_t, float>(
112
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
113
template NDArray CSRGetData<kDGLCPU, int32_t, double>(
114
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
115
template NDArray CSRGetData<kDGLCPU, int64_t, double>(
116
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
117
118

// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
119
template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(
120
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
121
template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(
122
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
123
124
125
126

}  // namespace impl
}  // namespace aten
}  // namespace dgl