remove_edges.cc 3.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/transform/remove_edges.cc
 * \brief Remove edges.
 */

#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h>
#include <vector>
#include <utility>
#include <tuple>

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace transform {

std::pair<HeteroGraphPtr, std::vector<IdArray>>
RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
  std::vector<IdArray> induced_eids;
  std::vector<HeteroGraphPtr> rel_graphs;
  const int64_t num_etypes = graph->NumEdgeTypes();

  for (int64_t etype = 0; etype < num_etypes; ++etype) {
31
    const SparseFormat fmt = graph->SelectFormat(etype, coo_code);
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    const auto src_dst_types = graph->GetEndpointTypes(etype);
    const dgl_type_t srctype = src_dst_types.first;
    const dgl_type_t dsttype = src_dst_types.second;
    const int num_ntypes_rel = (srctype == dsttype) ? 1 : 2;
    HeteroGraphPtr new_rel_graph;
    IdArray induced_eids_rel;

    if (fmt == SparseFormat::kCOO) {
      const COOMatrix &coo = graph->GetCOOMatrix(etype);
      const COOMatrix &result = COORemove(coo, eids[etype]);
      new_rel_graph = CreateFromCOO(
          num_ntypes_rel, result.num_rows, result.num_cols, result.row, result.col);
      induced_eids_rel = result.data;
    } else if (fmt == SparseFormat::kCSR) {
      const CSRMatrix &csr = graph->GetCSRMatrix(etype);
      const CSRMatrix &result = CSRRemove(csr, eids[etype]);
      new_rel_graph = CreateFromCSR(
          num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices,
          // TODO(BarclayII): make CSR support null eid array
          Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx));
      induced_eids_rel = result.data;
    } else if (fmt == SparseFormat::kCSC) {
      const CSRMatrix &csc = graph->GetCSCMatrix(etype);
      const CSRMatrix &result = CSRRemove(csc, eids[etype]);
      new_rel_graph = CreateFromCSC(
          num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices,
          // TODO(BarclayII): make CSR support null eid array
          Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx));
      induced_eids_rel = result.data;
    }

    rel_graphs.push_back(new_rel_graph);
    induced_eids.push_back(induced_eids_rel);
  }

  const HeteroGraphPtr new_graph = CreateHeteroGraph(
      graph->meta_graph(), rel_graphs, graph->NumVerticesPerType());
  return std::make_pair(new_graph, induced_eids);
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLRemoveEdges")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    const HeteroGraphRef graph_ref = args[0];
    const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]);

    HeteroGraphPtr new_graph;
    std::vector<IdArray> induced_eids;
    std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids);

    List<Value> induced_eids_ref;
    for (IdArray &array : induced_eids)
      induced_eids_ref.push_back(Value(MakeValue(array)));

    List<ObjectRef> ret;
    ret.push_back(HeteroGraphRef(new_graph));
    ret.push_back(induced_eids_ref);

    *rv = ret;
  });

};  // namespace transform

};  // namespace dgl