"src/runtime/vscode:/vscode.git/clone" did not exist on "1dddaad4f025b9f32ac796b9821f09036de95238"
csr_remove.cc 3.41 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/coo_remove.cc
 * @brief CSR matrix remove entries CPU implementation
5
6
 */
#include <dgl/array.h>
7

8
9
#include <utility>
#include <vector>
10

11
12
13
14
15
16
17
18
19
#include "array_utils.h"

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

namespace {

20
template <DGLDeviceType XPU, typename IdType>
21
void CSRRemoveConsecutive(
22
23
    CSRMatrix csr, IdArray entries, std::vector<IdType> *new_indptr,
    std::vector<IdType> *new_indices, std::vector<IdType> *new_eids) {
24
  CHECK_SAME_DTYPE(csr.indices, entries);
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  const int64_t n_entries = entries->shape[0];
  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
  const IdType *entry_data = static_cast<IdType *>(entries->data);

  std::vector<IdType> entry_data_sorted(entry_data, entry_data + n_entries);
  std::sort(entry_data_sorted.begin(), entry_data_sorted.end());

  int64_t k = 0;
  new_indptr->push_back(0);
  for (int64_t i = 0; i < csr.num_rows; ++i) {
    for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
      if (k < n_entries && entry_data_sorted[k] == j) {
        // Move on to the next different entry
39
        while (k < n_entries && entry_data_sorted[k] == j) ++k;
40
41
42
43
44
45
46
47
48
        continue;
      }
      new_indices->push_back(indices_data[j]);
      new_eids->push_back(k);
    }
    new_indptr->push_back(new_indices->size());
  }
}

49
template <DGLDeviceType XPU, typename IdType>
50
void CSRRemoveShuffled(
51
52
    CSRMatrix csr, IdArray entries, std::vector<IdType> *new_indptr,
    std::vector<IdType> *new_indices, std::vector<IdType> *new_eids) {
53
  CHECK_SAME_DTYPE(csr.indices, entries);
54
55
56
57
58
59
60
61
62
63
  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
  const IdType *eid_data = static_cast<IdType *>(csr.data->data);

  IdHashMap<IdType> eid_map(entries);

  new_indptr->push_back(0);
  for (int64_t i = 0; i < csr.num_rows; ++i) {
    for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
      const IdType eid = eid_data ? eid_data[j] : j;
64
      if (eid_map.Contains(eid)) continue;
65
66
67
68
69
70
71
72
73
      new_indices->push_back(indices_data[j]);
      new_eids->push_back(eid);
    }
    new_indptr->push_back(new_indices->size());
  }
}

};  // namespace

74
template <DGLDeviceType XPU, typename IdType>
75
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
76
  CHECK_SAME_DTYPE(csr.indices, entries);
77
78
  const int64_t nnz = csr.indices->shape[0];
  const int64_t n_entries = entries->shape[0];
79
  if (n_entries == 0) return csr;
80
81
82
83
84
85
86

  std::vector<IdType> new_indptr, new_indices, new_eids;
  new_indptr.reserve(nnz - n_entries);
  new_indices.reserve(nnz - n_entries);
  new_eids.reserve(nnz - n_entries);

  if (CSRHasData(csr))
87
88
    CSRRemoveShuffled<XPU, IdType>(
        csr, entries, &new_indptr, &new_indices, &new_eids);
89
90
  else
    // Removing from CSR ordered by eid has more efficient implementation
91
92
    CSRRemoveConsecutive<XPU, IdType>(
        csr, entries, &new_indptr, &new_indices, &new_eids);
93
94

  return CSRMatrix(
95
96
      csr.num_rows, csr.num_cols, IdArray::FromVector(new_indptr),
      IdArray::FromVector(new_indices), IdArray::FromVector(new_eids));
97
98
}

99
100
template CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDGLCPU, int64_t>(CSRMatrix csr, IdArray entries);
101
102
103
104

};  // namespace impl
};  // namespace aten
};  // namespace dgl