"vscode:/vscode.git/clone" did not exist on "3033f08201e4ec46d87c37098b066b8b1924b052"
Unverified Commit 8f0df39e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4810)



* [Misc] clang-format auto fix.

* manual

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 401e1278
...@@ -4,15 +4,17 @@ ...@@ -4,15 +4,17 @@
* \brief Convert multigraphs to simple graphs * \brief Convert multigraphs to simple graphs
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <vector> #include <dgl/transform.h>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h"
#include "../heterograph.h" #include "../heterograph.h"
#include "../unit_graph.h" #include "../unit_graph.h"
#include "../../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -25,7 +27,8 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ...@@ -25,7 +27,8 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToSimpleGraph(const HeteroGraphPtr graph) { ToSimpleGraph(const HeteroGraphPtr graph) {
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
const auto metagraph = graph->meta_graph(); const auto metagraph = graph->meta_graph();
const auto &ugs = std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs(); const auto &ugs =
std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();
std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes); std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes);
std::vector<HeteroGraphPtr> rel_graphs(num_etypes); std::vector<HeteroGraphPtr> rel_graphs(num_etypes);
...@@ -35,31 +38,31 @@ ToSimpleGraph(const HeteroGraphPtr graph) { ...@@ -35,31 +38,31 @@ ToSimpleGraph(const HeteroGraphPtr graph) {
std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result; std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result;
} }
const HeteroGraphPtr result = CreateHeteroGraph( const HeteroGraphPtr result =
metagraph, rel_graphs, graph->NumVerticesPerType()); CreateHeteroGraph(metagraph, rel_graphs, graph->NumVerticesPerType());
return std::make_tuple(result, counts, edge_maps); return std::make_tuple(result, counts, edge_maps);
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const auto result = ToSimpleGraph(graph_ref.sptr()); const auto result = ToSimpleGraph(graph_ref.sptr());
List<Value> counts, edge_maps; List<Value> counts, edge_maps;
for (const IdArray &count : std::get<1>(result)) for (const IdArray &count : std::get<1>(result))
counts.push_back(Value(MakeValue(count))); counts.push_back(Value(MakeValue(count)));
for (const IdArray &edge_map : std::get<2>(result)) for (const IdArray &edge_map : std::get<2>(result))
edge_maps.push_back(Value(MakeValue(edge_map))); edge_maps.push_back(Value(MakeValue(edge_map)));
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(std::get<0>(result))); ret.push_back(HeteroGraphRef(std::get<0>(result)));
ret.push_back(counts); ret.push_back(counts);
ret.push_back(edge_maps); ret.push_back(edge_maps);
*rv = ret; *rv = ret;
}); });
}; // namespace transform }; // namespace transform
......
...@@ -9,8 +9,9 @@ using namespace dgl::runtime; ...@@ -9,8 +9,9 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
HeteroGraphPtr JointUnionHeteroGraph( HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list has at least two graphs"; CHECK_GT(component_graphs.size(), 0)
<< "Input graph list has at least two graphs";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges()); std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0); std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
...@@ -24,18 +25,21 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -24,18 +25,21 @@ HeteroGraphPtr JointUnionHeteroGraph(
HeteroGraphPtr rgptr = nullptr; HeteroGraphPtr rgptr = nullptr;
// ALL = CSC | CSR | COO // ALL = CSC | CSR | COO
const dgl_format_code_t code =\ const dgl_format_code_t code =
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats(); component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// get common format // get common format
for (size_t i = 0; i < component_graphs.size(); ++i) { for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i]; const auto& cg = component_graphs[i];
CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype)) << "Input graph[" << i << CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype))
"] should have same number of src vertices as input graph[0]"; << "Input graph[" << i
CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype)) << "Input graph[" << i << << "] should have same number of src vertices as input graph[0]";
"] should have same number of dst vertices as input graph[0]"; CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype))
<< "Input graph[" << i
const dgl_format_code_t curr_code = cg->GetRelationGraph(etype)->GetAllowedFormats(); << "] should have same number of dst vertices as input graph[0]";
const dgl_format_code_t curr_code =
cg->GetRelationGraph(etype)->GetAllowedFormats();
if (curr_code != code) if (curr_code != code)
LOG(FATAL) << "All components should have the same formats"; LOG(FATAL) << "All components should have the same formats";
} }
...@@ -50,8 +54,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -50,8 +54,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
} }
aten::COOMatrix res = aten::UnionCoo(coos); aten::COOMatrix res = aten::UnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) { } else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs; std::vector<aten::CSRMatrix> csrs;
for (size_t i = 0; i < component_graphs.size(); ++i) { for (size_t i = 0; i < component_graphs.size(); ++i) {
...@@ -61,8 +65,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -61,8 +65,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
} }
aten::CSRMatrix res = aten::UnionCsr(csrs); aten::CSRMatrix res = aten::UnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) { } else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs; std::vector<aten::CSRMatrix> cscs;
...@@ -73,8 +77,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -73,8 +77,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
} }
aten::CSRMatrix res = aten::UnionCsr(cscs); aten::CSRMatrix res = aten::UnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
} }
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
...@@ -82,7 +86,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -82,7 +86,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
num_nodes_per_type[dst_vtype] = num_dst_v; num_nodes_per_type[dst_vtype] = num_dst_v;
} }
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type)); return CreateHeteroGraph(
meta_graph, rel_graphs, std::move(num_nodes_per_type));
} }
HeteroGraphPtr DisjointUnionHeteroGraph2( HeteroGraphPtr DisjointUnionHeteroGraph2(
...@@ -94,8 +99,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -94,8 +99,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
// Loop over all ntypes // Loop over all ntypes
for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) { for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) {
uint64_t offset = 0; uint64_t offset = 0;
for (const auto &cg : component_graphs) for (const auto& cg : component_graphs) offset += cg->NumVertices(vtype);
offset += cg->NumVertices(vtype);
num_nodes_per_type[vtype] = offset; num_nodes_per_type[vtype] = offset;
} }
...@@ -106,11 +110,12 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -106,11 +110,12 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
HeteroGraphPtr rgptr = nullptr; HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code =\ const dgl_format_code_t code =
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats(); component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// do some preprocess // do some preprocess
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
const dgl_format_code_t cur_code = cg->GetRelationGraph(etype)->GetAllowedFormats(); const dgl_format_code_t cur_code =
cg->GetRelationGraph(etype)->GetAllowedFormats();
if (cur_code != code) if (cur_code != code)
LOG(FATAL) << "All components should have the same formats"; LOG(FATAL) << "All components should have the same formats";
} }
...@@ -118,51 +123,55 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -118,51 +123,55 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
// prefer COO // prefer COO
if (FORMAT_HAS_COO(code)) { if (FORMAT_HAS_COO(code)) {
std::vector<aten::COOMatrix> coos; std::vector<aten::COOMatrix> coos;
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
aten::COOMatrix coo = cg->GetCOOMatrix(etype); aten::COOMatrix coo = cg->GetCOOMatrix(etype);
coos.push_back(coo); coos.push_back(coo);
} }
aten::COOMatrix res = aten::DisjointUnionCoo(coos); aten::COOMatrix res = aten::DisjointUnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) { } else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs; std::vector<aten::CSRMatrix> csrs;
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
aten::CSRMatrix csr = cg->GetCSRMatrix(etype); aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
csrs.push_back(csr); csrs.push_back(csr);
} }
aten::CSRMatrix res = aten::DisjointUnionCsr(csrs); aten::CSRMatrix res = aten::DisjointUnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) { } else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs; std::vector<aten::CSRMatrix> cscs;
for (const auto &cg : component_graphs) { for (const auto& cg : component_graphs) {
aten::CSRMatrix csc = cg->GetCSCMatrix(etype); aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
cscs.push_back(csc); cscs.push_back(csc);
} }
aten::CSRMatrix res = aten::DisjointUnionCsr(cscs); aten::CSRMatrix res = aten::DisjointUnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC( rgptr =
(src_vtype == dst_vtype) ? 1 : 2, res, code); UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
} }
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
} }
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type)); return CreateHeteroGraph(
meta_graph, rel_graphs, std::move(num_nodes_per_type));
} }
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) { GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes) {
// Sanity check for vertex sizes // Sanity check for vertex sizes
CHECK_EQ(vertex_sizes->dtype.bits, 64) << "dtype of vertex_sizes should be int64"; CHECK_EQ(vertex_sizes->dtype.bits, 64)
<< "dtype of vertex_sizes should be int64";
CHECK_EQ(edge_sizes->dtype.bits, 64) << "dtype of edge_sizes should be int64"; CHECK_EQ(edge_sizes->dtype.bits, 64) << "dtype of edge_sizes should be int64";
const uint64_t len_vertex_sizes = vertex_sizes->shape[0]; const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data); const uint64_t* vertex_sizes_data =
static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices(); const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types; const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
...@@ -175,10 +184,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -175,10 +184,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of vertices in the batch for all types // We've flattened the number of vertices in the batch for all types
vertex_cumsum[vtype].push_back( vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]); vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
} }
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype)) CHECK_EQ(
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype; vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type "
<< vtype;
} }
// Sanity check for edge sizes // Sanity check for edge sizes
...@@ -193,10 +204,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -193,10 +204,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of edges in the batch for all types // We've flattened the number of edges in the batch for all types
edge_cumsum[etype].push_back( edge_cumsum[etype].push_back(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]); edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
} }
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype)) CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype; << "Sum of the given sizes must equal to the number of edges for type "
<< etype;
} }
// Construct relation graphs for unbatched graphs // Construct relation graphs for unbatched graphs
...@@ -211,14 +223,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -211,14 +223,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype); aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
auto res = aten::DisjointPartitionCooBySizes(coo, auto res = aten::DisjointPartitionCooBySizes(
batch_size, coo, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
edge_cumsum[etype], vertex_cumsum[dst_vtype]);
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO( HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, res[g], code); (src_vtype == dst_vtype) ? 1 : 2, res[g], code);
rel_graphs[g].push_back(rgptr); rel_graphs[g].push_back(rgptr);
} }
} }
...@@ -228,14 +238,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -228,14 +238,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype); aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csr, auto res = aten::DisjointPartitionCsrBySizes(
batch_size, csr, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
edge_cumsum[etype], vertex_cumsum[dst_vtype]);
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR( HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR(
(src_vtype == dst_vtype) ? 1 : 2, res[g], code); (src_vtype == dst_vtype) ? 1 : 2, res[g], code);
rel_graphs[g].push_back(rgptr); rel_graphs[g].push_back(rgptr);
} }
} }
...@@ -246,14 +254,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -246,14 +254,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype); aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csc, auto res = aten::DisjointPartitionCsrBySizes(
batch_size, csc, batch_size, edge_cumsum[etype], vertex_cumsum[dst_vtype],
edge_cumsum[etype], vertex_cumsum[src_vtype]);
vertex_cumsum[dst_vtype],
vertex_cumsum[src_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC( HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC(
(src_vtype == dst_vtype) ? 1 : 2, res[g], code); (src_vtype == dst_vtype) ? 1 : 2, res[g], code);
rel_graphs[g].push_back(rgptr); rel_graphs[g].push_back(rgptr);
} }
} }
...@@ -264,20 +270,26 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -264,20 +270,26 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
for (uint64_t i = 0; i < num_vertex_types; ++i) for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g]; num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type)); rst.push_back(
CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
} }
return rst; return rst;
} }
HeteroGraphPtr SliceHeteroGraph( HeteroGraphPtr SliceHeteroGraph(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray num_nodes_per_type, GraphPtr meta_graph, HeteroGraphPtr batched_graph,
IdArray start_nid_per_type, IdArray num_edges_per_type, IdArray start_eid_per_type) { IdArray num_nodes_per_type, IdArray start_nid_per_type,
IdArray num_edges_per_type, IdArray start_eid_per_type) {
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges()); std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
const uint64_t* start_nid_per_type_data = static_cast<uint64_t*>(start_nid_per_type->data); const uint64_t* start_nid_per_type_data =
const uint64_t* num_nodes_per_type_data = static_cast<uint64_t*>(num_nodes_per_type->data); static_cast<uint64_t*>(start_nid_per_type->data);
const uint64_t* start_eid_per_type_data = static_cast<uint64_t*>(start_eid_per_type->data); const uint64_t* num_nodes_per_type_data =
const uint64_t* num_edges_per_type_data = static_cast<uint64_t*>(num_edges_per_type->data); static_cast<uint64_t*>(num_nodes_per_type->data);
const uint64_t* start_eid_per_type_data =
static_cast<uint64_t*>(start_eid_per_type->data);
const uint64_t* num_edges_per_type_data =
static_cast<uint64_t*>(num_edges_per_type->data);
// Map vertex type to the corresponding node range // Map vertex type to the corresponding node range
const uint64_t num_vertex_types = meta_graph->NumVertices(); const uint64_t num_vertex_types = meta_graph->NumVertices();
...@@ -287,7 +299,7 @@ HeteroGraphPtr SliceHeteroGraph( ...@@ -287,7 +299,7 @@ HeteroGraphPtr SliceHeteroGraph(
for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) { for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
vertex_range[vtype].push_back(start_nid_per_type_data[vtype]); vertex_range[vtype].push_back(start_nid_per_type_data[vtype]);
vertex_range[vtype].push_back( vertex_range[vtype].push_back(
start_nid_per_type_data[vtype] + num_nodes_per_type_data[vtype]); start_nid_per_type_data[vtype] + num_nodes_per_type_data[vtype]);
} }
// Loop over all canonical etypes // Loop over all canonical etypes
...@@ -296,50 +308,52 @@ HeteroGraphPtr SliceHeteroGraph( ...@@ -296,50 +308,52 @@ HeteroGraphPtr SliceHeteroGraph(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
HeteroGraphPtr rgptr = nullptr; HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code = batched_graph->GetRelationGraph(etype)->GetAllowedFormats(); const dgl_format_code_t code =
batched_graph->GetRelationGraph(etype)->GetAllowedFormats();
// handle graph without edges // handle graph without edges
std::vector<uint64_t> edge_range; std::vector<uint64_t> edge_range;
edge_range.push_back(start_eid_per_type_data[etype]); edge_range.push_back(start_eid_per_type_data[etype]);
edge_range.push_back(start_eid_per_type_data[etype] + num_edges_per_type_data[etype]); edge_range.push_back(
start_eid_per_type_data[etype] + num_edges_per_type_data[etype]);
// prefer COO // prefer COO
if (FORMAT_HAS_COO(code)) { if (FORMAT_HAS_COO(code)) {
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype); aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
aten::COOMatrix res = aten::COOSliceContiguousChunk(coo, aten::COOMatrix res = aten::COOSliceContiguousChunk(
edge_range, coo, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
vertex_range[src_vtype], rgptr =
vertex_range[dst_vtype]); UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr = UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) { } else if (FORMAT_HAS_CSR(code)) {
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype); aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csr, aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
edge_range, csr, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
vertex_range[src_vtype], rgptr =
vertex_range[dst_vtype]); UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr = UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) { } else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype); aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csc, aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
edge_range, csc, edge_range, vertex_range[dst_vtype], vertex_range[src_vtype]);
vertex_range[dst_vtype], rgptr =
vertex_range[src_vtype]); UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr = UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
} }
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
} }
return CreateHeteroGraph(meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>()); return CreateHeteroGraph(
meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>());
} }
template <class IdType> template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) { GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes) {
// Sanity check for vertex sizes // Sanity check for vertex sizes
const uint64_t len_vertex_sizes = vertex_sizes->shape[0]; const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data); const uint64_t* vertex_sizes_data =
static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices(); const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types; const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
// Map vertex type to the corresponding node cum sum // Map vertex type to the corresponding node cum sum
...@@ -351,10 +365,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -351,10 +365,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of vertices in the batch for all types // We've flattened the number of vertices in the batch for all types
vertex_cumsum[vtype].push_back( vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]); vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
} }
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype)) CHECK_EQ(
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype; vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type "
<< vtype;
} }
// Sanity check for edge sizes // Sanity check for edge sizes
...@@ -369,10 +385,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -369,10 +385,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of edges in the batch for all types // We've flattened the number of edges in the batch for all types
edge_cumsum[etype].push_back( edge_cumsum[etype].push_back(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]); edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
} }
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype)) CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype; << "Sum of the given sizes must equal to the number of edges for type "
<< etype;
} }
// Construct relation graphs for unbatched graphs // Construct relation graphs for unbatched graphs
...@@ -390,7 +407,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -390,7 +407,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
std::vector<IdType> result_src, result_dst; std::vector<IdType> result_src, result_dst;
// Loop over the chunk of edges for the specified graph and edge type // Loop over the chunk of edges for the specified graph and edge type
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) { for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1];
++e) {
// TODO(mufei): Should use array operations to implement this. // TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]); result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]);
result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]); result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
...@@ -410,15 +428,18 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -410,15 +428,18 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) { for (uint64_t g = 0; g < batch_size; ++g) {
for (uint64_t i = 0; i < num_vertex_types; ++i) for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g]; num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type)); rst.push_back(
CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
} }
return rst; return rst;
} }
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>( template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes); GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>( template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes); GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
} // namespace dgl } // namespace dgl
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
* \file graph/traversal.cc * \file graph/traversal.cc
* \brief Graph traversal implementation * \brief Graph traversal implementation
*/ */
#include "./traversal.h"
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include "./traversal.h"
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -15,46 +18,36 @@ namespace dgl { ...@@ -15,46 +18,36 @@ namespace dgl {
namespace traverse { namespace traverse {
namespace { namespace {
// A utility view class to wrap a vector into a queue. // A utility view class to wrap a vector into a queue.
template<typename DType> template <typename DType>
struct VectorQueueWrapper { struct VectorQueueWrapper {
std::vector<DType>* vec; std::vector<DType>* vec;
size_t head = 0; size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {} explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
void push(const DType& elem) { void push(const DType& elem) { vec->push_back(elem); }
vec->push_back(elem);
}
DType top() const { DType top() const { return vec->operator[](head); }
return vec->operator[](head);
}
void pop() { void pop() { ++head; }
++head;
}
bool empty() const { bool empty() const { return head == vec->size(); }
return head == vec->size();
}
size_t size() const { size_t size() const { return vec->size() - head; }
return vec->size() - head;
}
}; };
// Internal function to merge multiple traversal traces into one ndarray. // Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together. // It is similar to zip the vectors together.
template<typename DType> template <typename DType>
IdArray MergeMultipleTraversals( IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0; int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) { for (size_t j = 0; j < traces.size(); ++j) {
...@@ -70,15 +63,15 @@ IdArray MergeMultipleTraversals( ...@@ -70,15 +63,15 @@ IdArray MergeMultipleTraversals(
// Internal function to compute sections if multiple traversal traces // Internal function to compute sections if multiple traversal traces
// are merged into one ndarray. // are merged into one ndarray.
template<typename DType> template <typename DType>
IdArray ComputeMergedSections( IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0; int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
...@@ -99,7 +92,8 @@ IdArray ComputeMergedSections( ...@@ -99,7 +92,8 @@ IdArray ComputeMergedSections(
* \brief Class for representing frontiers. * \brief Class for representing frontiers.
* *
* Each frontier is a list of nodes/edges (specified by their ids). * Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value). * An optional tag can be specified on each node/edge (represented by an int
* value).
*/ */
struct Frontiers { struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */ /*!\brief a vector store for the nodes/edges in all the frontiers */
...@@ -112,142 +106,145 @@ struct Frontiers { ...@@ -112,142 +106,145 @@ struct Frontiers {
std::vector<int64_t> sections; std::vector<int64_t> sections;
}; };
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) { Frontiers BFSNodesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
} }
}; };
BFSNodes(graph, source, reversed, &queue, visit, make_frontier); BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
return front; return front;
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed); const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids); IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) { Frontiers BFSEdgesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
// NOTE: std::queue has no top() method. // NOTE: std::queue has no top() method.
std::vector<dgl_id_t> nodes; std::vector<dgl_id_t> nodes;
VectorQueueWrapper<dgl_id_t> queue(&nodes); VectorQueueWrapper<dgl_id_t> queue(&nodes);
auto visit = [&] (const dgl_id_t e) { front.ids.push_back(e); }; auto visit = [&](const dgl_id_t e) { front.ids.push_back(e); };
bool first_frontier = true; bool first_frontier = true;
auto make_frontier = [&] { auto make_frontier = [&] {
if (first_frontier) { if (first_frontier) {
first_frontier = false; // do not push the first section when doing edges first_frontier = false; // do not push the first section when doing edges
} else if (!queue.empty()) { } else if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
} }
}; };
BFSEdges(graph, source, reversed, &queue, visit, make_frontier); BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
return front; return front;
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed); const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids); IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
}); });
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) { Frontiers TopologicalNodesFrontiers(
const GraphInterface& graph, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
} }
}; };
TopologicalNodes(graph, reversed, &queue, visit, make_frontier); TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
return front; return front;
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
bool reversed = args[1]; bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed); const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids); IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); }; auto visit = [&](dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit); DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
} }
IdArray ids = MergeMultipleTraversals(edges); IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges); IdArray sections = ComputeMergedSections(edges);
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
const bool has_reverse_edge = args[3]; const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4]; const bool has_nontree_edge = args[4];
const bool return_labels = args[5]; const bool return_labels = args[5];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
std::vector<std::vector<int64_t>> tags; std::vector<std::vector<int64_t>> tags;
if (return_labels) { if (return_labels) {
tags.resize(len); tags.resize(len);
} }
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { auto visit = [&](dgl_id_t e, int tag) {
edges[i].push_back(e); edges[i].push_back(e);
if (return_labels) { if (return_labels) {
tags[i].push_back(tag); tags[i].push_back(tag);
} }
}; };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, DFSLabeledEdges(
has_reverse_edge, has_nontree_edge, visit); *g.sptr(), src_data[i], reversed, has_reverse_edge,
} has_nontree_edge, visit);
}
IdArray ids = MergeMultipleTraversals(edges); IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges); IdArray sections = ComputeMergedSections(edges);
if (return_labels) { if (return_labels) {
IdArray labels = MergeMultipleTraversals(tags); IdArray labels = MergeMultipleTraversals(tags);
*rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections}); *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
} else { } else {
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
} }
}); });
} // namespace traverse } // namespace traverse
} // namespace dgl } // namespace dgl
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
* \file graph/traversal.h * \file graph/traversal.h
* \brief Graph traversal routines. * \brief Graph traversal routines.
* *
* Traversal routines generate frontiers. Frontiers can be node frontiers or edge * Traversal routines generate frontiers. Frontiers can be node frontiers or
* frontiers depending on the traversal function. Each frontier is a * edge frontiers depending on the traversal function. Each frontier is a list
* list of nodes/edges (specified by their ids). An optional tag can be specified * of nodes/edges (specified by their ids). An optional tag can be specified for
* for each node/edge (represented by an int value). * each node/edge (represented by an int value).
*/ */
#ifndef DGL_GRAPH_TRAVERSAL_H_ #ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_ #define DGL_GRAPH_TRAVERSAL_H_
#include <dgl/graph_interface.h> #include <dgl/graph_interface.h>
#include <stack> #include <stack>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -39,18 +40,16 @@ namespace traverse { ...@@ -39,18 +40,16 @@ namespace traverse {
* *
* \param graph The graph. * \param graph The graph.
* \param sources Source nodes. * \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction * \param reversed If true, BFS follows the in-edge direction.
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made.
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template <typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(const GraphInterface& graph, void BFSNodes(
IdArray source, const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,
bool reversed, VisitFn visit, FrontierFn make_frontier) {
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
...@@ -63,7 +62,8 @@ void BFSNodes(const GraphInterface& graph, ...@@ -63,7 +62,8 @@ void BFSNodes(const GraphInterface& graph,
} }
make_frontier(); make_frontier();
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec; const auto neighbor_iter =
reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
while (!queue->empty()) { while (!queue->empty()) {
const size_t size = queue->size(); const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -102,19 +102,17 @@ void BFSNodes(const GraphInterface& graph, ...@@ -102,19 +102,17 @@ void BFSNodes(const GraphInterface& graph,
* *
* \param graph The graph. * \param graph The graph.
* \param sources Source nodes. * \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction * \param reversed If true, BFS follows the in-edge direction.
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* The argument would be edge ID. * The argument would be edge ID.
* \param make_frontier The function to indicate that a new frontier can be made; * \param make_frontier The function to indicate that a new frontier can be
* made.
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template <typename Queue, typename VisitFn, typename FrontierFn>
void BFSEdges(const GraphInterface& graph, void BFSEdges(
IdArray source, const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,
bool reversed, VisitFn visit, FrontierFn make_frontier) {
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
...@@ -126,7 +124,8 @@ void BFSEdges(const GraphInterface& graph, ...@@ -126,7 +124,8 @@ void BFSEdges(const GraphInterface& graph,
} }
make_frontier(); make_frontier();
const auto neighbor_iter = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec; const auto neighbor_iter =
reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
while (!queue->empty()) { while (!queue->empty()) {
const size_t size = queue->size(); const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -165,19 +164,20 @@ void BFSEdges(const GraphInterface& graph, ...@@ -165,19 +164,20 @@ void BFSEdges(const GraphInterface& graph,
* void (*make_frontier)(void); * void (*make_frontier)(void);
* *
* \param graph The graph. * \param graph The graph.
* \param reversed If true, follows the in-edge direction * \param reversed If true, follows the in-edge direction.
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made.
*/ */
template<typename Queue, typename VisitFn, typename FrontierFn> template <typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const GraphInterface& graph, void TopologicalNodes(
bool reversed, const GraphInterface& graph, bool reversed, Queue* queue, VisitFn visit,
Queue* queue, FrontierFn make_frontier) {
VisitFn visit, const auto get_degree =
FrontierFn make_frontier) { reversed ? &GraphInterface::OutDegree : &GraphInterface::InDegree;
const auto get_degree = reversed? &GraphInterface::OutDegree : &GraphInterface::InDegree; const auto neighbor_iter =
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec; reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
uint64_t num_visited_nodes = 0; uint64_t num_visited_nodes = 0;
std::vector<uint64_t> degrees(graph.NumVertices(), 0); std::vector<uint64_t> degrees(graph.NumVertices(), 0);
for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) { for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) {
...@@ -207,7 +207,8 @@ void TopologicalNodes(const GraphInterface& graph, ...@@ -207,7 +207,8 @@ void TopologicalNodes(const GraphInterface& graph,
} }
if (num_visited_nodes != graph.NumVertices()) { if (num_visited_nodes != graph.NumVertices()) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph."; LOG(FATAL)
<< "Error in topological traversal: loop detected in the given graph.";
} }
} }
...@@ -221,30 +222,28 @@ enum DFSEdgeTag { ...@@ -221,30 +222,28 @@ enum DFSEdgeTag {
* \brief Traverse the graph in a depth-first-search (DFS) order. * \brief Traverse the graph in a depth-first-search (DFS) order.
* *
* The traversal visit edges in its DFS order. Edges have three tags: * The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2) * FORWARD(0), REVERSE(1), NONTREE(2).
* *
* A FORWARD edge is one in which `u` has been visisted but `v` has not. * A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge * A REVERSE edge is one in which both `u` and `v` have been visisted and the
* is in the DFS tree. * edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge * been visisted but the edge is NOT in the DFS tree.
* is NOT in the DFS tree.
* *
* \param source Source node. * \param source Source node.
* \param reversed If true, DFS follows the in-edge direction * \param reversed If true, DFS follows the in-edge direction.
* \param has_reverse_edge If true, REVERSE edges are included * \param has_reverse_edge If true, REVERSE edges are included.
* \param has_nontree_edge If true, NONTREE edges are included * \param has_nontree_edge If true, NONTREE edges are included.
* \param visit The function to call when an edge is visited; the edge id and its * \param visit The function to call when an edge is visited; the edge id and
* tag will be given as the arguments. * its tag will be given as the arguments.
*/ */
template<typename VisitFn> template <typename VisitFn>
void DFSLabeledEdges(const GraphInterface& graph, void DFSLabeledEdges(
dgl_id_t source, const GraphInterface& graph, dgl_id_t source, bool reversed,
bool reversed, bool has_reverse_edge, bool has_nontree_edge, VisitFn visit) {
bool has_reverse_edge, const auto succ =
bool has_nontree_edge, reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
VisitFn visit) { const auto out_edge =
const auto succ = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec; reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
const auto out_edge = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
if ((graph.*succ)(source).size() == 0) { if ((graph.*succ)(source).size() == 0) {
// no out-going edges from the source node // no out-going edges from the source node
...@@ -273,7 +272,7 @@ void DFSLabeledEdges(const GraphInterface& graph, ...@@ -273,7 +272,7 @@ void DFSLabeledEdges(const GraphInterface& graph,
stack.pop(); stack.pop();
// find next one. // find next one.
if (i < (graph.*succ)(u).size() - 1) { if (i < (graph.*succ)(u).size() - 1) {
stack.push(std::make_tuple(u, i+1, false)); stack.push(std::make_tuple(u, i + 1, false));
} }
} else { } else {
visited[v] = true; visited[v] = true;
...@@ -291,4 +290,3 @@ void DFSLabeledEdges(const GraphInterface& graph, ...@@ -291,4 +290,3 @@ void DFSLabeledEdges(const GraphInterface& graph,
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_ #endif // DGL_GRAPH_TRAVERSAL_H_
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
* \brief Operations on partition implemented in CUDA. * \brief Operations on partition implemented in CUDA.
*/ */
#include "../partition_op.h"
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include "../../array/cuda/dgl_cub.cuh" #include "../../array/cuda/dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/workspace.h" #include "../../runtime/workspace.h"
#include "../partition_op.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -21,22 +20,21 @@ namespace impl { ...@@ -21,22 +20,21 @@ namespace impl {
namespace { namespace {
/** /**
* @brief Kernel to map global element IDs to partition IDs by remainder. * @brief Kernel to map global element IDs to partition IDs by remainder.
* *
* @tparam IdType The type of ID. * @tparam IdType The type of ID.
* @param global The global element IDs. * @param global The global element IDs.
* @param num_elements The number of element IDs. * @param num_elements The number of element IDs.
* @param num_parts The number of partitions. * @param num_parts The number of partitions.
* @param part_id The mapped partition ID (outupt). * @param part_id The mapped partition ID (outupt).
*/ */
template<typename IdType> template <typename IdType>
__global__ void _MapProcByRemainderKernel( __global__ void _MapProcByRemainderKernel(
const IdType * const global, const IdType* const global, const int64_t num_elements,
const int64_t num_elements, const int64_t num_parts, IdType* const part_id) {
const int64_t num_parts, assert(num_elements <= gridDim.x * blockDim.x);
IdType * const part_id) { const int64_t idx =
assert(num_elements <= gridDim.x*blockDim.x); blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_elements) { if (idx < num_elements) {
part_id[idx] = global[idx] % num_parts; part_id[idx] = global[idx] % num_parts;
...@@ -44,24 +42,23 @@ __global__ void _MapProcByRemainderKernel( ...@@ -44,24 +42,23 @@ __global__ void _MapProcByRemainderKernel(
} }
/** /**
* @brief Kernel to map global element IDs to partition IDs, using a bit-mask. * @brief Kernel to map global element IDs to partition IDs, using a bit-mask.
* The number of partitions must be a power a two. * The number of partitions must be a power a two.
* *
* @tparam IdType The type of ID. * @tparam IdType The type of ID.
* @param global The global element IDs. * @param global The global element IDs.
* @param num_elements The number of element IDs. * @param num_elements The number of element IDs.
* @param mask The bit-mask with 1's for each bit to keep from the element ID to * @param mask The bit-mask with 1's for each bit to keep from the element ID to
* extract the partition ID (e.g., an 8 partition mask would be 0x07). * extract the partition ID (e.g., an 8 partition mask would be 0x07).
* @param part_id The mapped partition ID (outupt). * @param part_id The mapped partition ID (outupt).
*/ */
template<typename IdType> template <typename IdType>
__global__ void _MapProcByMaskRemainderKernel( __global__ void _MapProcByMaskRemainderKernel(
const IdType * const global, const IdType* const global, const int64_t num_elements, const IdType mask,
const int64_t num_elements, IdType* const part_id) {
const IdType mask, assert(num_elements <= gridDim.x * blockDim.x);
IdType * const part_id) { const int64_t idx =
assert(num_elements <= gridDim.x*blockDim.x); blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_elements) { if (idx < num_elements) {
part_id[idx] = global[idx] & mask; part_id[idx] = global[idx] & mask;
...@@ -69,22 +66,20 @@ __global__ void _MapProcByMaskRemainderKernel( ...@@ -69,22 +66,20 @@ __global__ void _MapProcByMaskRemainderKernel(
} }
/** /**
* @brief Kernel to map global element IDs to local element IDs. * @brief Kernel to map global element IDs to local element IDs.
* *
* @tparam IdType The type of ID. * @tparam IdType The type of ID.
* @param global The global element IDs. * @param global The global element IDs.
* @param num_elements The number of IDs. * @param num_elements The number of IDs.
* @param num_parts The number of partitions. * @param num_parts The number of partitions.
* @param local The local element IDs (output). * @param local The local element IDs (output).
*/ */
template<typename IdType> template <typename IdType>
__global__ void _MapLocalIndexByRemainderKernel( __global__ void _MapLocalIndexByRemainderKernel(
const IdType * const global, const IdType* const global, const int64_t num_elements, const int num_parts,
const int64_t num_elements, IdType* const local) {
const int num_parts, assert(num_elements <= gridDim.x * blockDim.x);
IdType * const local) { const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
if (idx < num_elements) { if (idx < num_elements) {
local[idx] = global[idx] / num_parts; local[idx] = global[idx] / num_parts;
...@@ -92,25 +87,22 @@ __global__ void _MapLocalIndexByRemainderKernel( ...@@ -92,25 +87,22 @@ __global__ void _MapLocalIndexByRemainderKernel(
} }
/** /**
* @brief Kernel to map local element IDs within a partition to their global * @brief Kernel to map local element IDs within a partition to their global
* IDs, using the remainder over the number of partitions. * IDs, using the remainder over the number of partitions.
* *
* @tparam IdType The type of ID. * @tparam IdType The type of ID.
* @param local The local element IDs. * @param local The local element IDs.
* @param part_id The partition to map local elements from. * @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map. * @param num_elements The number of elements to map.
* @param num_parts The number of partitions. * @param num_parts The number of partitions.
* @param global The global element IDs (output). * @param global The global element IDs (output).
*/ */
template<typename IdType> template <typename IdType>
__global__ void _MapGlobalIndexByRemainderKernel( __global__ void _MapGlobalIndexByRemainderKernel(
const IdType * const local, const IdType* const local, const int part_id, const int64_t num_elements,
const int part_id, const int num_parts, IdType* const global) {
const int64_t num_elements, assert(num_elements <= gridDim.x * blockDim.x);
const int num_parts, const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
IdType * const global) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
assert(part_id < num_parts); assert(part_id < num_parts);
...@@ -120,125 +112,113 @@ __global__ void _MapGlobalIndexByRemainderKernel( ...@@ -120,125 +112,113 @@ __global__ void _MapGlobalIndexByRemainderKernel(
} }
/** /**
* @brief Device function to perform a binary search to find to which partition a * @brief Device function to perform a binary search to find to which partition
* given ID belongs. * a given ID belongs.
* *
* @tparam RangeType The type of range. * @tparam RangeType The type of range.
* @param range The prefix-sum of IDs assigned to partitions. * @param range The prefix-sum of IDs assigned to partitions.
* @param num_parts The number of partitions. * @param num_parts The number of partitions.
* @param target The element ID to find the partition of. * @param target The element ID to find the partition of.
* *
* @return The partition. * @return The partition.
*/ */
template<typename RangeType> template <typename RangeType>
__device__ RangeType _SearchRange( __device__ RangeType _SearchRange(
const RangeType * const range, const RangeType* const range, const int num_parts, const RangeType target) {
const int num_parts,
const RangeType target) {
int start = 0; int start = 0;
int end = num_parts; int end = num_parts;
int cur = (end+start)/2; int cur = (end + start) / 2;
assert(range[0] == 0); assert(range[0] == 0);
assert(target < range[num_parts]); assert(target < range[num_parts]);
while (start+1 < end) { while (start + 1 < end) {
if (target < range[cur]) { if (target < range[cur]) {
end = cur; end = cur;
} else { } else {
start = cur; start = cur;
} }
cur = (start+end)/2; cur = (start + end) / 2;
} }
return cur; return cur;
} }
/** /**
* @brief Kernel to map element IDs to partition IDs. * @brief Kernel to map element IDs to partition IDs.
* *
* @tparam IdType The type of element ID. * @tparam IdType The type of element ID.
* @tparam RangeType The type of of the range. * @tparam RangeType The type of of the range.
* @param range The prefix-sum of IDs assigned to partitions. * @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs. * @param global The global element IDs.
* @param num_elements The number of element IDs. * @param num_elements The number of element IDs.
* @param num_parts The number of partitions. * @param num_parts The number of partitions.
* @param part_id The partition ID assigned to each element (output). * @param part_id The partition ID assigned to each element (output).
*/ */
template<typename IdType, typename RangeType> template <typename IdType, typename RangeType>
__global__ void _MapProcByRangeKernel( __global__ void _MapProcByRangeKernel(
const RangeType * const range, const RangeType* const range, const IdType* const global,
const IdType * const global, const int64_t num_elements, const int64_t num_parts,
const int64_t num_elements, IdType* const part_id) {
const int64_t num_parts, assert(num_elements <= gridDim.x * blockDim.x);
IdType * const part_id) { const int64_t idx =
assert(num_elements <= gridDim.x*blockDim.x); blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
// rely on caching to load the range into L1 cache // rely on caching to load the range into L1 cache
if (idx < num_elements) { if (idx < num_elements) {
part_id[idx] = static_cast<IdType>(_SearchRange( part_id[idx] = static_cast<IdType>(_SearchRange(
range, range, static_cast<int>(num_parts),
static_cast<int>(num_parts),
static_cast<RangeType>(global[idx]))); static_cast<RangeType>(global[idx])));
} }
} }
/** /**
* @brief Kernel to map global element IDs to their ID within their respective * @brief Kernel to map global element IDs to their ID within their respective
* partition. * partition.
* *
* @tparam IdType The type of element ID. * @tparam IdType The type of element ID.
* @tparam RangeType The type of the range. * @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigned to partitions. * @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs. * @param global The global element IDs.
* @param num_elements The number of elements. * @param num_elements The number of elements.
* @param num_parts The number of partitions. * @param num_parts The number of partitions.
* @param local The local element IDs (output). * @param local The local element IDs (output).
*/ */
template<typename IdType, typename RangeType> template <typename IdType, typename RangeType>
__global__ void _MapLocalIndexByRangeKernel( __global__ void _MapLocalIndexByRangeKernel(
const RangeType * const range, const RangeType* const range, const IdType* const global,
const IdType * const global, const int64_t num_elements, const int num_parts, IdType* const local) {
const int64_t num_elements, assert(num_elements <= gridDim.x * blockDim.x);
const int num_parts, const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
IdType * const local) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
// rely on caching to load the range into L1 cache // rely on caching to load the range into L1 cache
if (idx < num_elements) { if (idx < num_elements) {
const int proc = _SearchRange( const int proc = _SearchRange(
range, range, static_cast<int>(num_parts),
static_cast<int>(num_parts),
static_cast<RangeType>(global[idx])); static_cast<RangeType>(global[idx]));
local[idx] = global[idx] - range[proc]; local[idx] = global[idx] - range[proc];
} }
} }
/** /**
* @brief Kernel to map local element IDs within a partition to their global * @brief Kernel to map local element IDs within a partition to their global
* IDs. * IDs.
* *
* @tparam IdType The type of ID. * @tparam IdType The type of ID.
* @tparam RangeType The type of the range. * @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigend to partitions. * @param range The prefix-sum of IDs assigend to partitions.
* @param local The local element IDs. * @param local The local element IDs.
* @param part_id The partition to map local elements from. * @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map. * @param num_elements The number of elements to map.
* @param num_parts The number of partitions. * @param num_parts The number of partitions.
* @param global The global element IDs (output). * @param global The global element IDs (output).
*/ */
template<typename IdType, typename RangeType> template <typename IdType, typename RangeType>
__global__ void _MapGlobalIndexByRangeKernel( __global__ void _MapGlobalIndexByRangeKernel(
const RangeType * const range, const RangeType* const range, const IdType* const local, const int part_id,
const IdType * const local, const int64_t num_elements, const int num_parts, IdType* const global) {
const int part_id, assert(num_elements <= gridDim.x * blockDim.x);
const int64_t num_elements, const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
const int num_parts,
IdType * const global) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
assert(part_id < num_parts); assert(part_id < num_parts);
...@@ -252,11 +232,8 @@ __global__ void _MapGlobalIndexByRangeKernel( ...@@ -252,11 +232,8 @@ __global__ void _MapGlobalIndexByRangeKernel(
// Remainder Based Partition Operations // Remainder Based Partition Operations
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray> GeneratePermutationFromRemainder(
GeneratePermutationFromRemainder( int64_t array_size, int num_parts, IdArray in_idx) {
int64_t array_size,
int num_parts,
IdArray in_idx) {
std::pair<IdArray, NDArray> result; std::pair<IdArray, NDArray> result;
const auto& ctx = in_idx->ctx; const auto& ctx = in_idx->ctx;
...@@ -265,19 +242,19 @@ GeneratePermutationFromRemainder( ...@@ -265,19 +242,19 @@ GeneratePermutationFromRemainder(
const int64_t num_in = in_idx->shape[0]; const int64_t num_in = in_idx->shape[0];
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts << CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
") must be at least 1."; << ") must be at least 1.";
if (num_parts == 1) { if (num_parts == 1) {
// no permutation // no permutation
result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx); result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx); result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
return result; return result;
} }
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8); result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx); result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
int64_t * out_counts = static_cast<int64_t*>(result.second->data); int64_t* out_counts = static_cast<int64_t*>(result.second->data);
if (num_in == 0) { if (num_in == 0) {
// now that we've zero'd out_counts, nothing left to do for an empty // now that we've zero'd out_counts, nothing left to do for an empty
// mapping // mapping
...@@ -291,21 +268,20 @@ GeneratePermutationFromRemainder( ...@@ -291,21 +268,20 @@ GeneratePermutationFromRemainder(
Workspace<IdType> proc_id_in(device, ctx, num_in); Workspace<IdType> proc_id_in(device, ctx, num_in);
{ {
const dim3 block(256); const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x); const dim3 grid((num_in + block.x - 1) / block.x);
if (num_parts < (1 << part_bits)) { if (num_parts < (1 << part_bits)) {
// num_parts is not a power of 2 // num_parts is not a power of 2
CUDA_KERNEL_CALL(_MapProcByRemainderKernel, grid, block, 0, stream, CUDA_KERNEL_CALL(
static_cast<const IdType*>(in_idx->data), _MapProcByRemainderKernel, grid, block, 0, stream,
num_in, static_cast<const IdType*>(in_idx->data), num_in, num_parts,
num_parts,
proc_id_in.get()); proc_id_in.get());
} else { } else {
// num_parts is a power of 2 // num_parts is a power of 2
CUDA_KERNEL_CALL(_MapProcByMaskRemainderKernel, grid, block, 0, stream, CUDA_KERNEL_CALL(
static_cast<const IdType*>(in_idx->data), _MapProcByMaskRemainderKernel, grid, block, 0, stream,
num_in, static_cast<const IdType*>(in_idx->data), num_in,
static_cast<IdType>(num_parts-1), // bit mask static_cast<IdType>(num_parts - 1), // bit mask
proc_id_in.get()); proc_id_in.get());
} }
} }
...@@ -313,18 +289,20 @@ GeneratePermutationFromRemainder( ...@@ -313,18 +289,20 @@ GeneratePermutationFromRemainder(
// then create a permutation array that groups processors together by // then create a permutation array that groups processors together by
// performing a radix sort // performing a radix sort
Workspace<IdType> proc_id_out(device, ctx, num_in); Workspace<IdType> proc_id_out(device, ctx, num_in);
IdType * perm_out = static_cast<IdType*>(result.first->data); IdType* perm_out = static_cast<IdType*>(result.first->data);
{ {
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx); IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
size_t sort_workspace_size; size_t sort_workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out, nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
num_in, 0, part_bits, stream)); static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
stream));
Workspace<void> sort_workspace(device, ctx, sort_workspace_size); Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out, sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream)); num_in, 0, part_bits, stream));
} }
// explicitly free so workspace can be re-used // explicitly free so workspace can be re-used
...@@ -334,83 +312,58 @@ GeneratePermutationFromRemainder( ...@@ -334,83 +312,58 @@ GeneratePermutationFromRemainder(
// Count the number of values to be sent to each processor // Count the number of values to be sent to each processor
{ {
using AtomicCount = unsigned long long; // NOLINT using AtomicCount = unsigned long long; // NOLINT
static_assert(sizeof(AtomicCount) == sizeof(*out_counts), static_assert(
sizeof(AtomicCount) == sizeof(*out_counts),
"AtomicCount must be the same width as int64_t for atomicAdd " "AtomicCount must be the same width as int64_t for atomicAdd "
"in cub::DeviceHistogram::HistogramEven() to work"); "in cub::DeviceHistogram::HistogramEven() to work");
// TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged, // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
// add a compile time check against the cub version to allow // add a compile time check against the cub version to allow
// num_in > (2 << 31). // num_in > (2 << 31).
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max())) << CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
"number of values to insert into histogram must be less than max " << "number of values to insert into histogram must be less than max "
"value of int."; "value of int.";
size_t hist_workspace_size; size_t hist_workspace_size;
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr, nullptr, hist_workspace_size, proc_id_out.get(),
hist_workspace_size, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
proc_id_out.get(), static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
reinterpret_cast<AtomicCount*>(out_counts), static_cast<int>(num_in), stream));
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size); Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(), hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
hist_workspace_size, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
proc_id_out.get(), static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
reinterpret_cast<AtomicCount*>(out_counts), static_cast<int>(num_in), stream));
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
} }
return result; return result;
} }
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
template std::pair<IdArray, IdArray> kDGLCUDA, int32_t>(int64_t array_size, int num_parts, IdArray in_idx);
GeneratePermutationFromRemainder<kDGLCUDA, int32_t>( template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
int64_t array_size, kDGLCUDA, int64_t>(int64_t array_size, int num_parts, IdArray in_idx);
int num_parts,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDGLCUDA, int64_t>(
int64_t array_size,
int num_parts,
IdArray in_idx);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder( IdArray MapToLocalFromRemainder(const int num_parts, IdArray global_idx) {
const int num_parts,
IdArray global_idx) {
const auto& ctx = global_idx->ctx; const auto& ctx = global_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) { if (num_parts > 1) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx, IdArray local_idx =
sizeof(IdType)*8); aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128); const dim3 block(128);
const dim3 grid((global_idx->shape[0] +block.x-1)/block.x); const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_MapLocalIndexByRemainderKernel, _MapLocalIndexByRemainderKernel, grid, block, 0, stream,
grid, static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
block, num_parts, static_cast<IdType*>(local_idx->data));
0,
stream,
static_cast<const IdType*>(global_idx->data),
global_idx->shape[0],
num_parts,
static_cast<IdType*>(local_idx->data));
return local_idx; return local_idx;
} else { } else {
...@@ -419,45 +372,33 @@ IdArray MapToLocalFromRemainder( ...@@ -419,45 +372,33 @@ IdArray MapToLocalFromRemainder(
} }
} }
template IdArray template IdArray MapToLocalFromRemainder<kDGLCUDA, int32_t>(
MapToLocalFromRemainder<kDGLCUDA, int32_t>( int num_parts, IdArray in_idx);
int num_parts, template IdArray MapToLocalFromRemainder<kDGLCUDA, int64_t>(
IdArray in_idx); int num_parts, IdArray in_idx);
template IdArray
MapToLocalFromRemainder<kDGLCUDA, int64_t>(
int num_parts,
IdArray in_idx);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder( IdArray MapToGlobalFromRemainder(
const int num_parts, const int num_parts, IdArray local_idx, const int part_id) {
IdArray local_idx, CHECK_LT(part_id, num_parts)
const int part_id) { << "Invalid partition id " << part_id << "/" << num_parts;
CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id << CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
"/" << num_parts; << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id <<
"/" << num_parts;
const auto& ctx = local_idx->ctx; const auto& ctx = local_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) { if (num_parts > 1) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx, IdArray global_idx =
sizeof(IdType)*8); aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128); const dim3 block(128);
const dim3 grid((local_idx->shape[0] +block.x-1)/block.x); const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_MapGlobalIndexByRemainderKernel, _MapGlobalIndexByRemainderKernel, grid, block, 0, stream,
grid, static_cast<const IdType*>(local_idx->data), part_id,
block, global_idx->shape[0], num_parts,
0,
stream,
static_cast<const IdType*>(local_idx->data),
part_id,
global_idx->shape[0],
num_parts,
static_cast<IdType*>(global_idx->data)); static_cast<IdType*>(global_idx->data));
return global_idx; return global_idx;
...@@ -467,27 +408,16 @@ IdArray MapToGlobalFromRemainder( ...@@ -467,27 +408,16 @@ IdArray MapToGlobalFromRemainder(
} }
} }
template IdArray template IdArray MapToGlobalFromRemainder<kDGLCUDA, int32_t>(
MapToGlobalFromRemainder<kDGLCUDA, int32_t>( int num_parts, IdArray in_idx, int part_id);
int num_parts, template IdArray MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
IdArray in_idx, int num_parts, IdArray in_idx, int part_id);
int part_id);
template IdArray
MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
int num_parts,
IdArray in_idx,
int part_id);
// Range Based Partition Operations // Range Based Partition Operations
template <DGLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray> GeneratePermutationFromRange(
GeneratePermutationFromRange( int64_t array_size, int num_parts, IdArray range, IdArray in_idx) {
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx) {
std::pair<IdArray, NDArray> result; std::pair<IdArray, NDArray> result;
const auto& ctx = in_idx->ctx; const auto& ctx = in_idx->ctx;
...@@ -496,19 +426,19 @@ GeneratePermutationFromRange( ...@@ -496,19 +426,19 @@ GeneratePermutationFromRange(
const int64_t num_in = in_idx->shape[0]; const int64_t num_in = in_idx->shape[0];
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts << CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
") must be at least 1."; << ") must be at least 1.";
if (num_parts == 1) { if (num_parts == 1) {
// no permutation // no permutation
result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx); result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx); result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
return result; return result;
} }
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8); result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx); result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
int64_t * out_counts = static_cast<int64_t*>(result.second->data); int64_t* out_counts = static_cast<int64_t*>(result.second->data);
if (num_in == 0) { if (num_in == 0) {
// now that we've zero'd out_counts, nothing left to do for an empty // now that we've zero'd out_counts, nothing left to do for an empty
// mapping // mapping
...@@ -522,31 +452,32 @@ GeneratePermutationFromRange( ...@@ -522,31 +452,32 @@ GeneratePermutationFromRange(
Workspace<IdType> proc_id_in(device, ctx, num_in); Workspace<IdType> proc_id_in(device, ctx, num_in);
{ {
const dim3 block(256); const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x); const dim3 grid((num_in + block.x - 1) / block.x);
CUDA_KERNEL_CALL(_MapProcByRangeKernel, grid, block, 0, stream, CUDA_KERNEL_CALL(
_MapProcByRangeKernel, grid, block, 0, stream,
static_cast<const RangeType*>(range->data), static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(in_idx->data), static_cast<const IdType*>(in_idx->data), num_in, num_parts,
num_in,
num_parts,
proc_id_in.get()); proc_id_in.get());
} }
// then create a permutation array that groups processors together by // then create a permutation array that groups processors together by
// performing a radix sort // performing a radix sort
Workspace<IdType> proc_id_out(device, ctx, num_in); Workspace<IdType> proc_id_out(device, ctx, num_in);
IdType * perm_out = static_cast<IdType*>(result.first->data); IdType* perm_out = static_cast<IdType*>(result.first->data);
{ {
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx); IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
size_t sort_workspace_size; size_t sort_workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out, nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
num_in, 0, part_bits, stream)); static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
stream));
Workspace<void> sort_workspace(device, ctx, sort_workspace_size); Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out, sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream)); num_in, 0, part_bits, stream));
} }
// explicitly free so workspace can be re-used // explicitly free so workspace can be re-used
...@@ -556,98 +487,68 @@ GeneratePermutationFromRange( ...@@ -556,98 +487,68 @@ GeneratePermutationFromRange(
// Count the number of values to be sent to each processor // Count the number of values to be sent to each processor
{ {
using AtomicCount = unsigned long long; // NOLINT using AtomicCount = unsigned long long; // NOLINT
static_assert(sizeof(AtomicCount) == sizeof(*out_counts), static_assert(
sizeof(AtomicCount) == sizeof(*out_counts),
"AtomicCount must be the same width as int64_t for atomicAdd " "AtomicCount must be the same width as int64_t for atomicAdd "
"in cub::DeviceHistogram::HistogramEven() to work"); "in cub::DeviceHistogram::HistogramEven() to work");
// TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged, // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
// add a compile time check against the cub version to allow // add a compile time check against the cub version to allow
// num_in > (2 << 31). // num_in > (2 << 31).
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max())) << CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
"number of values to insert into histogram must be less than max " << "number of values to insert into histogram must be less than max "
"value of int."; "value of int.";
size_t hist_workspace_size; size_t hist_workspace_size;
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr, nullptr, hist_workspace_size, proc_id_out.get(),
hist_workspace_size, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
proc_id_out.get(), static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
reinterpret_cast<AtomicCount*>(out_counts), static_cast<int>(num_in), stream));
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size); Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven( CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(), hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
hist_workspace_size, reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
proc_id_out.get(), static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
reinterpret_cast<AtomicCount*>(out_counts), static_cast<int>(num_in), stream));
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
} }
return result; return result;
} }
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>( GeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>(
int64_t array_size, int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>( GeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>(
int64_t array_size, int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>( GeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>(
int64_t array_size, int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray> template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>( GeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>(
int64_t array_size, int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
int num_parts,
IdArray range,
IdArray in_idx);
template <DGLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange( IdArray MapToLocalFromRange(
const int num_parts, const int num_parts, IdArray range, IdArray global_idx) {
IdArray range,
IdArray global_idx) {
const auto& ctx = global_idx->ctx; const auto& ctx = global_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && global_idx->shape[0] > 0) { if (num_parts > 1 && global_idx->shape[0] > 0) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx, IdArray local_idx =
sizeof(IdType)*8); aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128); const dim3 block(128);
const dim3 grid((global_idx->shape[0] +block.x-1)/block.x); const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_MapLocalIndexByRangeKernel, _MapLocalIndexByRangeKernel, grid, block, 0, stream,
grid,
block,
0,
stream,
static_cast<const RangeType*>(range->data), static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(global_idx->data), static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
global_idx->shape[0], num_parts, static_cast<IdType*>(local_idx->data));
num_parts,
static_cast<IdType*>(local_idx->data));
return local_idx; return local_idx;
} else { } else {
...@@ -656,60 +557,38 @@ IdArray MapToLocalFromRange( ...@@ -656,60 +557,38 @@ IdArray MapToLocalFromRange(
} }
} }
template IdArray template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(
MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>( int num_parts, IdArray range, IdArray in_idx);
int num_parts, template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(
IdArray range, int num_parts, IdArray range, IdArray in_idx);
IdArray in_idx); template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
template IdArray int num_parts, IdArray range, IdArray in_idx);
MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>( template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts, int num_parts, IdArray range, IdArray in_idx);
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template <DGLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange( IdArray MapToGlobalFromRange(
const int num_parts, const int num_parts, IdArray range, IdArray local_idx, const int part_id) {
IdArray range, CHECK_LT(part_id, num_parts)
IdArray local_idx, << "Invalid partition id " << part_id << "/" << num_parts;
const int part_id) { CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id << << num_parts;
"/" << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id <<
"/" << num_parts;
const auto& ctx = local_idx->ctx; const auto& ctx = local_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && local_idx->shape[0] > 0) { if (num_parts > 1 && local_idx->shape[0] > 0) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx, IdArray global_idx =
sizeof(IdType)*8); aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128); const dim3 block(128);
const dim3 grid((local_idx->shape[0] +block.x-1)/block.x); const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_MapGlobalIndexByRangeKernel, _MapGlobalIndexByRangeKernel, grid, block, 0, stream,
grid,
block,
0,
stream,
static_cast<const RangeType*>(range->data), static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(local_idx->data), static_cast<const IdType*>(local_idx->data), part_id,
part_id, global_idx->shape[0], num_parts,
global_idx->shape[0],
num_parts,
static_cast<IdType*>(global_idx->data)); static_cast<IdType*>(global_idx->data));
return global_idx; return global_idx;
...@@ -719,31 +598,14 @@ IdArray MapToGlobalFromRange( ...@@ -719,31 +598,14 @@ IdArray MapToGlobalFromRange(
} }
} }
template IdArray template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(
MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>( int num_parts, IdArray range, IdArray in_idx, int part_id);
int num_parts, template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(
IdArray range, int num_parts, IdArray range, IdArray in_idx, int part_id);
IdArray in_idx, template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
int part_id); int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>( int num_parts, IdArray range, IdArray in_idx, int part_id);
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
} // namespace impl } // namespace impl
} // namespace partition } // namespace partition
......
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
#include "ndarray_partition.h" #include "ndarray_partition.h"
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <utility> #include <dgl/runtime/registry.h>
#include <memory> #include <memory>
#include <utility>
#include "partition_op.h" #include "partition_op.h"
...@@ -19,30 +20,21 @@ namespace dgl { ...@@ -19,30 +20,21 @@ namespace dgl {
namespace partition { namespace partition {
NDArrayPartition::NDArrayPartition( NDArrayPartition::NDArrayPartition(
const int64_t array_size, const int num_parts) : const int64_t array_size, const int num_parts)
array_size_(array_size), : array_size_(array_size), num_parts_(num_parts) {}
num_parts_(num_parts) {
}
int64_t NDArrayPartition::ArraySize() const { int64_t NDArrayPartition::ArraySize() const { return array_size_; }
return array_size_;
}
int NDArrayPartition::NumParts() const {
return num_parts_;
}
int NDArrayPartition::NumParts() const { return num_parts_; }
class RemainderPartition : public NDArrayPartition { class RemainderPartition : public NDArrayPartition {
public: public:
RemainderPartition( RemainderPartition(const int64_t array_size, const int num_parts)
const int64_t array_size, const int num_parts) : : NDArrayPartition(array_size, num_parts) {
NDArrayPartition(array_size, num_parts) {
// do nothing // do nothing
} }
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray> GeneratePermutation(
GeneratePermutation(
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
...@@ -55,13 +47,12 @@ class RemainderPartition : public NDArrayPartition { ...@@ -55,13 +47,12 @@ class RemainderPartition : public NDArrayPartition {
#endif #endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet " LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented."; "implemented.";
// should be unreachable // should be unreachable
return std::pair<IdArray, NDArray>{}; return std::pair<IdArray, NDArray>{};
} }
IdArray MapToLocal( IdArray MapToLocal(IdArray in_idx) const override {
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) { if (ctx.device_type == kDGLCUDA) {
...@@ -73,14 +64,12 @@ class RemainderPartition : public NDArrayPartition { ...@@ -73,14 +64,12 @@ class RemainderPartition : public NDArrayPartition {
#endif #endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet " LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented."; "implemented.";
// should be unreachable // should be unreachable
return IdArray{}; return IdArray{};
} }
IdArray MapToGlobal( IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
IdArray in_idx,
const int part_id) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) { if (ctx.device_type == kDGLCUDA) {
...@@ -92,41 +81,39 @@ class RemainderPartition : public NDArrayPartition { ...@@ -92,41 +81,39 @@ class RemainderPartition : public NDArrayPartition {
#endif #endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet " LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented."; "implemented.";
// should be unreachable // should be unreachable
return IdArray{}; return IdArray{};
} }
int64_t PartSize(const int part_id) const override { int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id << ") for " CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
"partition of size " << NumParts() << "."; << ") for "
"partition of size "
<< NumParts() << ".";
return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts()); return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts());
} }
}; };
class RangePartition : public NDArrayPartition { class RangePartition : public NDArrayPartition {
public: public:
RangePartition( RangePartition(const int64_t array_size, const int num_parts, IdArray range)
const int64_t array_size, : NDArrayPartition(array_size, num_parts),
const int num_parts, range_(range),
IdArray range) : // We also need a copy of the range on the CPU, to compute partition
NDArrayPartition(array_size, num_parts), // sizes. We require the input range on the GPU, as if we have multiple
range_(range), // GPUs, we can't know which is the proper one to copy the array to, but
// We also need a copy of the range on the CPU, to compute partition // we have only one CPU context, and can safely copy the array to that.
// sizes. We require the input range on the GPU, as if we have multiple range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {
// GPUs, we can't know which is the proper one to copy the array to, but we
// have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {
auto ctx = range->ctx; auto ctx = range->ctx;
if (ctx.device_type != kDGLCUDA) { if (ctx.device_type != kDGLCUDA) {
LOG(FATAL) << "The range for an NDArrayPartition is only supported " LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before " " on GPUs. Transfer the range to the target device before "
"creating the partition."; "creating the partition.";
} }
} }
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray> GeneratePermutation(
GeneratePermutation(
IdArray in_idx) const override { IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
...@@ -134,11 +121,13 @@ class RangePartition : public NDArrayPartition { ...@@ -134,11 +121,13 @@ class RangePartition : public NDArrayPartition {
if (ctx.device_type != range_->ctx.device_type || if (ctx.device_type != range_->ctx.device_type ||
ctx.device_id != range_->ctx.device_id) { ctx.device_id != range_->ctx.device_id) {
LOG(FATAL) << "The range for the NDArrayPartition and the input " LOG(FATAL) << "The range for the NDArrayPartition and the input "
"array must be on the same device: " << ctx << " vs. " << range_->ctx; "array must be on the same device: "
<< ctx << " vs. " << range_->ctx;
} }
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::GeneratePermutationFromRange<kDGLCUDA, IdType, RangeType>( return impl::GeneratePermutationFromRange<
kDGLCUDA, IdType, RangeType>(
ArraySize(), NumParts(), range_, in_idx); ArraySize(), NumParts(), range_, in_idx);
}); });
}); });
...@@ -146,13 +135,12 @@ class RangePartition : public NDArrayPartition { ...@@ -146,13 +135,12 @@ class RangePartition : public NDArrayPartition {
#endif #endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet " LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented."; "implemented.";
// should be unreachable // should be unreachable
return std::pair<IdArray, NDArray>{}; return std::pair<IdArray, NDArray>{};
} }
IdArray MapToLocal( IdArray MapToLocal(IdArray in_idx) const override {
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) { if (ctx.device_type == kDGLCUDA) {
...@@ -166,14 +154,12 @@ class RangePartition : public NDArrayPartition { ...@@ -166,14 +154,12 @@ class RangePartition : public NDArrayPartition {
#endif #endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet " LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented."; "implemented.";
// should be unreachable // should be unreachable
return IdArray{}; return IdArray{};
} }
IdArray MapToGlobal( IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
IdArray in_idx,
const int part_id) const override {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx; auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) { if (ctx.device_type == kDGLCUDA) {
...@@ -187,17 +173,20 @@ class RangePartition : public NDArrayPartition { ...@@ -187,17 +173,20 @@ class RangePartition : public NDArrayPartition {
#endif #endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet " LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented."; "implemented.";
// should be unreachable // should be unreachable
return IdArray{}; return IdArray{};
} }
int64_t PartSize(const int part_id) const override { int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id << ") for " CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
"partition of size " << NumParts() << "."; << ") for "
"partition of size "
<< NumParts() << ".";
ATEN_ID_TYPE_SWITCH(range_cpu_->dtype, RangeType, { ATEN_ID_TYPE_SWITCH(range_cpu_->dtype, RangeType, {
const RangeType * const ptr = static_cast<const RangeType*>(range_cpu_->data); const RangeType* const ptr =
return ptr[part_id+1]-ptr[part_id]; static_cast<const RangeType*>(range_cpu_->data);
return ptr[part_id + 1] - ptr[part_id];
}); });
} }
...@@ -207,66 +196,58 @@ class RangePartition : public NDArrayPartition { ...@@ -207,66 +196,58 @@ class RangePartition : public NDArrayPartition {
}; };
NDArrayPartitionRef CreatePartitionRemainderBased( NDArrayPartitionRef CreatePartitionRemainderBased(
const int64_t array_size, const int64_t array_size, const int num_parts) {
const int num_parts) { return NDArrayPartitionRef(
return NDArrayPartitionRef(std::make_shared<RemainderPartition>( std::make_shared<RemainderPartition>(array_size, num_parts));
array_size, num_parts));
} }
NDArrayPartitionRef CreatePartitionRangeBased( NDArrayPartitionRef CreatePartitionRangeBased(
const int64_t array_size, const int64_t array_size, const int num_parts, IdArray range) {
const int num_parts, return NDArrayPartitionRef(
IdArray range) { std::make_shared<RangePartition>(array_size, num_parts, range));
return NDArrayPartitionRef(std::make_shared<RangePartition>(
array_size,
num_parts,
range));
} }
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased") DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t array_size = args[0]; int64_t array_size = args[0];
int num_parts = args[1]; int num_parts = args[1];
*rv = CreatePartitionRemainderBased(array_size, num_parts); *rv = CreatePartitionRemainderBased(array_size, num_parts);
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRangeBased") DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRangeBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int64_t array_size = args[0]; const int64_t array_size = args[0];
const int num_parts = args[1]; const int num_parts = args[1];
IdArray range = args[2]; IdArray range = args[2];
*rv = CreatePartitionRangeBased(array_size, num_parts, range);
});
*rv = CreatePartitionRangeBased(array_size, num_parts, range);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize") DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0]; NDArrayPartitionRef part = args[0];
int part_id = args[1]; int part_id = args[1];
*rv = part->PartSize(part_id); *rv = part->PartSize(part_id);
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToLocal") DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToLocal")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0]; NDArrayPartitionRef part = args[0];
IdArray idxs = args[1]; IdArray idxs = args[1];
*rv = part->MapToLocal(idxs); *rv = part->MapToLocal(idxs);
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToGlobal") DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToGlobal")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0]; NDArrayPartitionRef part = args[0];
IdArray idxs = args[1]; IdArray idxs = args[1];
const int part_id = args[2]; const int part_id = args[2];
*rv = part->MapToGlobal(idxs, part_id);
});
*rv = part->MapToGlobal(idxs, part_id);
});
} // namespace partition } // namespace partition
} // namespace dgl } // namespace dgl
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* \file ndarray_partition.h * \file ndarray_partition.h
* \brief DGL utilities for working with the partitioned NDArrays * \brief DGL utilities for working with the partitioned NDArrays
*/ */
#ifndef DGL_PARTITION_NDARRAY_PARTITION_H_ #ifndef DGL_PARTITION_NDARRAY_PARTITION_H_
#define DGL_PARTITION_NDARRAY_PARTITION_H_ #define DGL_PARTITION_NDARRAY_PARTITION_H_
#include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/object.h>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -28,9 +28,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -28,9 +28,7 @@ class NDArrayPartition : public runtime::Object {
* @param array_size The first dimension of the partitioned array. * @param array_size The first dimension of the partitioned array.
* @param num_parts The number parts to the array is split into. * @param num_parts The number parts to the array is split into.
*/ */
NDArrayPartition( NDArrayPartition(int64_t array_size, int num_parts);
int64_t array_size,
int num_parts);
virtual ~NDArrayPartition() = default; virtual ~NDArrayPartition() = default;
...@@ -50,8 +48,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -50,8 +48,7 @@ class NDArrayPartition : public runtime::Object {
* @return A pair containing 0) the permutation to re-order the indices by * @return A pair containing 0) the permutation to re-order the indices by
* partition, 1) the number of indices per partition (int64_t). * partition, 1) the number of indices per partition (int64_t).
*/ */
virtual std::pair<IdArray, NDArray> virtual std::pair<IdArray, NDArray> GeneratePermutation(
GeneratePermutation(
IdArray in_idx) const = 0; IdArray in_idx) const = 0;
/** /**
...@@ -62,8 +59,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -62,8 +59,7 @@ class NDArrayPartition : public runtime::Object {
* *
* @return The local indices. * @return The local indices.
*/ */
virtual IdArray MapToLocal( virtual IdArray MapToLocal(IdArray in_idx) const = 0;
IdArray in_idx) const = 0;
/** /**
* @brief Generate the global indices (the numbering unique across all * @brief Generate the global indices (the numbering unique across all
...@@ -74,9 +70,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -74,9 +70,7 @@ class NDArrayPartition : public runtime::Object {
* *
* @return The global indices. * @return The global indices.
*/ */
virtual IdArray MapToGlobal( virtual IdArray MapToGlobal(IdArray in_idx, int part_id) const = 0;
IdArray in_idx,
int part_id) const = 0;
/** /**
* @brief Get the number of rows/items assigned to the given part. * @brief Get the number of rows/items assigned to the given part.
...@@ -85,8 +79,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -85,8 +79,7 @@ class NDArrayPartition : public runtime::Object {
* *
* @return The size. * @return The size.
*/ */
virtual int64_t PartSize( virtual int64_t PartSize(int part_id) const = 0;
int part_id) const = 0;
/** /**
* @brief Get the first dimension of the partitioned array. * @brief Get the first dimension of the partitioned array.
...@@ -119,9 +112,7 @@ DGL_DEFINE_OBJECT_REF(NDArrayPartitionRef, NDArrayPartition); ...@@ -119,9 +112,7 @@ DGL_DEFINE_OBJECT_REF(NDArrayPartitionRef, NDArrayPartition);
* @return The partition object. * @return The partition object.
*/ */
NDArrayPartitionRef CreatePartitionRemainderBased( NDArrayPartitionRef CreatePartitionRemainderBased(
int64_t array_size, int64_t array_size, int num_parts);
int num_parts);
/** /**
* @brief Create a new partition object, using the range (exclusive prefix-sum) * @brief Create a new partition object, using the range (exclusive prefix-sum)
...@@ -136,9 +127,7 @@ NDArrayPartitionRef CreatePartitionRemainderBased( ...@@ -136,9 +127,7 @@ NDArrayPartitionRef CreatePartitionRemainderBased(
* @return The partition object. * @return The partition object.
*/ */
NDArrayPartitionRef CreatePartitionRangeBased( NDArrayPartitionRef CreatePartitionRangeBased(
int64_t array_size, int64_t array_size, int num_parts, IdArray range);
int num_parts,
IdArray range);
} // namespace partition } // namespace partition
} // namespace dgl } // namespace dgl
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
* \brief DGL utilities for working with the partitioned NDArrays * \brief DGL utilities for working with the partitioned NDArrays
*/ */
#ifndef DGL_PARTITION_PARTITION_OP_H_ #ifndef DGL_PARTITION_PARTITION_OP_H_
#define DGL_PARTITION_PARTITION_OP_H_ #define DGL_PARTITION_PARTITION_OP_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -33,11 +33,8 @@ namespace impl { ...@@ -33,11 +33,8 @@ namespace impl {
* indices in each part. * indices in each part.
*/ */
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> std::pair<IdArray, IdArray> GeneratePermutationFromRemainder(
GeneratePermutationFromRemainder( int64_t array_size, int num_parts, IdArray in_idx);
int64_t array_size,
int num_parts,
IdArray in_idx);
/** /**
* @brief Generate the set of local indices from the global indices, using * @brief Generate the set of local indices from the global indices, using
...@@ -52,9 +49,7 @@ GeneratePermutationFromRemainder( ...@@ -52,9 +49,7 @@ GeneratePermutationFromRemainder(
* @return The array of local indices. * @return The array of local indices.
*/ */
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder( IdArray MapToLocalFromRemainder(int num_parts, IdArray global_idx);
int num_parts,
IdArray global_idx);
/** /**
* @brief Generate the set of global indices from the local indices, using * @brief Generate the set of global indices from the local indices, using
...@@ -70,10 +65,7 @@ IdArray MapToLocalFromRemainder( ...@@ -70,10 +65,7 @@ IdArray MapToLocalFromRemainder(
* @return The array of global indices. * @return The array of global indices.
*/ */
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder( IdArray MapToGlobalFromRemainder(int num_parts, IdArray local_idx, int part_id);
int num_parts,
IdArray local_idx,
int part_id);
/** /**
* @brief Create a permutation that groups indices by the part id when used for * @brief Create a permutation that groups indices by the part id when used for
...@@ -96,12 +88,8 @@ IdArray MapToGlobalFromRemainder( ...@@ -96,12 +88,8 @@ IdArray MapToGlobalFromRemainder(
* indices in each part. * indices in each part.
*/ */
template <DGLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, IdArray> std::pair<IdArray, IdArray> GeneratePermutationFromRange(
GeneratePermutationFromRange( int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
/** /**
* @brief Generate the set of local indices from the global indices, using * @brief Generate the set of local indices from the global indices, using
...@@ -119,10 +107,7 @@ GeneratePermutationFromRange( ...@@ -119,10 +107,7 @@ GeneratePermutationFromRange(
* @return The array of local indices. * @return The array of local indices.
*/ */
template <DGLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange( IdArray MapToLocalFromRange(int num_parts, IdArray range, IdArray global_idx);
int num_parts,
IdArray range,
IdArray global_idx);
/** /**
* @brief Generate the set of global indices from the local indices, using * @brief Generate the set of global indices from the local indices, using
...@@ -142,12 +127,7 @@ IdArray MapToLocalFromRange( ...@@ -142,12 +127,7 @@ IdArray MapToLocalFromRange(
*/ */
template <DGLDeviceType XPU, typename IdType, typename RangeType> template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange( IdArray MapToGlobalFromRange(
int num_parts, int num_parts, IdArray range, IdArray local_idx, int part_id);
IdArray range,
IdArray local_idx,
int part_id);
} // namespace impl } // namespace impl
} // namespace partition } // namespace partition
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
* \brief Random number generator interfaces * \brief Random number generator interfaces
*/ */
#include <dmlc/omp.h> #include <dgl/array.h>
#include <dgl/runtime/registry.h> #include <dgl/random.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/random.h> #include <dgl/runtime/registry.h>
#include <dgl/array.h> #include <dmlc/omp.h>
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#include "../runtime/cuda/cuda_common.h" #include "../runtime/cuda/cuda_common.h"
...@@ -20,53 +20,57 @@ using namespace dgl::runtime; ...@@ -20,53 +20,57 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const int seed = args[0]; const int seed = args[0];
runtime::parallel_for(0, omp_get_max_threads(), [&](size_t b, size_t e) { runtime::parallel_for(0, omp_get_max_threads(), [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
RandomEngine::ThreadLocal()->SetSeed(seed); RandomEngine::ThreadLocal()->SetSeed(seed);
} }
}); });
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
if (DeviceAPI::Get(kDGLCUDA)->IsAvailable()) { if (DeviceAPI::Get(kDGLCUDA)->IsAvailable()) {
auto* thr_entry = CUDAThreadEntry::ThreadLocal(); auto *thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) { if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT)); CURAND_CALL(curandCreateGenerator(
&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen, static_cast<uint64_t>(seed)));
} }
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
}
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}); });
DGL_REGISTER_GLOBAL("rng._CAPI_Choice") DGL_REGISTER_GLOBAL("rng._CAPI_Choice")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const int64_t num = args[0]; const int64_t num = args[0];
const int64_t population = args[1]; const int64_t population = args[1];
const NDArray prob = args[2]; const NDArray prob = args[2];
const bool replace = args[3]; const bool replace = args[3];
const int bits = args[4]; const int bits = args[4];
CHECK(bits == 32 || bits == 64) CHECK(bits == 32 || bits == 64)
<< "Supported bit widths are 32 and 64, but got " << bits << "."; << "Supported bit widths are 32 and 64, but got " << bits << ".";
if (aten::IsNullArray(prob)) { if (aten::IsNullArray(prob)) {
if (bits == 32) { if (bits == 32) {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int32_t>(num, population, replace); *rv = RandomEngine::ThreadLocal()->UniformChoice<int32_t>(
num, population, replace);
} else {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int64_t>(
num, population, replace);
}
} else { } else {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int64_t>(num, population, replace); if (bits == 32) {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int32_t, FloatType>(
num, prob, replace);
});
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int64_t, FloatType>(
num, prob, replace);
});
}
} }
} else { });
if (bits == 32) {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int32_t, FloatType>(num, prob, replace);
});
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int64_t, FloatType>(num, prob, replace);
});
}
}
});
}; // namespace dgl }; // namespace dgl
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_RPC_NET_TYPE_H_ #define DGL_RPC_NET_TYPE_H_
#include <string> #include <string>
#include "rpc_msg.h" #include "rpc_msg.h"
namespace dgl { namespace dgl {
...@@ -29,11 +30,12 @@ struct RPCBase { ...@@ -29,11 +30,12 @@ struct RPCBase {
struct RPCSender : RPCBase { struct RPCSender : RPCBase {
/*! /*!
* \brief Connect to a receiver. * \brief Connect to a receiver.
* *
* When there are multiple receivers to be connected, application will call `ConnectReceiver` * When there are multiple receivers to be connected, application will call
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* successfully established or some of them fail. * sure that either all the connections are successfully established or some
* * of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091' * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return True for success and False for fail * \return True for success and False for fail
...@@ -49,13 +51,11 @@ struct RPCSender : RPCBase { ...@@ -49,13 +51,11 @@ struct RPCSender : RPCBase {
* *
* The function is *not* thread-safe; only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/ */
virtual bool ConnectReceiverFinalize(const int max_try_times) { virtual bool ConnectReceiverFinalize(const int max_try_times) { return true; }
return true;
}
/*! /*!
* \brief Send RPCMessage to specified Receiver. * \brief Send RPCMessage to specified Receiver.
* \param msg data message * \param msg data message
* \param recv_id receiver's ID * \param recv_id receiver's ID
*/ */
virtual void Send(const RPCMessage &msg, int recv_id) = 0; virtual void Send(const RPCMessage &msg, int recv_id) = 0;
...@@ -71,13 +71,14 @@ struct RPCReceiver : RPCBase { ...@@ -71,13 +71,14 @@ struct RPCReceiver : RPCBase {
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
virtual bool Wait(const std::string &addr, int num_sender, virtual bool Wait(
bool blocking = true) = 0; const std::string &addr, int num_sender, bool blocking = true) = 0;
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut. * \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
virtual RPCStatus Recv(RPCMessage *msg, int timeout) = 0; virtual RPCStatus Recv(RPCMessage *msg, int timeout) = 0;
......
...@@ -17,7 +17,8 @@ namespace network { ...@@ -17,7 +17,8 @@ namespace network {
// In most cases, delim contains only one character. In this case, we // In most cases, delim contains only one character. In this case, we
// use CalculateReserveForVector to count the number of elements should // use CalculateReserveForVector to count the number of elements should
// be reserved in result vector, and thus optimize SplitStringUsing. // be reserved in result vector, and thus optimize SplitStringUsing.
static int CalculateReserveForVector(const std::string& full, const char* delim) { static int CalculateReserveForVector(
const std::string& full, const char* delim) {
int count = 0; int count = 0;
if (delim[0] != '\0' && delim[1] == '\0') { if (delim[0] != '\0' && delim[1] == '\0') {
// Optimize the common case where delim is a single character. // Optimize the common case where delim is a single character.
...@@ -38,19 +39,18 @@ static int CalculateReserveForVector(const std::string& full, const char* delim) ...@@ -38,19 +39,18 @@ static int CalculateReserveForVector(const std::string& full, const char* delim)
return count; return count;
} }
void SplitStringUsing(const std::string& full, void SplitStringUsing(
const char* delim, const std::string& full, const char* delim,
std::vector<std::string>* result) { std::vector<std::string>* result) {
CHECK(delim != NULL); CHECK(delim != NULL);
CHECK(result != NULL); CHECK(result != NULL);
result->reserve(CalculateReserveForVector(full, delim)); result->reserve(CalculateReserveForVector(full, delim));
back_insert_iterator< std::vector<std::string> > it(*result); back_insert_iterator<std::vector<std::string> > it(*result);
SplitStringToIteratorUsing(full, delim, &it); SplitStringToIteratorUsing(full, delim, &it);
} }
void SplitStringToSetUsing(const std::string& full, void SplitStringToSetUsing(
const char* delim, const std::string& full, const char* delim, std::set<std::string>* result) {
std::set<std::string>* result) {
CHECK(delim != NULL); CHECK(delim != NULL);
CHECK(result != NULL); CHECK(result != NULL);
simple_insert_iterator<std::set<std::string> > it(result); simple_insert_iterator<std::set<std::string> > it(result);
......
...@@ -33,19 +33,18 @@ namespace network { ...@@ -33,19 +33,18 @@ namespace network {
// substrings[2] == "banana" // substrings[2] == "banana"
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
void SplitStringUsing(const std::string& full, void SplitStringUsing(
const char* delim, const std::string& full, const char* delim,
std::vector<std::string>* result); std::vector<std::string>* result);
// This function has the same semnatic as SplitStringUsing. Results // This function has the same semnatic as SplitStringUsing. Results
// are saved in an STL set container. // are saved in an STL set container.
void SplitStringToSetUsing(const std::string& full, void SplitStringToSetUsing(
const char* delim, const std::string& full, const char* delim, std::set<std::string>* result);
std::set<std::string>* result);
template <typename T> template <typename T>
struct simple_insert_iterator { struct simple_insert_iterator {
explicit simple_insert_iterator(T* t) : t_(t) { } explicit simple_insert_iterator(T* t) : t_(t) {}
simple_insert_iterator<T>& operator=(const typename T::value_type& value) { simple_insert_iterator<T>& operator=(const typename T::value_type& value) {
t_->insert(value); t_->insert(value);
...@@ -76,10 +75,8 @@ struct back_insert_iterator { ...@@ -76,10 +75,8 @@ struct back_insert_iterator {
}; };
template <typename StringType, typename ITR> template <typename StringType, typename ITR>
static inline static inline void SplitStringToIteratorUsing(
void SplitStringToIteratorUsing(const StringType& full, const StringType& full, const char* delim, ITR* result) {
const char* delim,
ITR* result) {
CHECK_NOTNULL(delim); CHECK_NOTNULL(delim);
// Optimize the common case where delim is a single character. // Optimize the common case where delim is a single character.
if (delim[0] != '\0' && delim[1] == '\0') { if (delim[0] != '\0' && delim[1] == '\0') {
......
...@@ -19,16 +19,17 @@ namespace network { ...@@ -19,16 +19,17 @@ namespace network {
/*! /*!
* \brief Network Sender for DGL distributed training. * \brief Network Sender for DGL distributed training.
* *
* Sender is an abstract class that defines a set of APIs for sending binary * Sender is an abstract class that defines a set of APIs for sending binary
* data message over network. It can be implemented by different underlying * data message over network. It can be implemented by different underlying
* networking libraries such TCP socket and MPI. One Sender can connect to * networking libraries such TCP socket and MPI. One Sender can connect to
* multiple receivers and it can send data to specified receiver via receiver's ID. * multiple receivers and it can send data to specified receiver via receiver's
* ID.
*/ */
class Sender : public rpc::RPCSender { class Sender : public rpc::RPCSender {
public: public:
/*! /*!
* \brief Sender constructor * \brief Sender constructor
* \param queue_size size (bytes) of message queue. * \param queue_size size (bytes) of message queue.
* \param max_thread_count size of thread pool. 0 for no limit * \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional. * Note that, the queue_size parameter is optional.
*/ */
...@@ -47,12 +48,12 @@ class Sender : public rpc::RPCSender { ...@@ -47,12 +48,12 @@ class Sender : public rpc::RPCSender {
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return Status code * \return Status code
* *
* (1) The send is non-blocking. There is no guarantee that the message has been * (1) The send is non-blocking. There is no guarantee that the message has
* physically sent out when the function returns. * been physically sent out when the function returns. (2) The communicator
* (2) The communicator will assume the responsibility of the given message. * will assume the responsibility of the given message. (3) The API is
* (3) The API is multi-thread safe. * multi-thread safe. (4) Messages sent to the same receiver are guaranteed to
* (4) Messages sent to the same receiver are guaranteed to be received in the same order. * be received in the same order. There is no guarantee for messages sent to
* There is no guarantee for messages sent to different receivers. * different receivers.
*/ */
virtual STATUS Send(Message msg, int recv_id) = 0; virtual STATUS Send(Message msg, int recv_id) = 0;
...@@ -70,10 +71,11 @@ class Sender : public rpc::RPCSender { ...@@ -70,10 +71,11 @@ class Sender : public rpc::RPCSender {
/*! /*!
* \brief Network Receiver for DGL distributed training. * \brief Network Receiver for DGL distributed training.
* *
* Receiver is an abstract class that defines a set of APIs for receiving binary data * Receiver is an abstract class that defines a set of APIs for receiving binary
* message over network. It can be implemented by different underlying networking * data message over network. It can be implemented by different underlying
* libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders * networking libraries such as TCP socket and MPI. One Receiver can connect
* and it can receive data from multiple Senders concurrently. * with multiple Senders and it can receive data from multiple Senders
* concurrently.
*/ */
class Receiver : public rpc::RPCReceiver { class Receiver : public rpc::RPCReceiver {
public: public:
...@@ -98,11 +100,13 @@ class Receiver : public rpc::RPCReceiver { ...@@ -98,11 +100,13 @@ class Receiver : public rpc::RPCReceiver {
* \brief Recv data from Sender * \brief Recv data from Sender
* \param msg pointer of data message * \param msg pointer of data message
* \param send_id which sender current msg comes from * \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code * \return Status code
* *
* (1) The Recv() API is thread-safe. * (1) The Recv() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns. * (2) Memory allocated by communicator but will not own it after the function
* returns.
*/ */
virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0; virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0;
...@@ -110,11 +114,13 @@ class Receiver : public rpc::RPCReceiver { ...@@ -110,11 +114,13 @@ class Receiver : public rpc::RPCReceiver {
* \brief Recv data from a specified Sender * \brief Recv data from a specified Sender
* \param msg pointer of data message * \param msg pointer of data message
* \param send_id sender's ID * \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code * \return Status code
* *
* (1) The RecvFrom() API is thread-safe. * (1) The RecvFrom() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns. * (2) Memory allocated by communicator but will not own it after the function
* returns.
*/ */
virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0; virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0;
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
* \file msg_queue.cc * \file msg_queue.cc
* \brief Message queue for DGL distributed training. * \brief Message queue for DGL distributed training.
*/ */
#include "msg_queue.h"
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <cstring>
#include "msg_queue.h" #include <cstring>
namespace dgl { namespace dgl {
namespace network { namespace network {
...@@ -38,9 +39,7 @@ STATUS MessageQueue::Add(Message msg, bool is_blocking) { ...@@ -38,9 +39,7 @@ STATUS MessageQueue::Add(Message msg, bool is_blocking) {
if (msg.size > free_size_ && !is_blocking) { if (msg.size > free_size_ && !is_blocking) {
return QUEUE_FULL; return QUEUE_FULL;
} }
cond_not_full_.wait(lock, [&]() { cond_not_full_.wait(lock, [&]() { return msg.size <= free_size_; });
return msg.size <= free_size_;
});
// Add data pointer to queue // Add data pointer to queue
queue_.push(msg); queue_.push(msg);
free_size_ -= msg.size; free_size_ -= msg.size;
...@@ -61,9 +60,8 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) { ...@@ -61,9 +60,8 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
} }
} }
cond_not_empty_.wait(lock, [this] { cond_not_empty_.wait(
return !queue_.empty() || exit_flag_.load(); lock, [this] { return !queue_.empty() || exit_flag_.load(); });
});
if (finished_producers_.size() >= num_producers_ && queue_.empty()) { if (finished_producers_.size() >= num_producers_ && queue_.empty()) {
return QUEUE_CLOSE; return QUEUE_CLOSE;
} }
...@@ -98,8 +96,7 @@ bool MessageQueue::Empty() const { ...@@ -98,8 +96,7 @@ bool MessageQueue::Empty() const {
bool MessageQueue::EmptyAndNoMoreAdd() const { bool MessageQueue::EmptyAndNoMoreAdd() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return queue_.size() == 0 && return queue_.size() == 0 && finished_producers_.size() >= num_producers_;
finished_producers_.size() >= num_producers_;
} }
} // namespace network } // namespace network
......
...@@ -8,14 +8,14 @@ ...@@ -8,14 +8,14 @@
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue> #include <queue>
#include <set> #include <set>
#include <string> #include <string>
#include <utility> // for pair #include <utility> // for pair
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <functional>
namespace dgl { namespace dgl {
namespace network { namespace network {
...@@ -25,13 +25,13 @@ typedef int STATUS; ...@@ -25,13 +25,13 @@ typedef int STATUS;
/*! /*!
* \brief Status code of message queue * \brief Status code of message queue
*/ */
#define ADD_SUCCESS 3400 // Add message successfully #define ADD_SUCCESS 3400 // Add message successfully
#define MSG_GT_SIZE 3401 // Message size beyond queue size #define MSG_GT_SIZE 3401 // Message size beyond queue size
#define MSG_LE_ZERO 3402 // Message size is not a positive number #define MSG_LE_ZERO 3402 // Message size is not a positive number
#define QUEUE_CLOSE 3403 // Cannot add message when queue is closed #define QUEUE_CLOSE 3403 // Cannot add message when queue is closed
#define QUEUE_FULL 3404 // Cannot add message when queue is full #define QUEUE_FULL 3404 // Cannot add message when queue is full
#define REMOVE_SUCCESS 3405 // Remove message successfully #define REMOVE_SUCCESS 3405 // Remove message successfully
#define QUEUE_EMPTY 3406 // Cannot remove when queue is empty #define QUEUE_EMPTY 3406 // Cannot remove when queue is empty
/*! /*!
* \brief Message used by network communicator and message queue. * \brief Message used by network communicator and message queue.
...@@ -40,13 +40,13 @@ struct Message { ...@@ -40,13 +40,13 @@ struct Message {
/*! /*!
* \brief Constructor * \brief Constructor
*/ */
Message() { } Message() {}
/*! /*!
* \brief Constructor * \brief Constructor
*/ */
Message(char* data_ptr, int64_t data_size) Message(char* data_ptr, int64_t data_size)
: data(data_ptr), size(data_size) { } : data(data_ptr), size(data_size) {}
/*! /*!
* \brief message data * \brief message data
...@@ -69,22 +69,23 @@ struct Message { ...@@ -69,22 +69,23 @@ struct Message {
/*! /*!
* \brief Free memory buffer of message * \brief Free memory buffer of message
*/ */
inline void DefaultMessageDeleter(Message* msg) { delete [] msg->data; } inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }
/*! /*!
* \brief Message Queue for network communication. * \brief Message Queue for network communication.
* *
* MessageQueue is FIFO queue that adopts producer/consumer model for data message. * MessageQueue is FIFO queue that adopts producer/consumer model for data
* It supports one or more producer threads and one or more consumer threads. * message. It supports one or more producer threads and one or more consumer
* Producers invokes Add() to push data message into the queue, and consumers * threads. Producers invokes Add() to push data message into the queue, and
* invokes Remove() to pop data message from queue. Add() and Remove() use two condition * consumers invokes Remove() to pop data message from queue. Add() and Remove()
* variables to synchronize producer threads and consumer threads. Each producer * use two condition variables to synchronize producer threads and consumer
* invokes SignalFinished(producer_id) to claim that it is about to finish, where * threads. Each producer invokes SignalFinished(producer_id) to claim that it
* producer_id is an integer uniquely identify a producer thread. This signaling mechanism * is about to finish, where producer_id is an integer uniquely identify a
* prevents consumers from waiting after all producers have finished their jobs. * producer thread. This signaling mechanism prevents consumers from waiting
* after all producers have finished their jobs.
* *
* MessageQueue is thread-safe. * MessageQueue is thread-safe.
* *
*/ */
class MessageQueue { class MessageQueue {
public: public:
...@@ -93,8 +94,8 @@ class MessageQueue { ...@@ -93,8 +94,8 @@ class MessageQueue {
* \param queue_size size (bytes) of message queue * \param queue_size size (bytes) of message queue
* \param num_producers number of producers, use 1 by default * \param num_producers number of producers, use 1 by default
*/ */
explicit MessageQueue(int64_t queue_size /* in bytes */, explicit MessageQueue(
int num_producers = 1); int64_t queue_size /* in bytes */, int num_producers = 1);
/*! /*!
* \brief MessageQueue deconstructor * \brief MessageQueue deconstructor
...@@ -134,48 +135,48 @@ class MessageQueue { ...@@ -134,48 +135,48 @@ class MessageQueue {
bool EmptyAndNoMoreAdd() const; bool EmptyAndNoMoreAdd() const;
protected: protected:
/*! /*!
* \brief message queue * \brief message queue
*/ */
std::queue<Message> queue_; std::queue<Message> queue_;
/*! /*!
* \brief Size of the queue in bytes * \brief Size of the queue in bytes
*/ */
int64_t queue_size_; int64_t queue_size_;
/*! /*!
* \brief Free size of the queue * \brief Free size of the queue
*/ */
int64_t free_size_; int64_t free_size_;
/*! /*!
* \brief Used to check all producers will no longer produce anything * \brief Used to check all producers will no longer produce anything
*/ */
size_t num_producers_; size_t num_producers_;
/*! /*!
* \brief Store finished producer id * \brief Store finished producer id
*/ */
std::set<int /* producer_id */> finished_producers_; std::set<int /* producer_id */> finished_producers_;
/*! /*!
* \brief Condition when consumer should wait * \brief Condition when consumer should wait
*/ */
std::condition_variable cond_not_full_; std::condition_variable cond_not_full_;
/*! /*!
* \brief Condition when producer should wait * \brief Condition when producer should wait
*/ */
std::condition_variable cond_not_empty_; std::condition_variable cond_not_empty_;
/*! /*!
* \brief Signal for exit wait * \brief Signal for exit wait
*/ */
std::atomic<bool> exit_flag_{false}; std::atomic<bool> exit_flag_{false};
/*! /*!
* \brief Protect all above data and conditions * \brief Protect all above data and conditions
*/ */
mutable std::mutex mutex_; mutable std::mutex mutex_;
}; };
......
...@@ -3,29 +3,29 @@ ...@@ -3,29 +3,29 @@
* \file communicator.cc * \file communicator.cc
* \brief SocketCommunicator for DGL distributed training. * \brief SocketCommunicator for DGL distributed training.
*/ */
#include <dmlc/logging.h> #include "socket_communicator.h"
#include <string.h> #include <dmlc/logging.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h>
#include <time.h> #include <time.h>
#include <memory> #include <memory>
#include "socket_communicator.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "socket_pool.h" #include "socket_pool.h"
#ifdef _WIN32 #ifdef _WIN32
#include <windows.h> #include <windows.h>
#else // !_WIN32 #else // !_WIN32
#include <unistd.h> #include <unistd.h>
#endif // _WIN32 #endif // _WIN32
namespace dgl { namespace dgl {
namespace network { namespace network {
/////////////////////////////////////// SocketSender
/////////////////////////////////////// SocketSender /////////////////////////////////////////// //////////////////////////////////////////////
bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) { bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
if (recv_id < 0) { if (recv_id < 0) {
...@@ -92,9 +92,7 @@ bool SocketSender::ConnectReceiverFinalize(const int max_try_times) { ...@@ -92,9 +92,7 @@ bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_)); msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
// Create a new thread for this socket connection // Create a new thread for this socket connection
threads_.push_back(std::make_shared<std::thread>( threads_.push_back(std::make_shared<std::thread>(
SendLoop, SendLoop, sockets_[thread_id], msg_queue_[thread_id]));
sockets_[thread_id],
msg_queue_[thread_id]));
} }
return true; return true;
...@@ -105,14 +103,13 @@ void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) { ...@@ -105,14 +103,13 @@ void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true); StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
zc_write_strm.Write(msg); zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size(); int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob->append(reinterpret_cast<char*>(&nonempty_ndarray_count), zerocopy_blob->append(
sizeof(int32_t)); reinterpret_cast<char*>(&nonempty_ndarray_count), sizeof(int32_t));
Message rpc_meta_msg; Message rpc_meta_msg;
rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data()); rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
rpc_meta_msg.size = zerocopy_blob->size(); rpc_meta_msg.size = zerocopy_blob->size();
rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {}; rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
CHECK_EQ(Send( CHECK_EQ(Send(rpc_meta_msg, recv_id), ADD_SUCCESS);
rpc_meta_msg, recv_id), ADD_SUCCESS);
// send real ndarray data // send real ndarray data
for (auto ptr : zc_write_strm.buffer_list()) { for (auto ptr : zc_write_strm.buffer_list()) {
Message ndarray_data_msg; Message ndarray_data_msg;
...@@ -123,8 +120,7 @@ void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) { ...@@ -123,8 +120,7 @@ void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
ndarray_data_msg.size = ptr.size; ndarray_data_msg.size = ptr.size;
NDArray tensor = ptr.tensor; NDArray tensor = ptr.tensor;
ndarray_data_msg.deallocator = [tensor](Message*) {}; ndarray_data_msg.deallocator = [tensor](Message*) {};
CHECK_EQ(Send( CHECK_EQ(Send(ndarray_data_msg, recv_id), ADD_SUCCESS);
ndarray_data_msg, recv_id), ADD_SUCCESS);
} }
} }
...@@ -156,7 +152,7 @@ void SocketSender::Finalize() { ...@@ -156,7 +152,7 @@ void SocketSender::Finalize() {
} }
// Clear all sockets // Clear all sockets
for (auto& group_sockets_ : sockets_) { for (auto& group_sockets_ : sockets_) {
for (auto &socket : group_sockets_) { for (auto& socket : group_sockets_) {
socket.second->Close(); socket.second->Close();
} }
} }
...@@ -168,9 +164,8 @@ void SendCore(Message msg, TCPSocket* socket) { ...@@ -168,9 +164,8 @@ void SendCore(Message msg, TCPSocket* socket) {
int64_t sent_bytes = 0; int64_t sent_bytes = 0;
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) { while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes; int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = socket->Send( int64_t tmp =
reinterpret_cast<char*>(&msg.size) + sent_bytes, socket->Send(reinterpret_cast<char*>(&msg.size) + sent_bytes, max_len);
max_len);
CHECK_NE(tmp, -1); CHECK_NE(tmp, -1);
sent_bytes += tmp; sent_bytes += tmp;
} }
...@@ -178,7 +173,7 @@ void SendCore(Message msg, TCPSocket* socket) { ...@@ -178,7 +173,7 @@ void SendCore(Message msg, TCPSocket* socket) {
sent_bytes = 0; sent_bytes = 0;
while (sent_bytes < msg.size) { while (sent_bytes < msg.size) {
int64_t max_len = msg.size - sent_bytes; int64_t max_len = msg.size - sent_bytes;
int64_t tmp = socket->Send(msg.data+sent_bytes, max_len); int64_t tmp = socket->Send(msg.data + sent_bytes, max_len);
CHECK_NE(tmp, -1); CHECK_NE(tmp, -1);
sent_bytes += tmp; sent_bytes += tmp;
} }
...@@ -189,8 +184,8 @@ void SendCore(Message msg, TCPSocket* socket) { ...@@ -189,8 +184,8 @@ void SendCore(Message msg, TCPSocket* socket) {
} }
void SocketSender::SendLoop( void SocketSender::SendLoop(
std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets, std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue) { std::shared_ptr<MessageQueue> queue) {
for (;;) { for (;;) {
Message msg; Message msg;
STATUS code = queue->Remove(&msg); STATUS code = queue->Remove(&msg);
...@@ -205,8 +200,10 @@ void SocketSender::SendLoop( ...@@ -205,8 +200,10 @@ void SocketSender::SendLoop(
} }
} }
/////////////////////////////////////// SocketReceiver /////////////////////////////////////////// /////////////////////////////////////// SocketReceiver
bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking) { //////////////////////////////////////////////
bool SocketReceiver::Wait(
const std::string& addr, int num_sender, bool blocking) {
CHECK_GT(num_sender, 0); CHECK_GT(num_sender, 0);
CHECK_EQ(blocking, true); CHECK_EQ(blocking, true);
std::vector<std::string> substring; std::vector<std::string> substring;
...@@ -231,7 +228,7 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking ...@@ -231,7 +228,7 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
num_sender_ = num_sender; num_sender_ = num_sender;
#ifdef USE_EPOLL #ifdef USE_EPOLL
if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) { if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
max_thread_count_ = num_sender_; max_thread_count_ = num_sender_;
} }
#else #else
max_thread_count_ = num_sender_; max_thread_count_ = num_sender_;
...@@ -256,7 +253,8 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking ...@@ -256,7 +253,8 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
auto socket = std::make_shared<TCPSocket>(); auto socket = std::make_shared<TCPSocket>();
sockets_[thread_id][i] = socket; sockets_[thread_id][i] = socket;
msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_); msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) { if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) ==
false) {
LOG(WARNING) << "Error on accept socket."; LOG(WARNING) << "Error on accept socket.";
return false; return false;
} }
...@@ -266,10 +264,7 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking ...@@ -266,10 +264,7 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) { for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
// create new thread for each socket // create new thread for each socket
threads_.push_back(std::make_shared<std::thread>( threads_.push_back(std::make_shared<std::thread>(
RecvLoop, RecvLoop, sockets_[thread_id], msg_queue_, &queue_sem_));
sockets_[thread_id],
msg_queue_,
&queue_sem_));
} }
return true; return true;
...@@ -285,7 +280,7 @@ rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) { ...@@ -285,7 +280,7 @@ rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
return rpc::kRPCTimeOut; return rpc::kRPCTimeOut;
} }
CHECK_EQ(status, REMOVE_SUCCESS); CHECK_EQ(status, REMOVE_SUCCESS);
char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t); char* count_ptr = rpc_meta_msg.data + rpc_meta_msg.size - sizeof(int32_t);
int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr)); int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
// Recv real ndarray data // Recv real ndarray data
std::vector<void*> buffer_list(nonempty_ndarray_count); std::vector<void*> buffer_list(nonempty_ndarray_count);
...@@ -305,7 +300,8 @@ rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) { ...@@ -305,7 +300,8 @@ rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
CHECK_EQ(status, REMOVE_SUCCESS); CHECK_EQ(status, REMOVE_SUCCESS);
buffer_list[i] = ndarray_data_msg.data; buffer_list[i] = ndarray_data_msg.data;
} }
StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list); StreamWithBuffer zc_read_strm(
rpc_meta_msg.data, rpc_meta_msg.size - sizeof(int32_t), buffer_list);
zc_read_strm.Read(msg); zc_read_strm.Read(msg);
rpc_meta_msg.deallocator(&rpc_meta_msg); rpc_meta_msg.deallocator(&rpc_meta_msg);
return rpc::kRPCSuccess; return rpc::kRPCSuccess;
...@@ -352,7 +348,7 @@ void SocketReceiver::Finalize() { ...@@ -352,7 +348,7 @@ void SocketReceiver::Finalize() {
for (auto& mq : msg_queue_) { for (auto& mq : msg_queue_) {
// wait until queue is empty // wait until queue is empty
while (mq.second->Empty() == false) { while (mq.second->Empty() == false) {
std::this_thread::sleep_for(std::chrono::seconds(1)); std::this_thread::sleep_for(std::chrono::seconds(1));
} }
mq.second->SignalFinished(mq.first); mq.second->SignalFinished(mq.first);
} }
...@@ -376,8 +372,7 @@ int64_t RecvDataSize(TCPSocket* socket) { ...@@ -376,8 +372,7 @@ int64_t RecvDataSize(TCPSocket* socket) {
while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) { while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - received_bytes; int64_t max_len = sizeof(int64_t) - received_bytes;
int64_t tmp = socket->Receive( int64_t tmp = socket->Receive(
reinterpret_cast<char*>(&data_size) + received_bytes, reinterpret_cast<char*>(&data_size) + received_bytes, max_len);
max_len);
if (tmp == -1) { if (tmp == -1) {
if (received_bytes > 0) { if (received_bytes > 0) {
// We want to finish reading full data_size // We want to finish reading full data_size
...@@ -390,8 +385,9 @@ int64_t RecvDataSize(TCPSocket* socket) { ...@@ -390,8 +385,9 @@ int64_t RecvDataSize(TCPSocket* socket) {
return data_size; return data_size;
} }
void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size, void RecvData(
int64_t *received_bytes) { TCPSocket* socket, char* buffer, const int64_t& data_size,
int64_t* received_bytes) {
while (*received_bytes < data_size) { while (*received_bytes < data_size) {
int64_t max_len = data_size - *received_bytes; int64_t max_len = data_size - *received_bytes;
int64_t tmp = socket->Receive(buffer + *received_bytes, max_len); int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
...@@ -404,15 +400,17 @@ void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size, ...@@ -404,15 +400,17 @@ void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
} }
void SocketReceiver::RecvLoop( void SocketReceiver::RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */, std::unordered_map<
std::shared_ptr<TCPSocket>> sockets, int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
std::unordered_map<int /* Sender (virtual) ID */, sockets,
std::shared_ptr<MessageQueue>> queues, std::unordered_map<
runtime::Semaphore *queue_sem) { int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
queues,
runtime::Semaphore* queue_sem) {
std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts; std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
SocketPool socket_pool; SocketPool socket_pool;
for (auto& socket : sockets) { for (auto& socket : sockets) {
auto &sender_id = socket.first; auto& sender_id = socket.first;
socket_pool.AddSocket(socket.second, sender_id); socket_pool.AddSocket(socket.second, sender_id);
recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext()); recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
} }
...@@ -432,9 +430,9 @@ void SocketReceiver::RecvLoop( ...@@ -432,9 +430,9 @@ void SocketReceiver::RecvLoop(
// Nonblocking socket might be interrupted at any point. So we need to // Nonblocking socket might be interrupted at any point. So we need to
// store the partially received data // store the partially received data
std::unique_ptr<RecvContext> &ctx = recv_contexts[sender_id]; std::unique_ptr<RecvContext>& ctx = recv_contexts[sender_id];
int64_t &data_size = ctx->data_size; int64_t& data_size = ctx->data_size;
int64_t &received_bytes = ctx->received_bytes; int64_t& received_bytes = ctx->received_bytes;
char*& buffer = ctx->buffer; char*& buffer = ctx->buffer;
if (data_size == -1) { if (data_size == -1) {
...@@ -443,7 +441,7 @@ void SocketReceiver::RecvLoop( ...@@ -443,7 +441,7 @@ void SocketReceiver::RecvLoop(
if (data_size > 0) { if (data_size > 0) {
try { try {
buffer = new char[data_size]; buffer = new char[data_size];
} catch(const std::bad_alloc&) { } catch (const std::bad_alloc&) {
LOG(FATAL) << "Cannot allocate enough memory for message, " LOG(FATAL) << "Cannot allocate enough memory for message, "
<< "(message size: " << data_size << ")"; << "(message size: " << data_size << ")";
} }
......
...@@ -6,22 +6,23 @@ ...@@ -6,22 +6,23 @@
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_ #ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_ #define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#include <thread> #include <memory>
#include <vector>
#include <string> #include <string>
#include <thread>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <vector>
#include "../../runtime/semaphore_wrapper.h" #include "../../runtime/semaphore_wrapper.h"
#include "common.h"
#include "communicator.h" #include "communicator.h"
#include "msg_queue.h" #include "msg_queue.h"
#include "tcp_socket.h" #include "tcp_socket.h"
#include "common.h"
namespace dgl { namespace dgl {
namespace network { namespace network {
static constexpr int kTimeOut = 10 * 60; // 10 minutes (in seconds) for socket timeout static constexpr int kTimeOut =
10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024 static constexpr int kMaxConnection = 1024; // maximal connection: 1024
/*! /*!
...@@ -41,19 +42,20 @@ class SocketSender : public Sender { ...@@ -41,19 +42,20 @@ class SocketSender : public Sender {
public: public:
/*! /*!
* \brief Sender constructor * \brief Sender constructor
* \param queue_size size of message queue * \param queue_size size of message queue
* \param max_thread_count size of thread pool. 0 for no limit * \param max_thread_count size of thread pool. 0 for no limit
*/ */
SocketSender(int64_t queue_size, int max_thread_count) SocketSender(int64_t queue_size, int max_thread_count)
: Sender(queue_size, max_thread_count) {} : Sender(queue_size, max_thread_count) {}
/*! /*!
* \brief Connect to a receiver. * \brief Connect to a receiver.
* *
* When there are multiple receivers to be connected, application will call `ConnectReceiver` * When there are multiple receivers to be connected, application will call
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* successfully established or some of them fail. * sure that either all the connections are successfully established or some
* * of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091' * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return True for success and False for fail * \return True for success and False for fail
...@@ -73,7 +75,7 @@ class SocketSender : public Sender { ...@@ -73,7 +75,7 @@ class SocketSender : public Sender {
/*! /*!
* \brief Send RPCMessage to specified Receiver. * \brief Send RPCMessage to specified Receiver.
* \param msg data message * \param msg data message
* \param recv_id receiver's ID * \param recv_id receiver's ID
*/ */
void Send(const rpc::RPCMessage& msg, int recv_id) override; void Send(const rpc::RPCMessage& msg, int recv_id) override;
...@@ -86,60 +88,63 @@ class SocketSender : public Sender { ...@@ -86,60 +88,63 @@ class SocketSender : public Sender {
/*! /*!
* \brief Communicator type: 'socket' * \brief Communicator type: 'socket'
*/ */
const std::string &NetType() const override { const std::string& NetType() const override {
static const std::string net_type = "socket"; static const std::string net_type = "socket";
return net_type; return net_type;
} }
/*! /*!
* \brief Send data to specified Receiver. Actually pushing message to message queue. * \brief Send data to specified Receiver. Actually pushing message to message
* \param msg data message * queue.
* \param recv_id receiver's ID * \param msg data message.
* \return Status code * \param recv_id receiver's ID.
* \return Status code.
* *
* (1) The send is non-blocking. There is no guarantee that the message has been * (1) The send is non-blocking. There is no guarantee that the message has
* physically sent out when the function returns. * been physically sent out when the function returns. (2) The communicator
* (2) The communicator will assume the responsibility of the given message. * will assume the responsibility of the given message. (3) The API is
* (3) The API is multi-thread safe. * multi-thread safe. (4) Messages sent to the same receiver are guaranteed to
* (4) Messages sent to the same receiver are guaranteed to be received in the same order. * be received in the same order. There is no guarantee for messages sent to
* There is no guarantee for messages sent to different receivers. * different receivers.
*/ */
STATUS Send(Message msg, int recv_id) override; STATUS Send(Message msg, int recv_id) override;
private: private:
/*! /*!
* \brief socket for each connection of receiver * \brief socket for each connection of receiver
*/ */
std::vector<std::unordered_map<int /* receiver ID */, std::vector<
std::shared_ptr<TCPSocket>>> sockets_; std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>>
sockets_;
/*! /*!
* \brief receivers' address * \brief receivers' address
*/ */
std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_; std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*! /*!
* \brief message queue for each thread * \brief message queue for each thread
*/ */
std::vector<std::shared_ptr<MessageQueue>> msg_queue_; std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
/*! /*!
* \brief Independent thread * \brief Independent thread
*/ */
std::vector<std::shared_ptr<std::thread>> threads_; std::vector<std::shared_ptr<std::thread>> threads_;
/*! /*!
* \brief Send-loop for each thread * \brief Send-loop for each thread
* \param sockets TCPSockets for current thread * \param sockets TCPSockets for current thread
* \param queue message_queue for current thread * \param queue message_queue for current thread
* *
* Note that, the SendLoop will finish its loop-job and exit thread * Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue. * when the main thread invokes Signal() API on the message queue.
*/ */
static void SendLoop( static void SendLoop(
std::unordered_map<int /* Receiver (virtual) ID */, std::unordered_map<
std::shared_ptr<TCPSocket>> sockets, int /* Receiver (virtual) ID */, std::shared_ptr<TCPSocket>>
std::shared_ptr<MessageQueue> queue); sockets,
std::shared_ptr<MessageQueue> queue);
}; };
/*! /*!
...@@ -155,7 +160,7 @@ class SocketReceiver : public Receiver { ...@@ -155,7 +160,7 @@ class SocketReceiver : public Receiver {
* \param max_thread_count size of thread pool. 0 for no limit * \param max_thread_count size of thread pool. 0 for no limit
*/ */
SocketReceiver(int64_t queue_size, int max_thread_count) SocketReceiver(int64_t queue_size, int max_thread_count)
: Receiver(queue_size, max_thread_count) {} : Receiver(queue_size, max_thread_count) {}
/*! /*!
* \brief Wait for all the Senders to connect * \brief Wait for all the Senders to connect
...@@ -166,13 +171,14 @@ class SocketReceiver : public Receiver { ...@@ -166,13 +171,14 @@ class SocketReceiver : public Receiver {
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
bool Wait(const std::string &addr, int num_sender, bool Wait(
bool blocking = true) override; const std::string& addr, int num_sender, bool blocking = true) override;
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut. * \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override; rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
...@@ -181,23 +187,28 @@ class SocketReceiver : public Receiver { ...@@ -181,23 +187,28 @@ class SocketReceiver : public Receiver {
* \brief Recv data from Sender. Actually removing data from msg_queue. * \brief Recv data from Sender. Actually removing data from msg_queue.
* \param msg pointer of data message * \param msg pointer of data message
* \param send_id which sender current msg comes from * \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code * \return Status code
* *
* (1) The Recv() API is thread-safe. * (1) The Recv() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns. * (2) Memory allocated by communicator but will not own it after the function
* returns.
*/ */
STATUS Recv(Message* msg, int* send_id, int timeout = 0) override; STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
/*! /*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue. * \brief Recv data from a specified Sender. Actually removing data from
* \param msg pointer of data message * msg_queue.
* \param msg pointer of data message.
* \param send_id sender's ID * \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code * \return Status code
* *
* (1) The RecvFrom() API is thread-safe. * (1) The RecvFrom() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns. * (2) Memory allocated by communicator but will not own it after the function
* returns.
*/ */
STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override; STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
...@@ -211,7 +222,7 @@ class SocketReceiver : public Receiver { ...@@ -211,7 +222,7 @@ class SocketReceiver : public Receiver {
/*! /*!
* \brief Communicator type: 'socket' * \brief Communicator type: 'socket'
*/ */
const std::string &NetType() const override { const std::string& NetType() const override {
static const std::string net_type = "socket"; static const std::string net_type = "socket";
return net_type; return net_type;
} }
...@@ -220,7 +231,7 @@ class SocketReceiver : public Receiver { ...@@ -220,7 +231,7 @@ class SocketReceiver : public Receiver {
struct RecvContext { struct RecvContext {
int64_t data_size = -1; int64_t data_size = -1;
int64_t received_bytes = 0; int64_t received_bytes = 0;
char *buffer = nullptr; char* buffer = nullptr;
}; };
/*! /*!
* \brief number of sender * \brief number of sender
...@@ -229,25 +240,27 @@ class SocketReceiver : public Receiver { ...@@ -229,25 +240,27 @@ class SocketReceiver : public Receiver {
/*! /*!
* \brief server socket for listening connections * \brief server socket for listening connections
*/ */
TCPSocket* server_socket_; TCPSocket* server_socket_;
/*! /*!
* \brief socket for each client connections * \brief socket for each client connections
*/ */
std::vector<std::unordered_map<int /* Sender (virutal) ID */, std::vector<std::unordered_map<
std::shared_ptr<TCPSocket>>> sockets_; int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>>
sockets_;
/*! /*!
* \brief Message queue for each socket connection * \brief Message queue for each socket connection
*/ */
std::unordered_map<int /* Sender (virtual) ID */, std::unordered_map<
std::shared_ptr<MessageQueue>> msg_queue_; int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
msg_queue_;
std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_; std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
/*! /*!
* \brief Independent thead * \brief Independent thead
*/ */
std::vector<std::shared_ptr<std::thread>> threads_; std::vector<std::shared_ptr<std::thread>> threads_;
/*! /*!
...@@ -263,13 +276,15 @@ class SocketReceiver : public Receiver { ...@@ -263,13 +276,15 @@ class SocketReceiver : public Receiver {
* *
* Note that, the RecvLoop will finish its loop-job and exit thread * Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue. * when the main thread invokes Signal() API on the message queue.
*/ */
static void RecvLoop( static void RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */, std::unordered_map<
std::shared_ptr<TCPSocket>> sockets, int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
std::unordered_map<int /* Sender (virtual) ID */, sockets,
std::shared_ptr<MessageQueue>> queues, std::unordered_map<
runtime::Semaphore *queue_sem); int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
queues,
runtime::Semaphore* queue_sem);
}; };
} // namespace network } // namespace network
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "socket_pool.h" #include "socket_pool.h"
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include "tcp_socket.h" #include "tcp_socket.h"
#ifdef USE_EPOLL #ifdef USE_EPOLL
...@@ -24,8 +25,8 @@ SocketPool::SocketPool() { ...@@ -24,8 +25,8 @@ SocketPool::SocketPool() {
#endif #endif
} }
void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id, void SocketPool::AddSocket(
int events) { std::shared_ptr<TCPSocket> socket, int socket_id, int events) {
int fd = socket->Socket(); int fd = socket->Socket();
tcp_sockets_[fd] = socket; tcp_sockets_[fd] = socket;
socket_ids_[fd] = socket_id; socket_ids_[fd] = socket_id;
...@@ -47,7 +48,7 @@ void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id, ...@@ -47,7 +48,7 @@ void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
#else #else
if (tcp_sockets_.size() > 1) { if (tcp_sockets_.size() > 1) {
LOG(FATAL) << "SocketPool supports only one socket if not use epoll." LOG(FATAL) << "SocketPool supports only one socket if not use epoll."
"Please turn on USE_EPOLL on building"; "Please turn on USE_EPOLL on building";
} }
#endif #endif
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_ #ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_ #define DGL_RPC_NETWORK_SOCKET_POOL_H_
#include <unordered_map>
#include <queue>
#include <memory> #include <memory>
#include <queue>
#include <unordered_map>
namespace dgl { namespace dgl {
namespace network { namespace network {
...@@ -19,7 +19,7 @@ class TCPSocket; ...@@ -19,7 +19,7 @@ class TCPSocket;
* \brief SocketPool maintains a group of nonblocking sockets, and can provide * \brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets. * active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification * Currently SocketPool is based on epoll, a scalable I/O event notification
* mechanism in Linux operating system. * mechanism in Linux operating system.
*/ */
class SocketPool { class SocketPool {
public: public:
...@@ -42,8 +42,8 @@ class SocketPool { ...@@ -42,8 +42,8 @@ class SocketPool {
* \param socket_id receiver/sender id of the socket * \param socket_id receiver/sender id of the socket
* \param events READ, WRITE or READ + WRITE * \param events READ, WRITE or READ + WRITE
*/ */
void AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id, void AddSocket(
int events = READ); std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ);
/*! /*!
* \brief Remove socket from SocketPool * \brief Remove socket from SocketPool
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include <sys/socket.h> #include <sys/socket.h>
#include <unistd.h> #include <unistd.h>
#endif // !_WIN32 #endif // !_WIN32
#include <string.h>
#include <errno.h> #include <errno.h>
#include <string.h>
namespace dgl { namespace dgl {
namespace network { namespace network {
...@@ -31,28 +31,29 @@ TCPSocket::TCPSocket() { ...@@ -31,28 +31,29 @@ TCPSocket::TCPSocket() {
LOG(FATAL) << "Can't create new socket. Error: " << strerror(errno); LOG(FATAL) << "Can't create new socket. Error: " << strerror(errno);
} }
#ifndef _WIN32 #ifndef _WIN32
// This is to make sure the same port can be reused right after the socket is closed. // This is to make sure the same port can be reused right after the socket is
// closed.
int enable = 1; int enable = 1;
if (setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0) { if (setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0) {
LOG(WARNING) << "cannot make the socket reusable. Error: " << strerror(errno); LOG(WARNING) << "cannot make the socket reusable. Error: "
<< strerror(errno);
} }
#endif // _WIN32 #endif // _WIN32
} }
TCPSocket::~TCPSocket() { TCPSocket::~TCPSocket() { Close(); }
Close();
}
bool TCPSocket::Connect(const char * ip, int port) { bool TCPSocket::Connect(const char *ip, int port) {
SAI sa_server; SAI sa_server;
sa_server.sin_family = AF_INET; sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port); sa_server.sin_port = htons(port);
int retval = 0; int retval = 0;
do { // retry if EINTR failure appears do { // retry if EINTR failure appears
if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) && if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
0 <= (retval = connect(socket_, reinterpret_cast<SA*>(&sa_server), 0 <= (retval = connect(
sizeof(sa_server)))) { socket_, reinterpret_cast<SA *>(&sa_server),
sizeof(sa_server)))) {
return true; return true;
} }
} while (retval == -1 && errno == EINTR); } while (retval == -1 && errno == EINTR);
...@@ -60,10 +61,10 @@ bool TCPSocket::Connect(const char * ip, int port) { ...@@ -60,10 +61,10 @@ bool TCPSocket::Connect(const char * ip, int port) {
return false; return false;
} }
bool TCPSocket::Bind(const char * ip, int port) { bool TCPSocket::Bind(const char *ip, int port) {
SAI sa_server; SAI sa_server;
sa_server.sin_family = AF_INET; sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port); sa_server.sin_port = htons(port);
int ret = 0; int ret = 0;
ret = inet_pton(AF_INET, ip, &sa_server.sin_addr); ret = inet_pton(AF_INET, ip, &sa_server.sin_addr);
if (ret == 0) { if (ret == 0) {
...@@ -75,13 +76,15 @@ bool TCPSocket::Bind(const char * ip, int port) { ...@@ -75,13 +76,15 @@ bool TCPSocket::Bind(const char * ip, int port) {
return false; return false;
} }
do { // retry if EINTR failure appears do { // retry if EINTR failure appears
if (0 <= (ret = bind(socket_, reinterpret_cast<SA *>(&sa_server), if (0 <=
sizeof(sa_server)))) { (ret = bind(
socket_, reinterpret_cast<SA *>(&sa_server), sizeof(sa_server)))) {
return true; return true;
} }
} while (ret == -1 && errno == EINTR); } while (ret == -1 && errno == EINTR);
LOG(ERROR) << "Failed bind on " << ip << ":" << port << " , error: " << strerror(errno); LOG(ERROR) << "Failed bind on " << ip << ":" << port
<< " , error: " << strerror(errno);
return false; return false;
} }
...@@ -93,17 +96,18 @@ bool TCPSocket::Listen(int max_connection) { ...@@ -93,17 +96,18 @@ bool TCPSocket::Listen(int max_connection) {
} }
} while (retval == -1 && errno == EINTR); } while (retval == -1 && errno == EINTR);
LOG(ERROR) << "Failed listen on socket fd: " << socket_ << " , error: " << strerror(errno); LOG(ERROR) << "Failed listen on socket fd: " << socket_
<< " , error: " << strerror(errno);
return false; return false;
} }
bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) { bool TCPSocket::Accept(TCPSocket *socket, std::string *ip, int *port) {
int sock_client; int sock_client;
SAI sa_client; SAI sa_client;
socklen_t len = sizeof(sa_client); socklen_t len = sizeof(sa_client);
do { // retry if EINTR failure appears do { // retry if EINTR failure appears
sock_client = accept(socket_, reinterpret_cast<SA*>(&sa_client), &len); sock_client = accept(socket_, reinterpret_cast<SA *>(&sa_client), &len);
} while (sock_client == -1 && errno == EINTR); } while (sock_client == -1 && errno == EINTR);
if (sock_client < 0) { if (sock_client < 0) {
...@@ -114,10 +118,8 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) { ...@@ -114,10 +118,8 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
} }
char tmp[INET_ADDRSTRLEN]; char tmp[INET_ADDRSTRLEN];
const char * ip_client = inet_ntop(AF_INET, const char *ip_client =
&sa_client.sin_addr, inet_ntop(AF_INET, &sa_client.sin_addr, tmp, sizeof(tmp));
tmp,
sizeof(tmp));
CHECK(ip_client != nullptr); CHECK(ip_client != nullptr);
ip->assign(ip_client); ip->assign(ip_client);
*port = ntohs(sa_client.sin_port); *port = ntohs(sa_client.sin_port);
...@@ -166,22 +168,20 @@ bool TCPSocket::SetNonBlocking(bool flag) { ...@@ -166,22 +168,20 @@ bool TCPSocket::SetNonBlocking(bool flag) {
#endif // _WIN32 #endif // _WIN32
void TCPSocket::SetTimeout(int timeout) { void TCPSocket::SetTimeout(int timeout) {
#ifdef _WIN32 #ifdef _WIN32
timeout = timeout * 1000; // WIN API accepts millsec timeout = timeout * 1000; // WIN API accepts millsec
setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, setsockopt(
reinterpret_cast<char*>(&timeout), sizeof(timeout)); socket_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&timeout),
#else // !_WIN32 sizeof(timeout));
struct timeval tv; #else // !_WIN32
tv.tv_sec = timeout; struct timeval tv;
tv.tv_usec = 0; tv.tv_sec = timeout;
setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, tv.tv_usec = 0;
&tv, sizeof(tv)); setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
#endif // _WIN32 #endif // _WIN32
} }
bool TCPSocket::ShutDown(int ways) { bool TCPSocket::ShutDown(int ways) { return 0 == shutdown(socket_, ways); }
return 0 == shutdown(socket_, ways);
}
void TCPSocket::Close() { void TCPSocket::Close() {
if (socket_ >= 0) { if (socket_ >= 0) {
...@@ -194,7 +194,7 @@ void TCPSocket::Close() { ...@@ -194,7 +194,7 @@ void TCPSocket::Close() {
} }
} }
int64_t TCPSocket::Send(const char * data, int64_t len_data) { int64_t TCPSocket::Send(const char *data, int64_t len_data) {
int64_t number_send; int64_t number_send;
do { // retry if EINTR failure appears do { // retry if EINTR failure appears
...@@ -207,7 +207,7 @@ int64_t TCPSocket::Send(const char * data, int64_t len_data) { ...@@ -207,7 +207,7 @@ int64_t TCPSocket::Send(const char * data, int64_t len_data) {
return number_send; return number_send;
} }
int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) { int64_t TCPSocket::Receive(char *buffer, int64_t size_buffer) {
int64_t number_recv; int64_t number_recv;
do { // retry if EINTR failure appears do { // retry if EINTR failure appears
...@@ -220,9 +220,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) { ...@@ -220,9 +220,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
return number_recv; return number_recv;
} }
int TCPSocket::Socket() const { int TCPSocket::Socket() const { return socket_; }
return socket_;
}
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment