nodeflow.cc 3.4 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/nodeflow.cc
 * \brief DGL NodeFlow related functions.
 */

#include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
9
#include <dgl/packed_func_ext.h>
10

11
#include <string>
12
13
14

#include "../c_api_common.h"

15
16
17
18
19
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;

20
21
namespace dgl {

22
23
24
std::vector<IdArray> GetNodeFlowSlice(
    const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
    size_t layer1_start, size_t layer1_end, bool remap) {
25
  CHECK_GE(layer1_start, layer0_size);
26
  if (fmt == std::string("csr")) {
27
    dgl_id_t first_vid = layer1_start - layer0_size;
28
29
    auto csr = aten::CSRSliceRows(
        graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
30
    if (remap) {
31
      dgl_id_t *eid_data = static_cast<dgl_id_t *>(csr.data->data);
32
      const dgl_id_t first_eid = eid_data[0];
33
34
      IdArray new_indices = aten::Sub(csr.indices, first_vid);
      IdArray new_data = aten::Sub(csr.data, first_eid);
35
36
37
      return {csr.indptr, new_indices, new_data};
    } else {
      return {csr.indptr, csr.indices, csr.data};
38
    }
39
  } else if (fmt == std::string("coo")) {
40
    auto csr = graph.GetInCSR()->ToCSRMatrix();
41
42
43
    const dgl_id_t *indptr = static_cast<dgl_id_t *>(csr.indptr->data);
    const dgl_id_t *indices = static_cast<dgl_id_t *>(csr.indices->data);
    const dgl_id_t *edge_ids = static_cast<dgl_id_t *>(csr.data->data);
44
    int64_t nnz = indptr[layer1_end] - indptr[layer1_start];
45
46
    IdArray idx = aten::NewIdArray(2 * nnz);
    IdArray eid = aten::NewIdArray(nnz);
47
48
    int64_t *idx_data = static_cast<int64_t *>(idx->data);
    dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
49
50
    size_t num_edges = 0;
    for (size_t i = layer1_start; i < layer1_end; i++) {
51
      for (dgl_id_t j = indptr[i]; j < indptr[i + 1]; j++) {
52
53
54
55
56
57
58
59
        // These nodes are all in a layer. We need to remap them to the node id
        // local to the layer.
        idx_data[num_edges] = remap ? i - layer1_start : i;
        num_edges++;
      }
    }
    CHECK_EQ(num_edges, nnz);
    if (remap) {
60
61
      size_t edge_start = indptr[layer1_start];
      dgl_id_t first_eid = edge_ids[edge_start];
62
63
      dgl_id_t first_vid = layer1_start - layer0_size;
      for (int64_t i = 0; i < nnz; i++) {
64
65
66
        CHECK_GE(indices[edge_start + i], first_vid);
        idx_data[nnz + i] = indices[edge_start + i] - first_vid;
        eid_data[i] = edge_ids[edge_start + i] - first_eid;
67
68
      }
    } else {
69
70
71
72
73
74
      std::copy(
          indices + indptr[layer1_start], indices + indptr[layer1_end],
          idx_data + nnz);
      std::copy(
          edge_ids + indptr[layer1_start], edge_ids + indptr[layer1_end],
          eid_data);
75
76
77
78
    }
    return std::vector<IdArray>{idx, eid};
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format";
79
    return {};
80
81
82
  }
}

83
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockAdj")
84
85
86
87
88
89
90
91
92
93
94
95
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      GraphRef g = args[0];
      std::string format = args[1];
      int64_t layer0_size = args[2];
      int64_t start = args[3];
      int64_t end = args[4];
      const bool remap = args[5];
      auto ig =
          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
      auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);
      *rv = ConvertNDArrayVectorToPackedFunc(res);
    });
96

97
}  // namespace dgl