graph_traversal.cc 3.2 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/traversal.cc
 * \brief Graph traversal implementation
 */
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
8

9
10
11
12
13
14
15
16
#include "../c_api_common.h"

using namespace dgl::runtime;

namespace dgl {
namespace traverse {

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
17
18
19
20
21
22
23
24
25
26
27
28
29
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef g = args[0];
      const IdArray src = args[1];
      bool reversed = args[2];
      aten::CSRMatrix csr;
      if (reversed) {
        csr = g.sptr()->GetCSCMatrix(0);
      } else {
        csr = g.sptr()->GetCSRMatrix(0);
      }
      const auto& front = aten::BFSNodesFrontiers(csr, src);
      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
    });
30
31

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
32
33
34
35
36
37
38
39
40
41
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef g = args[0];
      const IdArray src = args[1];
      bool reversed = args[2];
      aten::CSRMatrix csr;
      if (reversed) {
        csr = g.sptr()->GetCSCMatrix(0);
      } else {
        csr = g.sptr()->GetCSRMatrix(0);
      }
42

43
44
45
      const auto& front = aten::BFSEdgesFrontiers(csr, src);
      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
    });
46
47

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
48
49
50
51
52
53
54
55
56
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef g = args[0];
      bool reversed = args[1];
      aten::CSRMatrix csr;
      if (reversed) {
        csr = g.sptr()->GetCSCMatrix(0);
      } else {
        csr = g.sptr()->GetCSRMatrix(0);
      }
57

58
59
60
      const auto& front = aten::TopologicalNodesFrontiers(csr);
      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
    });
61
62

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef g = args[0];
      const IdArray source = args[1];
      const bool reversed = args[2];
      CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
      aten::CSRMatrix csr;
      if (reversed) {
        csr = g.sptr()->GetCSCMatrix(0);
      } else {
        csr = g.sptr()->GetCSRMatrix(0);
      }
      const auto& front = aten::DGLDFSEdges(csr, source);
      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
    });
77
78

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
79
80
81
82
83
84
85
86
87
88
89
90
91
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef g = args[0];
      const IdArray source = args[1];
      const bool reversed = args[2];
      const bool has_reverse_edge = args[3];
      const bool has_nontree_edge = args[4];
      const bool return_labels = args[5];
      aten::CSRMatrix csr;
      if (reversed) {
        csr = g.sptr()->GetCSCMatrix(0);
      } else {
        csr = g.sptr()->GetCSRMatrix(0);
      }
92

93
94
      const auto& front = aten::DGLDFSLabeledEdges(
          csr, source, has_reverse_edge, has_nontree_edge, return_labels);
95

96
97
98
99
100
101
102
      if (return_labels) {
        *rv = ConvertNDArrayVectorToPackedFunc(
            {front.ids, front.tags, front.sections});
      } else {
        *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
      }
    });
103
104
105

}  // namespace traverse
}  // namespace dgl