Unverified Commit 8f0df39e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4810)



* [Misc] clang-format auto fix.

* manual

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 401e1278
...@@ -4,15 +4,17 @@ ...@@ -4,15 +4,17 @@
* \brief Convert multigraphs to simple graphs * \brief Convert multigraphs to simple graphs
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <vector> #include <dgl/transform.h>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h"
#include "../heterograph.h" #include "../heterograph.h"
#include "../unit_graph.h" #include "../unit_graph.h"
#include "../../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -25,7 +27,8 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ...@@ -25,7 +27,8 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToSimpleGraph(const HeteroGraphPtr graph) { ToSimpleGraph(const HeteroGraphPtr graph) {
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
const auto metagraph = graph->meta_graph(); const auto metagraph = graph->meta_graph();
const auto &ugs = std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs(); const auto &ugs =
std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();
std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes); std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes);
std::vector<HeteroGraphPtr> rel_graphs(num_etypes); std::vector<HeteroGraphPtr> rel_graphs(num_etypes);
...@@ -35,14 +38,14 @@ ToSimpleGraph(const HeteroGraphPtr graph) { ...@@ -35,14 +38,14 @@ ToSimpleGraph(const HeteroGraphPtr graph) {
std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result; std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result;
} }
const HeteroGraphPtr result = CreateHeteroGraph( const HeteroGraphPtr result =
metagraph, rel_graphs, graph->NumVerticesPerType()); CreateHeteroGraph(metagraph, rel_graphs, graph->NumVerticesPerType());
return std::make_tuple(result, counts, edge_maps); return std::make_tuple(result, counts, edge_maps);
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const auto result = ToSimpleGraph(graph_ref.sptr()); const auto result = ToSimpleGraph(graph_ref.sptr());
......
...@@ -10,7 +10,8 @@ namespace dgl { ...@@ -10,7 +10,8 @@ namespace dgl {
HeteroGraphPtr JointUnionHeteroGraph( HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list has at least two graphs"; CHECK_GT(component_graphs.size(), 0)
<< "Input graph list has at least two graphs";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges()); std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0); std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
...@@ -24,18 +25,21 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -24,18 +25,21 @@ HeteroGraphPtr JointUnionHeteroGraph(
HeteroGraphPtr rgptr = nullptr; HeteroGraphPtr rgptr = nullptr;
// ALL = CSC | CSR | COO // ALL = CSC | CSR | COO
const dgl_format_code_t code =\ const dgl_format_code_t code =
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats(); component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// get common format // get common format
for (size_t i = 0; i < component_graphs.size(); ++i) { for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i]; const auto& cg = component_graphs[i];
CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype)) << "Input graph[" << i << CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype))
"] should have same number of src vertices as input graph[0]"; << "Input graph[" << i
CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype)) << "Input graph[" << i << << "] should have same number of src vertices as input graph[0]";
"] should have same number of dst vertices as input graph[0]"; CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype))
<< "Input graph[" << i
const dgl_format_code_t curr_code = cg->GetRelationGraph(etype)->GetAllowedFormats(); << "] should have same number of dst vertices as input graph[0]";
const dgl_format_code_t curr_code =
cg->GetRelationGraph(etype)->GetAllowedFormats();
if (curr_code != code) if (curr_code != code)
LOG(FATAL) << "All components should have the same formats"; LOG(FATAL) << "All components should have the same formats";
} }
...@@ -50,8 +54,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -50,8 +54,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
} }
aten::COOMatrix res = aten::UnionCoo(coos); aten::COOMatrix res = aten::UnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) { } else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs; std::vector<aten::CSRMatrix> csrs;
for (size_t i = 0; i < component_graphs.size(); ++i) { for (size_t i = 0; i < component_graphs.size(); ++i) {
...@@ -61,8 +65,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -61,8 +65,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
} }
aten::CSRMatrix res = aten::UnionCsr(csrs); aten::CSRMatrix res = aten::UnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) { } else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs; std::vector<aten::CSRMatrix> cscs;
...@@ -73,8 +77,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -73,8 +77,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
} }
aten::CSRMatrix res = aten::UnionCsr(cscs); aten::CSRMatrix res = aten::UnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
} }
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
...@@ -82,7 +86,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -82,7 +86,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
num_nodes_per_type[dst_vtype] = num_dst_v; num_nodes_per_type[dst_vtype] = num_dst_v;
} }
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type)); return CreateHeteroGraph(
meta_graph, rel_graphs, std::move(num_nodes_per_type));
} }
HeteroGraphPtr DisjointUnionHeteroGraph2( HeteroGraphPtr DisjointUnionHeteroGraph2(
...@@ -94,8 +99,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -94,8 +99,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
// Loop over all ntypes // Loop over all ntypes
for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) { for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) {
uint64_t offset = 0; uint64_t offset = 0;
for (const auto &cg : component_graphs) for (const auto& cg : component_graphs) offset += cg->NumVertices(vtype);
offset += cg->NumVertices(vtype);
num_nodes_per_type[vtype] = offset; num_nodes_per_type[vtype] = offset;
} }
...@@ -106,11 +110,12 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -106,11 +110,12 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
HeteroGraphPtr rgptr = nullptr; HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code =\ const dgl_format_code_t code =
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats(); component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// do some preprocess // do some preprocess
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
const dgl_format_code_t cur_code = cg->GetRelationGraph(etype)->GetAllowedFormats(); const dgl_format_code_t cur_code =
cg->GetRelationGraph(etype)->GetAllowedFormats();
if (cur_code != code) if (cur_code != code)
LOG(FATAL) << "All components should have the same formats"; LOG(FATAL) << "All components should have the same formats";
} }
...@@ -118,51 +123,55 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -118,51 +123,55 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
// prefer COO // prefer COO
if (FORMAT_HAS_COO(code)) { if (FORMAT_HAS_COO(code)) {
std::vector<aten::COOMatrix> coos; std::vector<aten::COOMatrix> coos;
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
aten::COOMatrix coo = cg->GetCOOMatrix(etype); aten::COOMatrix coo = cg->GetCOOMatrix(etype);
coos.push_back(coo); coos.push_back(coo);
} }
aten::COOMatrix res = aten::DisjointUnionCoo(coos); aten::COOMatrix res = aten::DisjointUnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) { } else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs; std::vector<aten::CSRMatrix> csrs;
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
aten::CSRMatrix csr = cg->GetCSRMatrix(etype); aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
csrs.push_back(csr); csrs.push_back(csr);
} }
aten::CSRMatrix res = aten::DisjointUnionCsr(csrs); aten::CSRMatrix res = aten::DisjointUnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) { } else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs; std::vector<aten::CSRMatrix> cscs;
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
aten::CSRMatrix csc = cg->GetCSCMatrix(etype); aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
cscs.push_back(csc); cscs.push_back(csc);
} }
aten::CSRMatrix res = aten::DisjointUnionCsr(cscs); aten::CSRMatrix res = aten::DisjointUnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
} }
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
} }
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type)); return CreateHeteroGraph(
meta_graph, rel_graphs, std::move(num_nodes_per_type));
} }
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) { GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes) {
// Sanity check for vertex sizes // Sanity check for vertex sizes
CHECK_EQ(vertex_sizes->dtype.bits, 64) << "dtype of vertex_sizes should be int64"; CHECK_EQ(vertex_sizes->dtype.bits, 64)
<< "dtype of vertex_sizes should be int64";
CHECK_EQ(edge_sizes->dtype.bits, 64) << "dtype of edge_sizes should be int64"; CHECK_EQ(edge_sizes->dtype.bits, 64) << "dtype of edge_sizes should be int64";
const uint64_t len_vertex_sizes = vertex_sizes->shape[0]; const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data); const uint64_t* vertex_sizes_data =
static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices(); const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types; const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
...@@ -177,8 +186,10 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -177,8 +186,10 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
vertex_cumsum[vtype].push_back( vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]); vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
} }
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype)) CHECK_EQ(
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype; vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type "
<< vtype;
} }
// Sanity check for edge sizes // Sanity check for edge sizes
...@@ -196,7 +207,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -196,7 +207,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]); edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
} }
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype)) CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype; << "Sum of the given sizes must equal to the number of edges for type "
<< etype;
} }
// Construct relation graphs for unbatched graphs // Construct relation graphs for unbatched graphs
...@@ -211,10 +223,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -211,10 +223,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype); aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
auto res = aten::DisjointPartitionCooBySizes(coo, auto res = aten::DisjointPartitionCooBySizes(
batch_size, coo, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
edge_cumsum[etype],
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]); vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO( HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
...@@ -228,10 +238,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -228,10 +238,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype); aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csr, auto res = aten::DisjointPartitionCsrBySizes(
batch_size, csr, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
edge_cumsum[etype],
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]); vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR( HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR(
...@@ -246,10 +254,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -246,10 +254,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype); aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csc, auto res = aten::DisjointPartitionCsrBySizes(
batch_size, csc, batch_size, edge_cumsum[etype], vertex_cumsum[dst_vtype],
edge_cumsum[etype],
vertex_cumsum[dst_vtype],
vertex_cumsum[src_vtype]); vertex_cumsum[src_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC( HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC(
...@@ -264,20 +270,26 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -264,20 +270,26 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
for (uint64_t i = 0; i < num_vertex_types; ++i) for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g]; num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type)); rst.push_back(
CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
} }
return rst; return rst;
} }
HeteroGraphPtr SliceHeteroGraph( HeteroGraphPtr SliceHeteroGraph(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray num_nodes_per_type, GraphPtr meta_graph, HeteroGraphPtr batched_graph,
IdArray start_nid_per_type, IdArray num_edges_per_type, IdArray start_eid_per_type) { IdArray num_nodes_per_type, IdArray start_nid_per_type,
IdArray num_edges_per_type, IdArray start_eid_per_type) {
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges()); std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
const uint64_t* start_nid_per_type_data = static_cast<uint64_t*>(start_nid_per_type->data); const uint64_t* start_nid_per_type_data =
const uint64_t* num_nodes_per_type_data = static_cast<uint64_t*>(num_nodes_per_type->data); static_cast<uint64_t*>(start_nid_per_type->data);
const uint64_t* start_eid_per_type_data = static_cast<uint64_t*>(start_eid_per_type->data); const uint64_t* num_nodes_per_type_data =
const uint64_t* num_edges_per_type_data = static_cast<uint64_t*>(num_edges_per_type->data); static_cast<uint64_t*>(num_nodes_per_type->data);
const uint64_t* start_eid_per_type_data =
static_cast<uint64_t*>(start_eid_per_type->data);
const uint64_t* num_edges_per_type_data =
static_cast<uint64_t*>(num_edges_per_type->data);
// Map vertex type to the corresponding node range // Map vertex type to the corresponding node range
const uint64_t num_vertex_types = meta_graph->NumVertices(); const uint64_t num_vertex_types = meta_graph->NumVertices();
...@@ -296,50 +308,52 @@ HeteroGraphPtr SliceHeteroGraph( ...@@ -296,50 +308,52 @@ HeteroGraphPtr SliceHeteroGraph(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
HeteroGraphPtr rgptr = nullptr; HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code = batched_graph->GetRelationGraph(etype)->GetAllowedFormats(); const dgl_format_code_t code =
batched_graph->GetRelationGraph(etype)->GetAllowedFormats();
// handle graph without edges // handle graph without edges
std::vector<uint64_t> edge_range; std::vector<uint64_t> edge_range;
edge_range.push_back(start_eid_per_type_data[etype]); edge_range.push_back(start_eid_per_type_data[etype]);
edge_range.push_back(start_eid_per_type_data[etype] + num_edges_per_type_data[etype]); edge_range.push_back(
start_eid_per_type_data[etype] + num_edges_per_type_data[etype]);
// prefer COO // prefer COO
if (FORMAT_HAS_COO(code)) { if (FORMAT_HAS_COO(code)) {
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype); aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
aten::COOMatrix res = aten::COOSliceContiguousChunk(coo, aten::COOMatrix res = aten::COOSliceContiguousChunk(
edge_range, coo, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
vertex_range[src_vtype], rgptr =
vertex_range[dst_vtype]); UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr = UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) { } else if (FORMAT_HAS_CSR(code)) {
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype); aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csr, aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
edge_range, csr, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
vertex_range[src_vtype], rgptr =
vertex_range[dst_vtype]); UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr = UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) { } else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype); aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csc, aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
edge_range, csc, edge_range, vertex_range[dst_vtype], vertex_range[src_vtype]);
vertex_range[dst_vtype], rgptr =
vertex_range[src_vtype]); UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr = UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
} }
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
} }
return CreateHeteroGraph(meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>()); return CreateHeteroGraph(
meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>());
} }
template <class IdType> template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) { GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes) {
// Sanity check for vertex sizes // Sanity check for vertex sizes
const uint64_t len_vertex_sizes = vertex_sizes->shape[0]; const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data); const uint64_t* vertex_sizes_data =
static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices(); const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types; const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
// Map vertex type to the corresponding node cum sum // Map vertex type to the corresponding node cum sum
...@@ -353,8 +367,10 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -353,8 +367,10 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
vertex_cumsum[vtype].push_back( vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]); vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
} }
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype)) CHECK_EQ(
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype; vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type "
<< vtype;
} }
// Sanity check for edge sizes // Sanity check for edge sizes
...@@ -372,7 +388,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -372,7 +388,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]); edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
} }
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype)) CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype; << "Sum of the given sizes must equal to the number of edges for type "
<< etype;
} }
// Construct relation graphs for unbatched graphs // Construct relation graphs for unbatched graphs
...@@ -390,7 +407,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -390,7 +407,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
std::vector<IdType> result_src, result_dst; std::vector<IdType> result_src, result_dst;
// Loop over the chunk of edges for the specified graph and edge type // Loop over the chunk of edges for the specified graph and edge type
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) { for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1];
++e) {
// TODO(mufei): Should use array operations to implement this. // TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]); result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]);
result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]); result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
...@@ -410,15 +428,18 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -410,15 +428,18 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
for (uint64_t i = 0; i < num_vertex_types; ++i) for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g]; num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type)); rst.push_back(
CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
} }
return rst; return rst;
} }
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>( template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes); GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>( template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes); GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
} // namespace dgl } // namespace dgl
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
* \file graph/traversal.cc * \file graph/traversal.cc
* \brief Graph traversal implementation * \brief Graph traversal implementation
*/ */
#include "./traversal.h"
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include "./traversal.h"
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -15,46 +18,36 @@ namespace dgl { ...@@ -15,46 +18,36 @@ namespace dgl {
namespace traverse { namespace traverse {
namespace { namespace {
// A utility view class to wrap a vector into a queue. // A utility view class to wrap a vector into a queue.
template<typename DType> template <typename DType>
struct VectorQueueWrapper { struct VectorQueueWrapper {
std::vector<DType>* vec; std::vector<DType>* vec;
size_t head = 0; size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {} explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
void push(const DType& elem) { void push(const DType& elem) { vec->push_back(elem); }
vec->push_back(elem);
}
DType top() const { DType top() const { return vec->operator[](head); }
return vec->operator[](head);
}
void pop() { void pop() { ++head; }
++head;
}
bool empty() const { bool empty() const { return head == vec->size(); }
return head == vec->size();
}
size_t size() const { size_t size() const { return vec->size() - head; }
return vec->size() - head;
}
}; };
// Internal function to merge multiple traversal traces into one ndarray. // Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together. // It is similar to zip the vectors together.
template<typename DType> template <typename DType>
IdArray MergeMultipleTraversals( IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0; int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) { for (size_t j = 0; j < traces.size(); ++j) {
...@@ -70,15 +63,15 @@ IdArray MergeMultipleTraversals( ...@@ -70,15 +63,15 @@ IdArray MergeMultipleTraversals(
// Internal function to compute sections if multiple traversal traces // Internal function to compute sections if multiple traversal traces
// are merged into one ndarray. // are merged into one ndarray.
template<typename DType> template <typename DType>
IdArray ComputeMergedSections( IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0; int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
...@@ -99,7 +92,8 @@ IdArray ComputeMergedSections( ...@@ -99,7 +92,8 @@ IdArray ComputeMergedSections(
* \brief Class for representing frontiers. * \brief Class for representing frontiers.
* *
* Each frontier is a list of nodes/edges (specified by their ids). * Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value). * An optional tag can be specified on each node/edge (represented by an int
* value).
*/ */
struct Frontiers { struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */ /*!\brief a vector store for the nodes/edges in all the frontiers */
...@@ -112,11 +106,12 @@ struct Frontiers { ...@@ -112,11 +106,12 @@ struct Frontiers {
std::vector<int64_t> sections; std::vector<int64_t> sections;
}; };
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) { Frontiers BFSNodesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
...@@ -127,7 +122,7 @@ Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool re ...@@ -127,7 +122,7 @@ Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool re
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
...@@ -137,12 +132,13 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes") ...@@ -137,12 +132,13 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) { Frontiers BFSEdgesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
// NOTE: std::queue has no top() method. // NOTE: std::queue has no top() method.
std::vector<dgl_id_t> nodes; std::vector<dgl_id_t> nodes;
VectorQueueWrapper<dgl_id_t> queue(&nodes); VectorQueueWrapper<dgl_id_t> queue(&nodes);
auto visit = [&] (const dgl_id_t e) { front.ids.push_back(e); }; auto visit = [&](const dgl_id_t e) { front.ids.push_back(e); };
bool first_frontier = true; bool first_frontier = true;
auto make_frontier = [&] { auto make_frontier = [&] {
if (first_frontier) { if (first_frontier) {
...@@ -157,7 +153,7 @@ Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool re ...@@ -157,7 +153,7 @@ Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool re
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
...@@ -167,11 +163,12 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges") ...@@ -167,11 +163,12 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
}); });
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) { Frontiers TopologicalNodesFrontiers(
const GraphInterface& graph, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
...@@ -182,7 +179,7 @@ Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) ...@@ -182,7 +179,7 @@ Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed)
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
bool reversed = args[1]; bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed); const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
...@@ -191,9 +188,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") ...@@ -191,9 +188,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
...@@ -202,7 +198,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -202,7 +198,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); }; auto visit = [&](dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit); DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
} }
IdArray ids = MergeMultipleTraversals(edges); IdArray ids = MergeMultipleTraversals(edges);
...@@ -211,7 +207,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -211,7 +207,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
...@@ -229,14 +225,15 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") ...@@ -229,14 +225,15 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
tags.resize(len); tags.resize(len);
} }
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { auto visit = [&](dgl_id_t e, int tag) {
edges[i].push_back(e); edges[i].push_back(e);
if (return_labels) { if (return_labels) {
tags[i].push_back(tag); tags[i].push_back(tag);
} }
}; };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, DFSLabeledEdges(
has_reverse_edge, has_nontree_edge, visit); *g.sptr(), src_data[i], reversed, has_reverse_edge,
has_nontree_edge, visit);
} }
IdArray ids = MergeMultipleTraversals(edges); IdArray ids = MergeMultipleTraversals(edges);
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
* \file graph/traversal.h * \file graph/traversal.h
* \brief Graph traversal routines. * \brief Graph traversal routines.
* *
* Traversal routines generate frontiers. Frontiers can be node frontiers or edge * Traversal routines generate frontiers. Frontiers can be node frontiers or
* frontiers depending on the traversal function. Each frontier is a * edge frontiers depending on the traversal function. Each frontier is a list
* list of nodes/edges (specified by their ids). An optional tag can be specified * of nodes/edges (specified by their ids). An optional tag can be specified for
* for each node/edge (represented by an int value). * each node/edge (represented by an int value).
*/ */
#ifndef DGL_GRAPH_TRAVERSAL_H_ #ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_ #define DGL_GRAPH_TRAVERSAL_H_
#include <dgl/graph_interface.h> #include <dgl/graph_interface.h>
#include <stack> #include <stack>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -39,18 +40,16 @@ namespace traverse { ...@@ -39,18 +40,16 @@ namespace traverse {
* *
* \param graph The graph. * \param graph The graph.
* \param sources Source nodes. * \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction * \param reversed If true, BFS follows the in-edge direction.
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made.
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template <typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(const GraphInterface& graph, void BFSNodes(
IdArray source, const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,
bool reversed, VisitFn visit, FrontierFn make_frontier) {
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
...@@ -63,7 +62,8 @@ void BFSNodes(const GraphInterface& graph, ...@@ -63,7 +62,8 @@ void BFSNodes(const GraphInterface& graph,
} }
make_frontier(); make_frontier();
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec; const auto neighbor_iter =
reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
while (!queue->empty()) { while (!queue->empty()) {
const size_t size = queue->size(); const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -102,19 +102,17 @@ void BFSNodes(const GraphInterface& graph, ...@@ -102,19 +102,17 @@ void BFSNodes(const GraphInterface& graph,
* *
* \param graph The graph. * \param graph The graph.
* \param sources Source nodes. * \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction * \param reversed If true, BFS follows the in-edge direction.
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* The argument would be edge ID. * The argument would be edge ID.
* \param make_frontier The function to indicate that a new frontier can be made; * \param make_frontier The function to indicate that a new frontier can be
* made.
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template <typename Queue, typename VisitFn, typename FrontierFn>
void BFSEdges(const GraphInterface& graph, void BFSEdges(
IdArray source, const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,
bool reversed, VisitFn visit, FrontierFn make_frontier) {
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
...@@ -126,7 +124,8 @@ void BFSEdges(const GraphInterface& graph, ...@@ -126,7 +124,8 @@ void BFSEdges(const GraphInterface& graph,
} }
make_frontier(); make_frontier();
const auto neighbor_iter = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec; const auto neighbor_iter =
reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
while (!queue->empty()) { while (!queue->empty()) {
const size_t size = queue->size(); const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -165,19 +164,20 @@ void BFSEdges(const GraphInterface& graph, ...@@ -165,19 +164,20 @@ void BFSEdges(const GraphInterface& graph,
* void (*make_frontier)(void); * void (*make_frontier)(void);
* *
* \param graph The graph. * \param graph The graph.
* \param reversed If true, follows the in-edge direction * \param reversed If true, follows the in-edge direction.
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made.
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template <typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const GraphInterface& graph, void TopologicalNodes(
bool reversed, const GraphInterface& graph, bool reversed, Queue* queue, VisitFn visit,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) { FrontierFn make_frontier) {
const auto get_degree = reversed? &GraphInterface::OutDegree : &GraphInterface::InDegree; const auto get_degree =
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec; reversed ? &GraphInterface::OutDegree : &GraphInterface::InDegree;
const auto neighbor_iter =
reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
uint64_t num_visited_nodes = 0; uint64_t num_visited_nodes = 0;
std::vector<uint64_t> degrees(graph.NumVertices(), 0); std::vector<uint64_t> degrees(graph.NumVertices(), 0);
for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) { for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) {
...@@ -207,7 +207,8 @@ void TopologicalNodes(const GraphInterface& graph, ...@@ -207,7 +207,8 @@ void TopologicalNodes(const GraphInterface& graph,
} }
if (num_visited_nodes != graph.NumVertices()) { if (num_visited_nodes != graph.NumVertices()) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph."; LOG(FATAL)
<< "Error in topological traversal: loop detected in the given graph.";
} }
} }
...@@ -221,30 +222,28 @@ enum DFSEdgeTag { ...@@ -221,30 +222,28 @@ enum DFSEdgeTag {
* \brief Traverse the graph in a depth-first-search (DFS) order. * \brief Traverse the graph in a depth-first-search (DFS) order.
* *
* The traversal visit edges in its DFS order. Edges have three tags: * The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2) * FORWARD(0), REVERSE(1), NONTREE(2).
* *
* A FORWARD edge is one in which `u` has been visisted but `v` has not. * A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge * A REVERSE edge is one in which both `u` and `v` have been visisted and the
* is in the DFS tree. * edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge * been visisted but the edge is NOT in the DFS tree.
* is NOT in the DFS tree.
* *
* \param source Source node. * \param source Source node.
* \param reversed If true, DFS follows the in-edge direction * \param reversed If true, DFS follows the in-edge direction.
* \param has_reverse_edge If true, REVERSE edges are included * \param has_reverse_edge If true, REVERSE edges are included.
* \param has_nontree_edge If true, NONTREE edges are included * \param has_nontree_edge If true, NONTREE edges are included.
* \param visit The function to call when an edge is visited; the edge id and its * \param visit The function to call when an edge is visited; the edge id and
* tag will be given as the arguments. * its tag will be given as the arguments.
*/ */
template<typename VisitFn> template <typename VisitFn>
void DFSLabeledEdges(const GraphInterface& graph, void DFSLabeledEdges(
dgl_id_t source, const GraphInterface& graph, dgl_id_t source, bool reversed,
bool reversed, bool has_reverse_edge, bool has_nontree_edge, VisitFn visit) {
bool has_reverse_edge, const auto succ =
bool has_nontree_edge, reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
VisitFn visit) { const auto out_edge =
const auto succ = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec; reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
const auto out_edge = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
if ((graph.*succ)(source).size() == 0) { if ((graph.*succ)(source).size() == 0) {
// no out-going edges from the source node // no out-going edges from the source node
...@@ -273,7 +272,7 @@ void DFSLabeledEdges(const GraphInterface& graph, ...@@ -273,7 +272,7 @@ void DFSLabeledEdges(const GraphInterface& graph,
stack.pop(); stack.pop();
// find next one. // find next one.
if (i < (graph.*succ)(u).size() - 1) { if (i < (graph.*succ)(u).size() - 1) {
stack.push(std::make_tuple(u, i+1, false)); stack.push(std::make_tuple(u, i + 1, false));
} }
} else { } else {
visited[v] = true; visited[v] = true;
...@@ -291,4 +290,3 @@ void DFSLabeledEdges(const GraphInterface& graph, ...@@ -291,4 +290,3 @@ void DFSLabeledEdges(const GraphInterface& graph,
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_ #endif // DGL_GRAPH_TRAVERSAL_H_
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -17,7 +17,8 @@ namespace network { ...@@ -17,7 +17,8 @@ namespace network {
// In most cases, delim contains only one character. In this case, we // In most cases, delim contains only one character. In this case, we
// use CalculateReserveForVector to count the number of elements should // use CalculateReserveForVector to count the number of elements should
// be reserved in result vector, and thus optimize SplitStringUsing. // be reserved in result vector, and thus optimize SplitStringUsing.
static int CalculateReserveForVector(const std::string& full, const char* delim) { static int CalculateReserveForVector(
const std::string& full, const char* delim) {
int count = 0; int count = 0;
if (delim[0] != '\0' && delim[1] == '\0') { if (delim[0] != '\0' && delim[1] == '\0') {
// Optimize the common case where delim is a single character. // Optimize the common case where delim is a single character.
...@@ -38,19 +39,18 @@ static int CalculateReserveForVector(const std::string& full, const char* delim) ...@@ -38,19 +39,18 @@ static int CalculateReserveForVector(const std::string& full, const char* delim)
return count; return count;
} }
void SplitStringUsing(const std::string& full, void SplitStringUsing(
const char* delim, const std::string& full, const char* delim,
std::vector<std::string>* result) { std::vector<std::string>* result) {
CHECK(delim != NULL); CHECK(delim != NULL);
CHECK(result != NULL); CHECK(result != NULL);
result->reserve(CalculateReserveForVector(full, delim)); result->reserve(CalculateReserveForVector(full, delim));
back_insert_iterator< std::vector<std::string> > it(*result); back_insert_iterator<std::vector<std::string> > it(*result);
SplitStringToIteratorUsing(full, delim, &it); SplitStringToIteratorUsing(full, delim, &it);
} }
void SplitStringToSetUsing(const std::string& full, void SplitStringToSetUsing(
const char* delim, const std::string& full, const char* delim, std::set<std::string>* result) {
std::set<std::string>* result) {
CHECK(delim != NULL); CHECK(delim != NULL);
CHECK(result != NULL); CHECK(result != NULL);
simple_insert_iterator<std::set<std::string> > it(result); simple_insert_iterator<std::set<std::string> > it(result);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment