"vscode:/vscode.git/clone" did not exist on "9a3456cd7be76cbee3d47af2c9aeeb3fd4f431d4"
csr_get_data.cc 5.17 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file array/cpu/csr_get_data.cc
 * @brief Retrieve entries of a CSR matrix
5
6
 */
#include <dgl/array.h>
7
#include <dgl/runtime/parallel_for.h>
8

9
#include <numeric>
10
11
12
#include <unordered_set>
#include <vector>

13
14
15
16
17
#include "array_utils.h"

namespace dgl {

using runtime::NDArray;
18
using runtime::parallel_for;
19
20
21
namespace aten {
namespace impl {

22
template <DGLDeviceType XPU, typename IdType>
23
24
25
26
27
void CollectDataFromSorted(
    const IdType* indices_data, const IdType* data, const IdType start,
    const IdType end, const IdType col, std::vector<IdType>* ret_vec) {
  const IdType* start_ptr = indices_data + start;
  const IdType* end_ptr = indices_data + end;
28
29
30
31
32
33
34
  auto it = std::lower_bound(start_ptr, end_ptr, col);
  // This might be a multi-graph. We need to collect all of the matched
  // columns.
  for (; it != end_ptr; it++) {
    // If the col exist
    if (*it == col) {
      IdType idx = it - indices_data;
35
      ret_vec->push_back(data ? data[idx] : idx);
36
37
38
39
40
41
42
    } else {
      // If we find a column that is different, we can stop searching now.
      break;
    }
  }
}

43
template <DGLDeviceType XPU, typename IdType, typename DType>
44
NDArray CSRGetData(
45
46
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, DType filler) {
47
48
49
50
  const int64_t rowlen = rows->shape[0];
  const int64_t collen = cols->shape[0];

  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
51
      << "Invalid row and col id array.";
52
53
54
55
56
57
58
59

  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
  const IdType* row_data = static_cast<IdType*>(rows->data);
  const IdType* col_data = static_cast<IdType*>(cols->data);

  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
60
61
  const IdType* data =
      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
62
63
64
65

  const int64_t retlen = std::max(rowlen, collen);
  const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
  if (return_eids)
66
67
    BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
        << "DType does not match row's dtype.";
68
69
70
71

  NDArray ret = Full(filler, retlen, rows->ctx);
  DType* ret_data = ret.Ptr<DType>();

72
73
74
75
  // NOTE: In most cases, the input csr is already sorted. If not, we might need
  // to
  //   consider sorting it especially when the number of (row, col) pairs is
  //   large. Need more benchmarks to justify the choice.
76
77
78

  if (csr.sorted) {
    // use binary search on each row
79
80
    parallel_for(0, retlen, [&](size_t b, size_t e) {
      for (auto p = b; p < e; ++p) {
81
82
83
84
85
86
87
88
        const IdType row_id = row_data[p * row_stride],
                     col_id = col_data[p * col_stride];
        CHECK(row_id >= 0 && row_id < csr.num_rows)
            << "Invalid row index: " << row_id;
        CHECK(col_id >= 0 && col_id < csr.num_cols)
            << "Invalid col index: " << col_id;
        const IdType* start_ptr = indices_data + indptr_data[row_id];
        const IdType* end_ptr = indices_data + indptr_data[row_id + 1];
89
90
91
92
93
94
        auto it = std::lower_bound(start_ptr, end_ptr, col_id);
        if (it != end_ptr && *it == col_id) {
          const IdType idx = it - indices_data;
          IdType eid = data ? data[idx] : idx;
          ret_data[p] = return_eids ? eid : weight_data[eid];
        }
95
      }
96
    });
97
98
  } else {
    // linear search on each row
99
100
    parallel_for(0, retlen, [&](size_t b, size_t e) {
      for (auto p = b; p < e; ++p) {
101
102
103
104
105
106
107
108
        const IdType row_id = row_data[p * row_stride],
                     col_id = col_data[p * col_stride];
        CHECK(row_id >= 0 && row_id < csr.num_rows)
            << "Invalid row index: " << row_id;
        CHECK(col_id >= 0 && col_id < csr.num_cols)
            << "Invalid col index: " << col_id;
        for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1];
             ++idx) {
109
110
111
112
113
          if (indices_data[idx] == col_id) {
            IdType eid = data ? data[idx] : idx;
            ret_data[p] = return_eids ? eid : weight_data[eid];
            break;
          }
114
115
        }
      }
116
    });
117
118
119
120
  }
  return ret;
}

121
template NDArray CSRGetData<kDGLCPU, int32_t, float>(
122
123
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, float filler);
124
template NDArray CSRGetData<kDGLCPU, int64_t, float>(
125
126
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, float filler);
127
template NDArray CSRGetData<kDGLCPU, int32_t, double>(
128
129
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, double filler);
130
template NDArray CSRGetData<kDGLCPU, int64_t, double>(
131
132
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, double filler);
133
134

// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
135
template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(
136
137
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, int32_t filler);
138
template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(
139
140
    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
    NDArray weights, int64_t filler);
141
142
143
144

}  // namespace impl
}  // namespace aten
}  // namespace dgl