remove_edges.cc 3.41 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/transform/remove_edges.cc
 * @brief Remove edges.
5
6
7
 */

#include <dgl/array.h>
8
#include <dgl/base_heterograph.h>
9
10
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
11
12
13
#include <dgl/runtime/registry.h>
#include <dgl/transform.h>

14
#include <tuple>
15
16
#include <utility>
#include <vector>
17
18
19
20
21
22
23
24

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace transform {

25
26
std::pair<HeteroGraphPtr, std::vector<IdArray>> RemoveEdges(
    const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
27
28
29
30
31
  std::vector<IdArray> induced_eids;
  std::vector<HeteroGraphPtr> rel_graphs;
  const int64_t num_etypes = graph->NumEdgeTypes();

  for (int64_t etype = 0; etype < num_etypes; ++etype) {
32
    const SparseFormat fmt = graph->SelectFormat(etype, COO_CODE);
33
34
35
36
37
38
39
40
41
42
43
    const auto src_dst_types = graph->GetEndpointTypes(etype);
    const dgl_type_t srctype = src_dst_types.first;
    const dgl_type_t dsttype = src_dst_types.second;
    const int num_ntypes_rel = (srctype == dsttype) ? 1 : 2;
    HeteroGraphPtr new_rel_graph;
    IdArray induced_eids_rel;

    if (fmt == SparseFormat::kCOO) {
      const COOMatrix &coo = graph->GetCOOMatrix(etype);
      const COOMatrix &result = COORemove(coo, eids[etype]);
      new_rel_graph = CreateFromCOO(
44
45
          num_ntypes_rel, result.num_rows, result.num_cols, result.row,
          result.col);
46
47
48
49
50
      induced_eids_rel = result.data;
    } else if (fmt == SparseFormat::kCSR) {
      const CSRMatrix &csr = graph->GetCSRMatrix(etype);
      const CSRMatrix &result = CSRRemove(csr, eids[etype]);
      new_rel_graph = CreateFromCSR(
51
52
          num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,
          result.indices,
53
          // TODO(BarclayII): make CSR support null eid array
54
55
56
          Range(
              0, result.indices->shape[0], result.indices->dtype.bits,
              result.indices->ctx));
57
58
59
60
61
      induced_eids_rel = result.data;
    } else if (fmt == SparseFormat::kCSC) {
      const CSRMatrix &csc = graph->GetCSCMatrix(etype);
      const CSRMatrix &result = CSRRemove(csc, eids[etype]);
      new_rel_graph = CreateFromCSC(
62
63
          num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,
          result.indices,
64
          // TODO(BarclayII): make CSR support null eid array
65
66
67
          Range(
              0, result.indices->shape[0], result.indices->dtype.bits,
              result.indices->ctx));
68
69
70
71
72
73
74
75
76
77
78
79
80
      induced_eids_rel = result.data;
    }

    rel_graphs.push_back(new_rel_graph);
    induced_eids.push_back(induced_eids_rel);
  }

  const HeteroGraphPtr new_graph = CreateHeteroGraph(
      graph->meta_graph(), rel_graphs, graph->NumVerticesPerType());
  return std::make_pair(new_graph, induced_eids);
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLRemoveEdges")
81
82
83
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      const HeteroGraphRef graph_ref = args[0];
      const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]);
84

85
86
87
      HeteroGraphPtr new_graph;
      std::vector<IdArray> induced_eids;
      std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids);
88

89
90
91
      List<Value> induced_eids_ref;
      for (IdArray &array : induced_eids)
        induced_eids_ref.push_back(Value(MakeValue(array)));
92

93
94
95
      List<ObjectRef> ret;
      ret.push_back(HeteroGraphRef(new_graph));
      ret.push_back(induced_eids_ref);
96

97
98
      *rv = ret;
    });
99
100
101
102

};  // namespace transform

};  // namespace dgl