coo_remove.cc 3.18 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/coo_remove.cc
 * \brief COO matrix remove entries CPU implementation
 */
#include <dgl/array.h>
7

8
9
#include <utility>
#include <vector>
10

11
12
13
14
15
16
17
18
19
#include "array_utils.h"

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

namespace {

20
21
/*! \brief COORemove implementation for COOMatrix with default consecutive edge
 * IDs */
22
template <DGLDeviceType XPU, typename IdType>
23
void COORemoveConsecutive(
24
25
    COOMatrix coo, IdArray entries, std::vector<IdType> *new_rows,
    std::vector<IdType> *new_cols, std::vector<IdType> *new_eids) {
26
27
28
29
30
31
32
33
34
35
36
37
38
  const int64_t nnz = coo.row->shape[0];
  const int64_t n_entries = entries->shape[0];
  const IdType *row_data = static_cast<IdType *>(coo.row->data);
  const IdType *col_data = static_cast<IdType *>(coo.col->data);
  const IdType *entry_data = static_cast<IdType *>(entries->data);

  std::vector<IdType> entry_data_sorted(entry_data, entry_data + n_entries);
  std::sort(entry_data_sorted.begin(), entry_data_sorted.end());

  int64_t j = 0;
  for (int64_t i = 0; i < nnz; ++i) {
    if (j < n_entries && entry_data_sorted[j] == i) {
      // Move on to the next different entry
39
      while (j < n_entries && entry_data_sorted[j] == i) ++j;
40
41
42
43
44
45
46
47
48
      continue;
    }
    new_rows->push_back(row_data[i]);
    new_cols->push_back(col_data[i]);
    new_eids->push_back(i);
  }
}

/*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */
49
template <DGLDeviceType XPU, typename IdType>
50
void COORemoveShuffled(
51
52
    COOMatrix coo, IdArray entries, std::vector<IdType> *new_rows,
    std::vector<IdType> *new_cols, std::vector<IdType> *new_eids) {
53
54
55
56
57
58
59
60
61
  const int64_t nnz = coo.row->shape[0];
  const IdType *row_data = static_cast<IdType *>(coo.row->data);
  const IdType *col_data = static_cast<IdType *>(coo.col->data);
  const IdType *eid_data = static_cast<IdType *>(coo.data->data);

  IdHashMap<IdType> eid_map(entries);

  for (int64_t i = 0; i < nnz; ++i) {
    const IdType eid = eid_data[i];
62
    if (eid_map.Contains(eid)) continue;
63
64
65
66
67
68
69
70
    new_rows->push_back(row_data[i]);
    new_cols->push_back(col_data[i]);
    new_eids->push_back(eid);
  }
}

};  // namespace

71
template <DGLDeviceType XPU, typename IdType>
72
73
74
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
  const int64_t nnz = coo.row->shape[0];
  const int64_t n_entries = entries->shape[0];
75
  if (n_entries == 0) return coo;
76
77
78
79
80
81
82

  std::vector<IdType> new_rows, new_cols, new_eids;
  new_rows.reserve(nnz - n_entries);
  new_cols.reserve(nnz - n_entries);
  new_eids.reserve(nnz - n_entries);

  if (COOHasData(coo))
83
84
    COORemoveShuffled<XPU, IdType>(
        coo, entries, &new_rows, &new_cols, &new_eids);
85
86
  else
    // Removing from COO ordered by eid has more efficient implementation.
87
88
    COORemoveConsecutive<XPU, IdType>(
        coo, entries, &new_rows, &new_cols, &new_eids);
89
90

  return COOMatrix(
91
92
      coo.num_rows, coo.num_cols, IdArray::FromVector(new_rows),
      IdArray::FromVector(new_cols), IdArray::FromVector(new_eids));
93
94
}

95
96
template COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDGLCPU, int64_t>(COOMatrix coo, IdArray entries);
97
98
99
100

};  // namespace impl
};  // namespace aten
};  // namespace dgl