"git@developer.sourcefind.cn:change/sglang.git" did not exist on "b7361cc4441d7843d4799da4bf78c3654a39422e"
graph_traversal.cc 3.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/traversal.cc
 * \brief Graph traversal implementation
 */
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include "../c_api_common.h"

using namespace dgl::runtime;

namespace dgl {
namespace traverse {

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef g = args[0];
    const IdArray src = args[1];
    bool reversed = args[2];
    aten::CSRMatrix csr;
    if (reversed) {
      csr = g.sptr()->GetCSCMatrix(0);
    } else {
      csr = g.sptr()->GetCSRMatrix(0);
    }
    const auto& front = aten::BFSNodesFrontiers(csr, src);
    *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
  });

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef g = args[0];
    const IdArray src = args[1];
    bool reversed = args[2];
    aten::CSRMatrix csr;
    if (reversed) {
      csr = g.sptr()->GetCSCMatrix(0);
    } else {
      csr = g.sptr()->GetCSRMatrix(0);
    }

    const auto& front = aten::BFSEdgesFrontiers(csr, src);
    *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
  });

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef g = args[0];
    bool reversed = args[1];
    aten::CSRMatrix csr;
    if (reversed) {
      csr = g.sptr()->GetCSCMatrix(0);
    } else {
      csr = g.sptr()->GetCSRMatrix(0);
    }

    const auto& front = aten::TopologicalNodesFrontiers(csr);
    *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
  });


DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef g = args[0];
    const IdArray source = args[1];
    const bool reversed = args[2];
    CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
    aten::CSRMatrix csr;
    if (reversed) {
      csr = g.sptr()->GetCSCMatrix(0);
    } else {
      csr = g.sptr()->GetCSRMatrix(0);
    }
    const auto& front = aten::DGLDFSEdges(csr, source);
    *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
  });

DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    HeteroGraphRef g = args[0];
    const IdArray source = args[1];
    const bool reversed = args[2];
    const bool has_reverse_edge = args[3];
    const bool has_nontree_edge = args[4];
    const bool return_labels = args[5];
    aten::CSRMatrix csr;
    if (reversed) {
      csr = g.sptr()->GetCSCMatrix(0);
    } else {
      csr = g.sptr()->GetCSRMatrix(0);
    }

    const auto& front = aten::DGLDFSLabeledEdges(csr,
                                                 source,
                                                 has_reverse_edge,
                                                 has_nontree_edge,
                                                 return_labels);

    if (return_labels) {
      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.tags, front.sections});
    } else {
      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
    }
  });

}  // namespace traverse
}  // namespace dgl