Unverified Commit 8f0df39e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4810)



* [Misc] clang-format auto fix.

* manual

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 401e1278
......@@ -4,15 +4,17 @@
* \brief Convert multigraphs to simple graphs
*/
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <dgl/transform.h>
#include <utility>
#include <vector>
#include "../../c_api_common.h"
#include "../heterograph.h"
#include "../unit_graph.h"
#include "../../c_api_common.h"
namespace dgl {
......@@ -25,7 +27,8 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToSimpleGraph(const HeteroGraphPtr graph) {
const int64_t num_etypes = graph->NumEdgeTypes();
const auto metagraph = graph->meta_graph();
const auto &ugs = std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();
const auto &ugs =
std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();
std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes);
std::vector<HeteroGraphPtr> rel_graphs(num_etypes);
......@@ -35,31 +38,31 @@ ToSimpleGraph(const HeteroGraphPtr graph) {
std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result;
}
const HeteroGraphPtr result = CreateHeteroGraph(
metagraph, rel_graphs, graph->NumVerticesPerType());
const HeteroGraphPtr result =
CreateHeteroGraph(metagraph, rel_graphs, graph->NumVerticesPerType());
return std::make_tuple(result, counts, edge_maps);
}
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0];
.set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0];
const auto result = ToSimpleGraph(graph_ref.sptr());
const auto result = ToSimpleGraph(graph_ref.sptr());
List<Value> counts, edge_maps;
for (const IdArray &count : std::get<1>(result))
counts.push_back(Value(MakeValue(count)));
for (const IdArray &edge_map : std::get<2>(result))
edge_maps.push_back(Value(MakeValue(edge_map)));
List<Value> counts, edge_maps;
for (const IdArray &count : std::get<1>(result))
counts.push_back(Value(MakeValue(count)));
for (const IdArray &edge_map : std::get<2>(result))
edge_maps.push_back(Value(MakeValue(edge_map)));
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(std::get<0>(result)));
ret.push_back(counts);
ret.push_back(edge_maps);
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(std::get<0>(result)));
ret.push_back(counts);
ret.push_back(edge_maps);
*rv = ret;
});
*rv = ret;
});
}; // namespace transform
......
......@@ -9,8 +9,9 @@ using namespace dgl::runtime;
namespace dgl {
HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list has at least two graphs";
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0)
<< "Input graph list has at least two graphs";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
......@@ -24,18 +25,21 @@ HeteroGraphPtr JointUnionHeteroGraph(
HeteroGraphPtr rgptr = nullptr;
// ALL = CSC | CSR | COO
const dgl_format_code_t code =\
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
const dgl_format_code_t code =
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// get common format
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype)) << "Input graph[" << i <<
"] should have same number of src vertices as input graph[0]";
CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype)) << "Input graph[" << i <<
"] should have same number of dst vertices as input graph[0]";
const dgl_format_code_t curr_code = cg->GetRelationGraph(etype)->GetAllowedFormats();
CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype))
<< "Input graph[" << i
<< "] should have same number of src vertices as input graph[0]";
CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype))
<< "Input graph[" << i
<< "] should have same number of dst vertices as input graph[0]";
const dgl_format_code_t curr_code =
cg->GetRelationGraph(etype)->GetAllowedFormats();
if (curr_code != code)
LOG(FATAL) << "All components should have the same formats";
}
......@@ -50,8 +54,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
}
aten::COOMatrix res = aten::UnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr =
UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
......@@ -61,8 +65,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
}
aten::CSRMatrix res = aten::UnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr =
UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs;
......@@ -73,8 +77,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
}
aten::CSRMatrix res = aten::UnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr =
UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
}
rel_graphs[etype] = rgptr;
......@@ -82,7 +86,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
num_nodes_per_type[dst_vtype] = num_dst_v;
}
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
return CreateHeteroGraph(
meta_graph, rel_graphs, std::move(num_nodes_per_type));
}
HeteroGraphPtr DisjointUnionHeteroGraph2(
......@@ -94,8 +99,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
// Loop over all ntypes
for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) {
uint64_t offset = 0;
for (const auto &cg : component_graphs)
offset += cg->NumVertices(vtype);
for (const auto& cg : component_graphs) offset += cg->NumVertices(vtype);
num_nodes_per_type[vtype] = offset;
}
......@@ -106,11 +110,12 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
const dgl_type_t dst_vtype = pair.second;
HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code =\
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
const dgl_format_code_t code =
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// do some preprocess
for (const auto &cg : component_graphs) {
const dgl_format_code_t cur_code = cg->GetRelationGraph(etype)->GetAllowedFormats();
for (const auto& cg : component_graphs) {
const dgl_format_code_t cur_code =
cg->GetRelationGraph(etype)->GetAllowedFormats();
if (cur_code != code)
LOG(FATAL) << "All components should have the same formats";
}
......@@ -118,51 +123,55 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
// prefer COO
if (FORMAT_HAS_COO(code)) {
std::vector<aten::COOMatrix> coos;
for (const auto &cg : component_graphs) {
for (const auto& cg : component_graphs) {
aten::COOMatrix coo = cg->GetCOOMatrix(etype);
coos.push_back(coo);
}
aten::COOMatrix res = aten::DisjointUnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr =
UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs;
for (const auto &cg : component_graphs) {
for (const auto& cg : component_graphs) {
aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
csrs.push_back(csr);
}
aten::CSRMatrix res = aten::DisjointUnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr =
UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs;
for (const auto &cg : component_graphs) {
for (const auto& cg : component_graphs) {
aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
cscs.push_back(csc);
}
aten::CSRMatrix res = aten::DisjointUnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC(
(src_vtype == dst_vtype) ? 1 : 2, res, code);
rgptr =
UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
}
rel_graphs[etype] = rgptr;
}
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
return CreateHeteroGraph(
meta_graph, rel_graphs, std::move(num_nodes_per_type));
}
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes) {
// Sanity check for vertex sizes
CHECK_EQ(vertex_sizes->dtype.bits, 64) << "dtype of vertex_sizes should be int64";
CHECK_EQ(vertex_sizes->dtype.bits, 64)
<< "dtype of vertex_sizes should be int64";
CHECK_EQ(edge_sizes->dtype.bits, 64) << "dtype of edge_sizes should be int64";
const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t* vertex_sizes_data =
static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
......@@ -175,10 +184,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of vertices in the batch for all types
vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
}
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype;
CHECK_EQ(
vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type "
<< vtype;
}
// Sanity check for edge sizes
......@@ -193,10 +204,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of edges in the batch for all types
edge_cumsum[etype].push_back(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
}
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype;
<< "Sum of the given sizes must equal to the number of edges for type "
<< etype;
}
// Construct relation graphs for unbatched graphs
......@@ -211,14 +223,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
auto res = aten::DisjointPartitionCooBySizes(coo,
batch_size,
edge_cumsum[etype],
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
auto res = aten::DisjointPartitionCooBySizes(
coo, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, res[g], code);
(src_vtype == dst_vtype) ? 1 : 2, res[g], code);
rel_graphs[g].push_back(rgptr);
}
}
......@@ -228,14 +238,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csr,
batch_size,
edge_cumsum[etype],
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
auto res = aten::DisjointPartitionCsrBySizes(
csr, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR(
(src_vtype == dst_vtype) ? 1 : 2, res[g], code);
(src_vtype == dst_vtype) ? 1 : 2, res[g], code);
rel_graphs[g].push_back(rgptr);
}
}
......@@ -246,14 +254,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
const dgl_type_t dst_vtype = pair.second;
// CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csc,
batch_size,
edge_cumsum[etype],
vertex_cumsum[dst_vtype],
vertex_cumsum[src_vtype]);
auto res = aten::DisjointPartitionCsrBySizes(
csc, batch_size, edge_cumsum[etype], vertex_cumsum[dst_vtype],
vertex_cumsum[src_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC(
(src_vtype == dst_vtype) ? 1 : 2, res[g], code);
(src_vtype == dst_vtype) ? 1 : 2, res[g], code);
rel_graphs[g].push_back(rgptr);
}
}
......@@ -264,20 +270,26 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
for (uint64_t g = 0; g < batch_size; ++g) {
for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
rst.push_back(
CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
}
return rst;
}
HeteroGraphPtr SliceHeteroGraph(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray num_nodes_per_type,
IdArray start_nid_per_type, IdArray num_edges_per_type, IdArray start_eid_per_type) {
GraphPtr meta_graph, HeteroGraphPtr batched_graph,
IdArray num_nodes_per_type, IdArray start_nid_per_type,
IdArray num_edges_per_type, IdArray start_eid_per_type) {
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
const uint64_t* start_nid_per_type_data = static_cast<uint64_t*>(start_nid_per_type->data);
const uint64_t* num_nodes_per_type_data = static_cast<uint64_t*>(num_nodes_per_type->data);
const uint64_t* start_eid_per_type_data = static_cast<uint64_t*>(start_eid_per_type->data);
const uint64_t* num_edges_per_type_data = static_cast<uint64_t*>(num_edges_per_type->data);
const uint64_t* start_nid_per_type_data =
static_cast<uint64_t*>(start_nid_per_type->data);
const uint64_t* num_nodes_per_type_data =
static_cast<uint64_t*>(num_nodes_per_type->data);
const uint64_t* start_eid_per_type_data =
static_cast<uint64_t*>(start_eid_per_type->data);
const uint64_t* num_edges_per_type_data =
static_cast<uint64_t*>(num_edges_per_type->data);
// Map vertex type to the corresponding node range
const uint64_t num_vertex_types = meta_graph->NumVertices();
......@@ -287,7 +299,7 @@ HeteroGraphPtr SliceHeteroGraph(
for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
vertex_range[vtype].push_back(start_nid_per_type_data[vtype]);
vertex_range[vtype].push_back(
start_nid_per_type_data[vtype] + num_nodes_per_type_data[vtype]);
start_nid_per_type_data[vtype] + num_nodes_per_type_data[vtype]);
}
// Loop over all canonical etypes
......@@ -296,50 +308,52 @@ HeteroGraphPtr SliceHeteroGraph(
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code = batched_graph->GetRelationGraph(etype)->GetAllowedFormats();
const dgl_format_code_t code =
batched_graph->GetRelationGraph(etype)->GetAllowedFormats();
// handle graph without edges
std::vector<uint64_t> edge_range;
edge_range.push_back(start_eid_per_type_data[etype]);
edge_range.push_back(start_eid_per_type_data[etype] + num_edges_per_type_data[etype]);
edge_range.push_back(
start_eid_per_type_data[etype] + num_edges_per_type_data[etype]);
// prefer COO
if (FORMAT_HAS_COO(code)) {
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
aten::COOMatrix res = aten::COOSliceContiguousChunk(coo,
edge_range,
vertex_range[src_vtype],
vertex_range[dst_vtype]);
rgptr = UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
aten::COOMatrix res = aten::COOSliceContiguousChunk(
coo, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
rgptr =
UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) {
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csr,
edge_range,
vertex_range[src_vtype],
vertex_range[dst_vtype]);
rgptr = UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
csr, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);
rgptr =
UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csc,
edge_range,
vertex_range[dst_vtype],
vertex_range[src_vtype]);
rgptr = UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(
csc, edge_range, vertex_range[dst_vtype], vertex_range[src_vtype]);
rgptr =
UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
}
rel_graphs[etype] = rgptr;
}
return CreateHeteroGraph(meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>());
return CreateHeteroGraph(
meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>());
}
template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes) {
// Sanity check for vertex sizes
const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t* vertex_sizes_data =
static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
// Map vertex type to the corresponding node cum sum
......@@ -351,10 +365,12 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of vertices in the batch for all types
vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
}
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype;
CHECK_EQ(
vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type "
<< vtype;
}
// Sanity check for edge sizes
......@@ -369,10 +385,11 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of edges in the batch for all types
edge_cumsum[etype].push_back(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
}
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype;
<< "Sum of the given sizes must equal to the number of edges for type "
<< etype;
}
// Construct relation graphs for unbatched graphs
......@@ -390,7 +407,8 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) {
std::vector<IdType> result_src, result_dst;
// Loop over the chunk of edges for the specified graph and edge type
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) {
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1];
++e) {
// TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]);
result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
......@@ -410,15 +428,18 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
for (uint64_t g = 0; g < batch_size; ++g) {
for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
rst.push_back(
CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
}
return rst;
}
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes);
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
template std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes);
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
} // namespace dgl
......@@ -3,10 +3,13 @@
* \file graph/traversal.cc
* \brief Graph traversal implementation
*/
#include "./traversal.h"
#include <dgl/packed_func_ext.h>
#include <algorithm>
#include <queue>
#include "./traversal.h"
#include "../c_api_common.h"
using namespace dgl::runtime;
......@@ -15,46 +18,36 @@ namespace dgl {
namespace traverse {
namespace {
// A utility view class to wrap a vector into a queue.
template<typename DType>
template <typename DType>
struct VectorQueueWrapper {
std::vector<DType>* vec;
size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {}
explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
void push(const DType& elem) {
vec->push_back(elem);
}
void push(const DType& elem) { vec->push_back(elem); }
DType top() const {
return vec->operator[](head);
}
DType top() const { return vec->operator[](head); }
void pop() {
++head;
}
void pop() { ++head; }
bool empty() const {
return head == vec->size();
}
bool empty() const { return head == vec->size(); }
size_t size() const {
return vec->size() - head;
}
size_t size() const { return vec->size() - head; }
};
// Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together.
template<typename DType>
IdArray MergeMultipleTraversals(
const std::vector<std::vector<DType>>& traces) {
template <typename DType>
IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
total_len += traces[i].size();
}
IdArray ret = IdArray::Empty({total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray ret = IdArray::Empty(
{total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) {
......@@ -70,15 +63,15 @@ IdArray MergeMultipleTraversals(
// Internal function to compute sections if multiple traversal traces
// are merged into one ndarray.
template<typename DType>
IdArray ComputeMergedSections(
const std::vector<std::vector<DType>>& traces) {
template <typename DType>
IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
}
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray ret = IdArray::Empty(
{max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0;
......@@ -99,7 +92,8 @@ IdArray ComputeMergedSections(
* \brief Class for representing frontiers.
*
* Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value).
* An optional tag can be specified on each node/edge (represented by an int
* value).
*/
struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */
......@@ -112,142 +106,145 @@ struct Frontiers {
std::vector<int64_t> sections;
};
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers BFSNodesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
auto make_frontier = [&] () {
if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(queue.size());
}
};
auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&]() {
if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(queue.size());
}
};
BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
return front;
}
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
Frontiers BFSEdgesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front;
// NOTE: std::queue has no top() method.
std::vector<dgl_id_t> nodes;
VectorQueueWrapper<dgl_id_t> queue(&nodes);
auto visit = [&] (const dgl_id_t e) { front.ids.push_back(e); };
auto visit = [&](const dgl_id_t e) { front.ids.push_back(e); };
bool first_frontier = true;
auto make_frontier = [&] {
if (first_frontier) {
first_frontier = false; // do not push the first section when doing edges
} else if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(queue.size());
}
};
if (first_frontier) {
first_frontier = false; // do not push the first section when doing edges
} else if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(queue.size());
}
};
BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
return front;
}
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
});
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
});
Frontiers TopologicalNodesFrontiers(
const GraphInterface& graph, bool reversed) {
Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { };
auto make_frontier = [&] () {
if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(queue.size());
}
};
auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&]() {
if (!queue.empty()) {
// do not push zero-length frontier
front.sections.push_back(queue.size());
}
};
TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
return front;
}
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
}
IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges);
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) {
auto visit = [&](dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
}
IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges);
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4];
const bool return_labels = args[5];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len);
std::vector<std::vector<int64_t>> tags;
if (return_labels) {
tags.resize(len);
}
for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) {
edges[i].push_back(e);
if (return_labels) {
tags[i].push_back(tag);
}
};
DFSLabeledEdges(*g.sptr(), src_data[i], reversed,
has_reverse_edge, has_nontree_edge, visit);
}
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4];
const bool return_labels = args[5];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len);
std::vector<std::vector<int64_t>> tags;
if (return_labels) {
tags.resize(len);
}
for (int64_t i = 0; i < len; ++i) {
auto visit = [&](dgl_id_t e, int tag) {
edges[i].push_back(e);
if (return_labels) {
tags[i].push_back(tag);
}
};
DFSLabeledEdges(
*g.sptr(), src_data[i], reversed, has_reverse_edge,
has_nontree_edge, visit);
}
IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges);
if (return_labels) {
IdArray labels = MergeMultipleTraversals(tags);
*rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
} else {
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
}
});
IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges);
if (return_labels) {
IdArray labels = MergeMultipleTraversals(tags);
*rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
} else {
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
}
});
} // namespace traverse
} // namespace dgl
......@@ -3,15 +3,16 @@
* \file graph/traversal.h
* \brief Graph traversal routines.
*
* Traversal routines generate frontiers. Frontiers can be node frontiers or edge
* frontiers depending on the traversal function. Each frontier is a
* list of nodes/edges (specified by their ids). An optional tag can be specified
* for each node/edge (represented by an int value).
* Traversal routines generate frontiers. Frontiers can be node frontiers or
* edge frontiers depending on the traversal function. Each frontier is a list
* of nodes/edges (specified by their ids). An optional tag can be specified for
* each node/edge (represented by an int value).
*/
#ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_
#include <dgl/graph_interface.h>
#include <stack>
#include <tuple>
#include <vector>
......@@ -39,18 +40,16 @@ namespace traverse {
*
* \param graph The graph.
* \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction
* \param reversed If true, BFS follows the in-edge direction.
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
* \param make_frontier The function to indicate that a new froniter can be
* made.
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(const GraphInterface& graph,
IdArray source,
bool reversed,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
template <typename Queue, typename VisitFn, typename FrontierFn>
void BFSNodes(
const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,
VisitFn visit, FrontierFn make_frontier) {
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
......@@ -63,7 +62,8 @@ void BFSNodes(const GraphInterface& graph,
}
make_frontier();
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
const auto neighbor_iter =
reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
......@@ -102,19 +102,17 @@ void BFSNodes(const GraphInterface& graph,
*
* \param graph The graph.
* \param sources Source nodes.
* \param reversed If true, BFS follows the in-edge direction
* \param reversed If true, BFS follows the in-edge direction.
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* The argument would be edge ID.
* \param make_frontier The function to indicate that a new frontier can be made;
* \param make_frontier The function to indicate that a new frontier can be
* made.
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void BFSEdges(const GraphInterface& graph,
IdArray source,
bool reversed,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
template <typename Queue, typename VisitFn, typename FrontierFn>
void BFSEdges(
const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,
VisitFn visit, FrontierFn make_frontier) {
const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data);
......@@ -126,7 +124,8 @@ void BFSEdges(const GraphInterface& graph,
}
make_frontier();
const auto neighbor_iter = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
const auto neighbor_iter =
reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
while (!queue->empty()) {
const size_t size = queue->size();
for (size_t i = 0; i < size; ++i) {
......@@ -165,19 +164,20 @@ void BFSEdges(const GraphInterface& graph,
* void (*make_frontier)(void);
*
* \param graph The graph.
* \param reversed If true, follows the in-edge direction
* \param reversed If true, follows the in-edge direction.
* \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made;
* \param make_frontier The function to indicate that a new froniter can be
* made.
*/
template<typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(const GraphInterface& graph,
bool reversed,
Queue* queue,
VisitFn visit,
FrontierFn make_frontier) {
const auto get_degree = reversed? &GraphInterface::OutDegree : &GraphInterface::InDegree;
const auto neighbor_iter = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
template <typename Queue, typename VisitFn, typename FrontierFn>
void TopologicalNodes(
const GraphInterface& graph, bool reversed, Queue* queue, VisitFn visit,
FrontierFn make_frontier) {
const auto get_degree =
reversed ? &GraphInterface::OutDegree : &GraphInterface::InDegree;
const auto neighbor_iter =
reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
uint64_t num_visited_nodes = 0;
std::vector<uint64_t> degrees(graph.NumVertices(), 0);
for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) {
......@@ -207,7 +207,8 @@ void TopologicalNodes(const GraphInterface& graph,
}
if (num_visited_nodes != graph.NumVertices()) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph.";
LOG(FATAL)
<< "Error in topological traversal: loop detected in the given graph.";
}
}
......@@ -221,30 +222,28 @@ enum DFSEdgeTag {
* \brief Traverse the graph in a depth-first-search (DFS) order.
*
* The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2)
* FORWARD(0), REVERSE(1), NONTREE(2).
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge
* is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge
* is NOT in the DFS tree.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the
* edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have
* been visisted but the edge is NOT in the DFS tree.
*
* \param source Source node.
* \param reversed If true, DFS follows the in-edge direction
* \param has_reverse_edge If true, REVERSE edges are included
* \param has_nontree_edge If true, NONTREE edges are included
* \param visit The function to call when an edge is visited; the edge id and its
* tag will be given as the arguments.
* \param reversed If true, DFS follows the in-edge direction.
* \param has_reverse_edge If true, REVERSE edges are included.
* \param has_nontree_edge If true, NONTREE edges are included.
* \param visit The function to call when an edge is visited; the edge id and
* its tag will be given as the arguments.
*/
template<typename VisitFn>
void DFSLabeledEdges(const GraphInterface& graph,
dgl_id_t source,
bool reversed,
bool has_reverse_edge,
bool has_nontree_edge,
VisitFn visit) {
const auto succ = reversed? &GraphInterface::PredVec : &GraphInterface::SuccVec;
const auto out_edge = reversed? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
template <typename VisitFn>
void DFSLabeledEdges(
const GraphInterface& graph, dgl_id_t source, bool reversed,
bool has_reverse_edge, bool has_nontree_edge, VisitFn visit) {
const auto succ =
reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;
const auto out_edge =
reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;
if ((graph.*succ)(source).size() == 0) {
// no out-going edges from the source node
......@@ -273,7 +272,7 @@ void DFSLabeledEdges(const GraphInterface& graph,
stack.pop();
// find next one.
if (i < (graph.*succ)(u).size() - 1) {
stack.push(std::make_tuple(u, i+1, false));
stack.push(std::make_tuple(u, i + 1, false));
}
} else {
visited[v] = true;
......@@ -291,4 +290,3 @@ void DFSLabeledEdges(const GraphInterface& graph,
} // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_
......@@ -4,13 +4,12 @@
* \brief Operations on partition implemented in CUDA.
*/
#include "../partition_op.h"
#include <dgl/runtime/device_api.h>
#include "../../array/cuda/dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/workspace.h"
#include "../partition_op.h"
using namespace dgl::runtime;
......@@ -21,22 +20,21 @@ namespace impl {
namespace {
/**
* @brief Kernel to map global element IDs to partition IDs by remainder.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param num_parts The number of partitions.
* @param part_id The mapped partition ID (outupt).
*/
template<typename IdType>
* @brief Kernel to map global element IDs to partition IDs by remainder.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param num_parts The number of partitions.
* @param part_id The mapped partition ID (outupt).
*/
template <typename IdType>
__global__ void _MapProcByRemainderKernel(
const IdType * const global,
const int64_t num_elements,
const int64_t num_parts,
IdType * const part_id) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
const IdType* const global, const int64_t num_elements,
const int64_t num_parts, IdType* const part_id) {
assert(num_elements <= gridDim.x * blockDim.x);
const int64_t idx =
blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
if (idx < num_elements) {
part_id[idx] = global[idx] % num_parts;
......@@ -44,24 +42,23 @@ __global__ void _MapProcByRemainderKernel(
}
/**
* @brief Kernel to map global element IDs to partition IDs, using a bit-mask.
* The number of partitions must be a power a two.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param mask The bit-mask with 1's for each bit to keep from the element ID to
* extract the partition ID (e.g., an 8 partition mask would be 0x07).
* @param part_id The mapped partition ID (outupt).
*/
template<typename IdType>
* @brief Kernel to map global element IDs to partition IDs, using a bit-mask.
* The number of partitions must be a power a two.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param mask The bit-mask with 1's for each bit to keep from the element ID to
* extract the partition ID (e.g., an 8 partition mask would be 0x07).
* @param part_id The mapped partition ID (outupt).
*/
template <typename IdType>
__global__ void _MapProcByMaskRemainderKernel(
const IdType * const global,
const int64_t num_elements,
const IdType mask,
IdType * const part_id) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
const IdType* const global, const int64_t num_elements, const IdType mask,
IdType* const part_id) {
assert(num_elements <= gridDim.x * blockDim.x);
const int64_t idx =
blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
if (idx < num_elements) {
part_id[idx] = global[idx] & mask;
......@@ -69,22 +66,20 @@ __global__ void _MapProcByMaskRemainderKernel(
}
/**
* @brief Kernel to map global element IDs to local element IDs.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of IDs.
* @param num_parts The number of partitions.
* @param local The local element IDs (output).
*/
template<typename IdType>
* @brief Kernel to map global element IDs to local element IDs.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of IDs.
* @param num_parts The number of partitions.
* @param local The local element IDs (output).
*/
template <typename IdType>
__global__ void _MapLocalIndexByRemainderKernel(
const IdType * const global,
const int64_t num_elements,
const int num_parts,
IdType * const local) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
const IdType* const global, const int64_t num_elements, const int num_parts,
IdType* const local) {
assert(num_elements <= gridDim.x * blockDim.x);
const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < num_elements) {
local[idx] = global[idx] / num_parts;
......@@ -92,25 +87,22 @@ __global__ void _MapLocalIndexByRemainderKernel(
}
/**
* @brief Kernel to map local element IDs within a partition to their global
* IDs, using the remainder over the number of partitions.
*
* @tparam IdType The type of ID.
* @param local The local element IDs.
* @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map.
* @param num_parts The number of partitions.
* @param global The global element IDs (output).
*/
template<typename IdType>
* @brief Kernel to map local element IDs within a partition to their global
* IDs, using the remainder over the number of partitions.
*
* @tparam IdType The type of ID.
* @param local The local element IDs.
* @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map.
* @param num_parts The number of partitions.
* @param global The global element IDs (output).
*/
template <typename IdType>
__global__ void _MapGlobalIndexByRemainderKernel(
const IdType * const local,
const int part_id,
const int64_t num_elements,
const int num_parts,
IdType * const global) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
const IdType* const local, const int part_id, const int64_t num_elements,
const int num_parts, IdType* const global) {
assert(num_elements <= gridDim.x * blockDim.x);
const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
assert(part_id < num_parts);
......@@ -120,125 +112,113 @@ __global__ void _MapGlobalIndexByRemainderKernel(
}
/**
* @brief Device function to perform a binary search to find to which partition a
* given ID belongs.
*
* @tparam RangeType The type of range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param num_parts The number of partitions.
* @param target The element ID to find the partition of.
*
* @return The partition.
*/
template<typename RangeType>
* @brief Device function to perform a binary search to find to which partition
* a given ID belongs.
*
* @tparam RangeType The type of range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param num_parts The number of partitions.
* @param target The element ID to find the partition of.
*
* @return The partition.
*/
template <typename RangeType>
__device__ RangeType _SearchRange(
const RangeType * const range,
const int num_parts,
const RangeType target) {
const RangeType* const range, const int num_parts, const RangeType target) {
int start = 0;
int end = num_parts;
int cur = (end+start)/2;
int cur = (end + start) / 2;
assert(range[0] == 0);
assert(target < range[num_parts]);
while (start+1 < end) {
while (start + 1 < end) {
if (target < range[cur]) {
end = cur;
} else {
start = cur;
}
cur = (start+end)/2;
cur = (start + end) / 2;
}
return cur;
}
/**
* @brief Kernel to map element IDs to partition IDs.
*
* @tparam IdType The type of element ID.
* @tparam RangeType The type of of the range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param num_parts The number of partitions.
* @param part_id The partition ID assigned to each element (output).
*/
template<typename IdType, typename RangeType>
* @brief Kernel to map element IDs to partition IDs.
*
* @tparam IdType The type of element ID.
* @tparam RangeType The type of of the range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param num_parts The number of partitions.
* @param part_id The partition ID assigned to each element (output).
*/
template <typename IdType, typename RangeType>
__global__ void _MapProcByRangeKernel(
const RangeType * const range,
const IdType * const global,
const int64_t num_elements,
const int64_t num_parts,
IdType * const part_id) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
const RangeType* const range, const IdType* const global,
const int64_t num_elements, const int64_t num_parts,
IdType* const part_id) {
assert(num_elements <= gridDim.x * blockDim.x);
const int64_t idx =
blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
// rely on caching to load the range into L1 cache
if (idx < num_elements) {
part_id[idx] = static_cast<IdType>(_SearchRange(
range,
static_cast<int>(num_parts),
range, static_cast<int>(num_parts),
static_cast<RangeType>(global[idx])));
}
}
/**
* @brief Kernel to map global element IDs to their ID within their respective
* partition.
*
* @tparam IdType The type of element ID.
* @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs.
* @param num_elements The number of elements.
* @param num_parts The number of partitions.
* @param local The local element IDs (output).
*/
template<typename IdType, typename RangeType>
* @brief Kernel to map global element IDs to their ID within their respective
* partition.
*
* @tparam IdType The type of element ID.
* @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs.
* @param num_elements The number of elements.
* @param num_parts The number of partitions.
* @param local The local element IDs (output).
*/
template <typename IdType, typename RangeType>
__global__ void _MapLocalIndexByRangeKernel(
const RangeType * const range,
const IdType * const global,
const int64_t num_elements,
const int num_parts,
IdType * const local) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
const RangeType* const range, const IdType* const global,
const int64_t num_elements, const int num_parts, IdType* const local) {
assert(num_elements <= gridDim.x * blockDim.x);
const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
// rely on caching to load the range into L1 cache
if (idx < num_elements) {
const int proc = _SearchRange(
range,
static_cast<int>(num_parts),
range, static_cast<int>(num_parts),
static_cast<RangeType>(global[idx]));
local[idx] = global[idx] - range[proc];
}
}
/**
* @brief Kernel to map local element IDs within a partition to their global
* IDs.
*
* @tparam IdType The type of ID.
* @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigend to partitions.
* @param local The local element IDs.
* @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map.
* @param num_parts The number of partitions.
* @param global The global element IDs (output).
*/
template<typename IdType, typename RangeType>
* @brief Kernel to map local element IDs within a partition to their global
* IDs.
*
* @tparam IdType The type of ID.
* @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigend to partitions.
* @param local The local element IDs.
* @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map.
* @param num_parts The number of partitions.
* @param global The global element IDs (output).
*/
template <typename IdType, typename RangeType>
__global__ void _MapGlobalIndexByRangeKernel(
const RangeType * const range,
const IdType * const local,
const int part_id,
const int64_t num_elements,
const int num_parts,
IdType * const global) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
const RangeType* const range, const IdType* const local, const int part_id,
const int64_t num_elements, const int num_parts, IdType* const global) {
assert(num_elements <= gridDim.x * blockDim.x);
const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
assert(part_id < num_parts);
......@@ -252,11 +232,8 @@ __global__ void _MapGlobalIndexByRangeKernel(
// Remainder Based Partition Operations
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, NDArray>
GeneratePermutationFromRemainder(
int64_t array_size,
int num_parts,
IdArray in_idx) {
std::pair<IdArray, NDArray> GeneratePermutationFromRemainder(
int64_t array_size, int num_parts, IdArray in_idx) {
std::pair<IdArray, NDArray> result;
const auto& ctx = in_idx->ctx;
......@@ -265,19 +242,19 @@ GeneratePermutationFromRemainder(
const int64_t num_in = in_idx->shape[0];
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts <<
") must be at least 1.";
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
<< ") must be at least 1.";
if (num_parts == 1) {
// no permutation
result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx);
result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
return result;
}
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8);
result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx);
int64_t * out_counts = static_cast<int64_t*>(result.second->data);
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
int64_t* out_counts = static_cast<int64_t*>(result.second->data);
if (num_in == 0) {
// now that we've zero'd out_counts, nothing left to do for an empty
// mapping
......@@ -291,21 +268,20 @@ GeneratePermutationFromRemainder(
Workspace<IdType> proc_id_in(device, ctx, num_in);
{
const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x);
const dim3 grid((num_in + block.x - 1) / block.x);
if (num_parts < (1 << part_bits)) {
// num_parts is not a power of 2
CUDA_KERNEL_CALL(_MapProcByRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data),
num_in,
num_parts,
CUDA_KERNEL_CALL(
_MapProcByRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data), num_in, num_parts,
proc_id_in.get());
} else {
// num_parts is a power of 2
CUDA_KERNEL_CALL(_MapProcByMaskRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data),
num_in,
static_cast<IdType>(num_parts-1), // bit mask
CUDA_KERNEL_CALL(
_MapProcByMaskRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data), num_in,
static_cast<IdType>(num_parts - 1), // bit mask
proc_id_in.get());
}
}
......@@ -313,18 +289,20 @@ GeneratePermutationFromRemainder(
// then create a permutation array that groups processors together by
// performing a radix sort
Workspace<IdType> proc_id_out(device, ctx, num_in);
IdType * perm_out = static_cast<IdType*>(result.first->data);
IdType* perm_out = static_cast<IdType*>(result.first->data);
{
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
size_t sort_workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
stream));
Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
}
// explicitly free so workspace can be re-used
......@@ -334,83 +312,58 @@ GeneratePermutationFromRemainder(
// Count the number of values to be sent to each processor
{
using AtomicCount = unsigned long long; // NOLINT
static_assert(sizeof(AtomicCount) == sizeof(*out_counts),
using AtomicCount = unsigned long long; // NOLINT
static_assert(
sizeof(AtomicCount) == sizeof(*out_counts),
"AtomicCount must be the same width as int64_t for atomicAdd "
"in cub::DeviceHistogram::HistogramEven() to work");
// TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
// add a compile time check against the cub version to allow
// num_in > (2 << 31).
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max())) <<
"number of values to insert into histogram must be less than max "
"value of int.";
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
<< "number of values to insert into histogram must be less than max "
"value of int.";
size_t hist_workspace_size;
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr,
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
nullptr, hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<int>(num_in), stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(),
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<int>(num_in), stream));
}
return result;
}
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDGLCUDA, int32_t>(
int64_t array_size,
int num_parts,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDGLCUDA, int64_t>(
int64_t array_size,
int num_parts,
IdArray in_idx);
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
kDGLCUDA, int32_t>(int64_t array_size, int num_parts, IdArray in_idx);
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
kDGLCUDA, int64_t>(int64_t array_size, int num_parts, IdArray in_idx);
template <DGLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder(
const int num_parts,
IdArray global_idx) {
IdArray MapToLocalFromRemainder(const int num_parts, IdArray global_idx) {
const auto& ctx = global_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
sizeof(IdType)*8);
IdArray local_idx =
aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128);
const dim3 grid((global_idx->shape[0] +block.x-1)/block.x);
const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_MapLocalIndexByRemainderKernel,
grid,
block,
0,
stream,
static_cast<const IdType*>(global_idx->data),
global_idx->shape[0],
num_parts,
static_cast<IdType*>(local_idx->data));
_MapLocalIndexByRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
num_parts, static_cast<IdType*>(local_idx->data));
return local_idx;
} else {
......@@ -419,45 +372,33 @@ IdArray MapToLocalFromRemainder(
}
}
template IdArray
MapToLocalFromRemainder<kDGLCUDA, int32_t>(
int num_parts,
IdArray in_idx);
template IdArray
MapToLocalFromRemainder<kDGLCUDA, int64_t>(
int num_parts,
IdArray in_idx);
template IdArray MapToLocalFromRemainder<kDGLCUDA, int32_t>(
int num_parts, IdArray in_idx);
template IdArray MapToLocalFromRemainder<kDGLCUDA, int64_t>(
int num_parts, IdArray in_idx);
template <DGLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder(
const int num_parts,
IdArray local_idx,
const int part_id) {
CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id <<
"/" << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id <<
"/" << num_parts;
const int num_parts, IdArray local_idx, const int part_id) {
CHECK_LT(part_id, num_parts)
<< "Invalid partition id " << part_id << "/" << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
<< num_parts;
const auto& ctx = local_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
sizeof(IdType)*8);
IdArray global_idx =
aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128);
const dim3 grid((local_idx->shape[0] +block.x-1)/block.x);
const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_MapGlobalIndexByRemainderKernel,
grid,
block,
0,
stream,
static_cast<const IdType*>(local_idx->data),
part_id,
global_idx->shape[0],
num_parts,
_MapGlobalIndexByRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(local_idx->data), part_id,
global_idx->shape[0], num_parts,
static_cast<IdType*>(global_idx->data));
return global_idx;
......@@ -467,27 +408,16 @@ IdArray MapToGlobalFromRemainder(
}
}
template IdArray
MapToGlobalFromRemainder<kDGLCUDA, int32_t>(
int num_parts,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
int num_parts,
IdArray in_idx,
int part_id);
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int32_t>(
int num_parts, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
int num_parts, IdArray in_idx, int part_id);
// Range Based Partition Operations
template <DGLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, NDArray>
GeneratePermutationFromRange(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx) {
std::pair<IdArray, NDArray> GeneratePermutationFromRange(
int64_t array_size, int num_parts, IdArray range, IdArray in_idx) {
std::pair<IdArray, NDArray> result;
const auto& ctx = in_idx->ctx;
......@@ -496,19 +426,19 @@ GeneratePermutationFromRange(
const int64_t num_in = in_idx->shape[0];
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts <<
") must be at least 1.";
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
<< ") must be at least 1.";
if (num_parts == 1) {
// no permutation
result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx);
result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
return result;
}
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8);
result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx);
int64_t * out_counts = static_cast<int64_t*>(result.second->data);
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
int64_t* out_counts = static_cast<int64_t*>(result.second->data);
if (num_in == 0) {
// now that we've zero'd out_counts, nothing left to do for an empty
// mapping
......@@ -522,31 +452,32 @@ GeneratePermutationFromRange(
Workspace<IdType> proc_id_in(device, ctx, num_in);
{
const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x);
const dim3 grid((num_in + block.x - 1) / block.x);
CUDA_KERNEL_CALL(_MapProcByRangeKernel, grid, block, 0, stream,
CUDA_KERNEL_CALL(
_MapProcByRangeKernel, grid, block, 0, stream,
static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(in_idx->data),
num_in,
num_parts,
static_cast<const IdType*>(in_idx->data), num_in, num_parts,
proc_id_in.get());
}
// then create a permutation array that groups processors together by
// performing a radix sort
Workspace<IdType> proc_id_out(device, ctx, num_in);
IdType * perm_out = static_cast<IdType*>(result.first->data);
IdType* perm_out = static_cast<IdType*>(result.first->data);
{
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
size_t sort_workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
stream));
Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
}
// explicitly free so workspace can be re-used
......@@ -556,98 +487,68 @@ GeneratePermutationFromRange(
// Count the number of values to be sent to each processor
{
using AtomicCount = unsigned long long; // NOLINT
static_assert(sizeof(AtomicCount) == sizeof(*out_counts),
using AtomicCount = unsigned long long; // NOLINT
static_assert(
sizeof(AtomicCount) == sizeof(*out_counts),
"AtomicCount must be the same width as int64_t for atomicAdd "
"in cub::DeviceHistogram::HistogramEven() to work");
// TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
// add a compile time check against the cub version to allow
// num_in > (2 << 31).
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max())) <<
"number of values to insert into histogram must be less than max "
"value of int.";
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
<< "number of values to insert into histogram must be less than max "
"value of int.";
size_t hist_workspace_size;
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr,
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
nullptr, hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<int>(num_in), stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(),
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
static_cast<int>(num_in), stream));
}
return result;
}
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange(
const int num_parts,
IdArray range,
IdArray global_idx) {
const int num_parts, IdArray range, IdArray global_idx) {
const auto& ctx = global_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && global_idx->shape[0] > 0) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
sizeof(IdType)*8);
IdArray local_idx =
aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128);
const dim3 grid((global_idx->shape[0] +block.x-1)/block.x);
const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_MapLocalIndexByRangeKernel,
grid,
block,
0,
stream,
_MapLocalIndexByRangeKernel, grid, block, 0, stream,
static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(global_idx->data),
global_idx->shape[0],
num_parts,
static_cast<IdType*>(local_idx->data));
static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
num_parts, static_cast<IdType*>(local_idx->data));
return local_idx;
} else {
......@@ -656,60 +557,38 @@ IdArray MapToLocalFromRange(
}
}
template IdArray
MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(
int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(
int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts, IdArray range, IdArray in_idx);
template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange(
const int num_parts,
IdArray range,
IdArray local_idx,
const int part_id) {
CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id <<
"/" << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id <<
"/" << num_parts;
const int num_parts, IdArray range, IdArray local_idx, const int part_id) {
CHECK_LT(part_id, num_parts)
<< "Invalid partition id " << part_id << "/" << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
<< num_parts;
const auto& ctx = local_idx->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_parts > 1 && local_idx->shape[0] > 0) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
sizeof(IdType)*8);
IdArray global_idx =
aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
const dim3 block(128);
const dim3 grid((local_idx->shape[0] +block.x-1)/block.x);
const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_MapGlobalIndexByRangeKernel,
grid,
block,
0,
stream,
_MapGlobalIndexByRangeKernel, grid, block, 0, stream,
static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(local_idx->data),
part_id,
global_idx->shape[0],
num_parts,
static_cast<const IdType*>(local_idx->data), part_id,
global_idx->shape[0], num_parts,
static_cast<IdType*>(global_idx->data));
return global_idx;
......@@ -719,31 +598,14 @@ IdArray MapToGlobalFromRange(
}
}
template IdArray
MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(
int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(
int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
int num_parts, IdArray range, IdArray in_idx, int part_id);
} // namespace impl
} // namespace partition
......
......@@ -6,10 +6,11 @@
#include "ndarray_partition.h"
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <utility>
#include <dgl/runtime/registry.h>
#include <memory>
#include <utility>
#include "partition_op.h"
......@@ -19,30 +20,21 @@ namespace dgl {
namespace partition {
NDArrayPartition::NDArrayPartition(
const int64_t array_size, const int num_parts) :
array_size_(array_size),
num_parts_(num_parts) {
}
const int64_t array_size, const int num_parts)
: array_size_(array_size), num_parts_(num_parts) {}
int64_t NDArrayPartition::ArraySize() const {
return array_size_;
}
int NDArrayPartition::NumParts() const {
return num_parts_;
}
int64_t NDArrayPartition::ArraySize() const { return array_size_; }
int NDArrayPartition::NumParts() const { return num_parts_; }
class RemainderPartition : public NDArrayPartition {
public:
RemainderPartition(
const int64_t array_size, const int num_parts) :
NDArrayPartition(array_size, num_parts) {
RemainderPartition(const int64_t array_size, const int num_parts)
: NDArrayPartition(array_size, num_parts) {
// do nothing
}
std::pair<IdArray, NDArray>
GeneratePermutation(
std::pair<IdArray, NDArray> GeneratePermutation(
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
......@@ -55,13 +47,12 @@ class RemainderPartition : public NDArrayPartition {
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
"implemented.";
// should be unreachable
return std::pair<IdArray, NDArray>{};
}
IdArray MapToLocal(
IdArray in_idx) const override {
IdArray MapToLocal(IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
......@@ -73,14 +64,12 @@ class RemainderPartition : public NDArrayPartition {
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
"implemented.";
// should be unreachable
return IdArray{};
}
IdArray MapToGlobal(
IdArray in_idx,
const int part_id) const override {
IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
......@@ -92,41 +81,39 @@ class RemainderPartition : public NDArrayPartition {
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
"implemented.";
// should be unreachable
return IdArray{};
}
int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id << ") for "
"partition of size " << NumParts() << ".";
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
<< ") for "
"partition of size "
<< NumParts() << ".";
return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts());
}
};
class RangePartition : public NDArrayPartition {
public:
RangePartition(
const int64_t array_size,
const int num_parts,
IdArray range) :
NDArrayPartition(array_size, num_parts),
range_(range),
// We also need a copy of the range on the CPU, to compute partition
// sizes. We require the input range on the GPU, as if we have multiple
// GPUs, we can't know which is the proper one to copy the array to, but we
// have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {
RangePartition(const int64_t array_size, const int num_parts, IdArray range)
: NDArrayPartition(array_size, num_parts),
range_(range),
// We also need a copy of the range on the CPU, to compute partition
// sizes. We require the input range on the GPU, as if we have multiple
// GPUs, we can't know which is the proper one to copy the array to, but
// we have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {
auto ctx = range->ctx;
if (ctx.device_type != kDGLCUDA) {
LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before "
"creating the partition.";
LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before "
"creating the partition.";
}
}
std::pair<IdArray, NDArray>
GeneratePermutation(
std::pair<IdArray, NDArray> GeneratePermutation(
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
......@@ -134,11 +121,13 @@ class RangePartition : public NDArrayPartition {
if (ctx.device_type != range_->ctx.device_type ||
ctx.device_id != range_->ctx.device_id) {
LOG(FATAL) << "The range for the NDArrayPartition and the input "
"array must be on the same device: " << ctx << " vs. " << range_->ctx;
"array must be on the same device: "
<< ctx << " vs. " << range_->ctx;
}
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::GeneratePermutationFromRange<kDGLCUDA, IdType, RangeType>(
return impl::GeneratePermutationFromRange<
kDGLCUDA, IdType, RangeType>(
ArraySize(), NumParts(), range_, in_idx);
});
});
......@@ -146,13 +135,12 @@ class RangePartition : public NDArrayPartition {
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
"implemented.";
// should be unreachable
return std::pair<IdArray, NDArray>{};
}
IdArray MapToLocal(
IdArray in_idx) const override {
IdArray MapToLocal(IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
......@@ -166,14 +154,12 @@ class RangePartition : public NDArrayPartition {
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
"implemented.";
// should be unreachable
return IdArray{};
}
IdArray MapToGlobal(
IdArray in_idx,
const int part_id) const override {
IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
......@@ -187,17 +173,20 @@ class RangePartition : public NDArrayPartition {
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
"implemented.";
// should be unreachable
return IdArray{};
}
int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id << ") for "
"partition of size " << NumParts() << ".";
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
<< ") for "
"partition of size "
<< NumParts() << ".";
ATEN_ID_TYPE_SWITCH(range_cpu_->dtype, RangeType, {
const RangeType * const ptr = static_cast<const RangeType*>(range_cpu_->data);
return ptr[part_id+1]-ptr[part_id];
const RangeType* const ptr =
static_cast<const RangeType*>(range_cpu_->data);
return ptr[part_id + 1] - ptr[part_id];
});
}
......@@ -207,66 +196,58 @@ class RangePartition : public NDArrayPartition {
};
NDArrayPartitionRef CreatePartitionRemainderBased(
const int64_t array_size,
const int num_parts) {
return NDArrayPartitionRef(std::make_shared<RemainderPartition>(
array_size, num_parts));
const int64_t array_size, const int num_parts) {
return NDArrayPartitionRef(
std::make_shared<RemainderPartition>(array_size, num_parts));
}
NDArrayPartitionRef CreatePartitionRangeBased(
const int64_t array_size,
const int num_parts,
IdArray range) {
return NDArrayPartitionRef(std::make_shared<RangePartition>(
array_size,
num_parts,
range));
const int64_t array_size, const int num_parts, IdArray range) {
return NDArrayPartitionRef(
std::make_shared<RangePartition>(array_size, num_parts, range));
}
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t array_size = args[0];
int num_parts = args[1];
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t array_size = args[0];
int num_parts = args[1];
*rv = CreatePartitionRemainderBased(array_size, num_parts);
});
*rv = CreatePartitionRemainderBased(array_size, num_parts);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRangeBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int64_t array_size = args[0];
const int num_parts = args[1];
IdArray range = args[2];
*rv = CreatePartitionRangeBased(array_size, num_parts, range);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int64_t array_size = args[0];
const int num_parts = args[1];
IdArray range = args[2];
*rv = CreatePartitionRangeBased(array_size, num_parts, range);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
int part_id = args[1];
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
int part_id = args[1];
*rv = part->PartSize(part_id);
});
*rv = part->PartSize(part_id);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToLocal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
*rv = part->MapToLocal(idxs);
});
*rv = part->MapToLocal(idxs);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToGlobal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
const int part_id = args[2];
*rv = part->MapToGlobal(idxs, part_id);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
const int part_id = args[2];
*rv = part->MapToGlobal(idxs, part_id);
});
} // namespace partition
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file ndarray_partition.h
* \brief DGL utilities for working with the partitioned NDArrays
* \file ndarray_partition.h
* \brief DGL utilities for working with the partitioned NDArrays
*/
#ifndef DGL_PARTITION_NDARRAY_PARTITION_H_
#define DGL_PARTITION_NDARRAY_PARTITION_H_
#include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/object.h>
#include <utility>
namespace dgl {
......@@ -28,9 +28,7 @@ class NDArrayPartition : public runtime::Object {
* @param array_size The first dimension of the partitioned array.
* @param num_parts The number parts to the array is split into.
*/
NDArrayPartition(
int64_t array_size,
int num_parts);
NDArrayPartition(int64_t array_size, int num_parts);
virtual ~NDArrayPartition() = default;
......@@ -50,8 +48,7 @@ class NDArrayPartition : public runtime::Object {
* @return A pair containing 0) the permutation to re-order the indices by
* partition, 1) the number of indices per partition (int64_t).
*/
virtual std::pair<IdArray, NDArray>
GeneratePermutation(
virtual std::pair<IdArray, NDArray> GeneratePermutation(
IdArray in_idx) const = 0;
/**
......@@ -62,8 +59,7 @@ class NDArrayPartition : public runtime::Object {
*
* @return The local indices.
*/
virtual IdArray MapToLocal(
IdArray in_idx) const = 0;
virtual IdArray MapToLocal(IdArray in_idx) const = 0;
/**
* @brief Generate the global indices (the numbering unique across all
......@@ -74,9 +70,7 @@ class NDArrayPartition : public runtime::Object {
*
* @return The global indices.
*/
virtual IdArray MapToGlobal(
IdArray in_idx,
int part_id) const = 0;
virtual IdArray MapToGlobal(IdArray in_idx, int part_id) const = 0;
/**
* @brief Get the number of rows/items assigned to the given part.
......@@ -85,8 +79,7 @@ class NDArrayPartition : public runtime::Object {
*
* @return The size.
*/
virtual int64_t PartSize(
int part_id) const = 0;
virtual int64_t PartSize(int part_id) const = 0;
/**
* @brief Get the first dimension of the partitioned array.
......@@ -119,9 +112,7 @@ DGL_DEFINE_OBJECT_REF(NDArrayPartitionRef, NDArrayPartition);
* @return The partition object.
*/
NDArrayPartitionRef CreatePartitionRemainderBased(
int64_t array_size,
int num_parts);
int64_t array_size, int num_parts);
/**
* @brief Create a new partition object, using the range (exclusive prefix-sum)
......@@ -136,9 +127,7 @@ NDArrayPartitionRef CreatePartitionRemainderBased(
* @return The partition object.
*/
NDArrayPartitionRef CreatePartitionRangeBased(
int64_t array_size,
int num_parts,
IdArray range);
int64_t array_size, int num_parts, IdArray range);
} // namespace partition
} // namespace dgl
......
......@@ -4,11 +4,11 @@
* \brief DGL utilities for working with the partitioned NDArrays
*/
#ifndef DGL_PARTITION_PARTITION_OP_H_
#define DGL_PARTITION_PARTITION_OP_H_
#include <dgl/array.h>
#include <utility>
namespace dgl {
......@@ -33,11 +33,8 @@ namespace impl {
* indices in each part.
*/
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder(
int64_t array_size,
int num_parts,
IdArray in_idx);
std::pair<IdArray, IdArray> GeneratePermutationFromRemainder(
int64_t array_size, int num_parts, IdArray in_idx);
/**
* @brief Generate the set of local indices from the global indices, using
......@@ -52,9 +49,7 @@ GeneratePermutationFromRemainder(
* @return The array of local indices.
*/
template <DGLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder(
int num_parts,
IdArray global_idx);
IdArray MapToLocalFromRemainder(int num_parts, IdArray global_idx);
/**
* @brief Generate the set of global indices from the local indices, using
......@@ -70,10 +65,7 @@ IdArray MapToLocalFromRemainder(
* @return The array of global indices.
*/
template <DGLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder(
int num_parts,
IdArray local_idx,
int part_id);
IdArray MapToGlobalFromRemainder(int num_parts, IdArray local_idx, int part_id);
/**
* @brief Create a permutation that groups indices by the part id when used for
......@@ -96,12 +88,8 @@ IdArray MapToGlobalFromRemainder(
* indices in each part.
*/
template <DGLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, IdArray>
GeneratePermutationFromRange(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
std::pair<IdArray, IdArray> GeneratePermutationFromRange(
int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
/**
* @brief Generate the set of local indices from the global indices, using
......@@ -119,10 +107,7 @@ GeneratePermutationFromRange(
* @return The array of local indices.
*/
template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange(
int num_parts,
IdArray range,
IdArray global_idx);
IdArray MapToLocalFromRange(int num_parts, IdArray range, IdArray global_idx);
/**
* @brief Generate the set of global indices from the local indices, using
......@@ -142,12 +127,7 @@ IdArray MapToLocalFromRange(
*/
template <DGLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange(
int num_parts,
IdArray range,
IdArray local_idx,
int part_id);
int num_parts, IdArray range, IdArray local_idx, int part_id);
} // namespace impl
} // namespace partition
......
......@@ -4,12 +4,12 @@
* \brief Random number generator interfaces
*/
#include <dmlc/omp.h>
#include <dgl/runtime/registry.h>
#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/random.h>
#include <dgl/array.h>
#include <dgl/runtime/registry.h>
#include <dmlc/omp.h>
#ifdef DGL_USE_CUDA
#include "../runtime/cuda/cuda_common.h"
......@@ -20,53 +20,57 @@ using namespace dgl::runtime;
namespace dgl {
DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const int seed = args[0];
.set_body([](DGLArgs args, DGLRetValue *rv) {
const int seed = args[0];
runtime::parallel_for(0, omp_get_max_threads(), [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) {
RandomEngine::ThreadLocal()->SetSeed(seed);
}
});
runtime::parallel_for(0, omp_get_max_threads(), [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) {
RandomEngine::ThreadLocal()->SetSeed(seed);
}
});
#ifdef DGL_USE_CUDA
if (DeviceAPI::Get(kDGLCUDA)->IsAvailable()) {
auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
if (DeviceAPI::Get(kDGLCUDA)->IsAvailable()) {
auto *thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(
&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen, static_cast<uint64_t>(seed)));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
}
#endif // DGL_USE_CUDA
});
});
DGL_REGISTER_GLOBAL("rng._CAPI_Choice")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const int64_t num = args[0];
const int64_t population = args[1];
const NDArray prob = args[2];
const bool replace = args[3];
const int bits = args[4];
CHECK(bits == 32 || bits == 64)
<< "Supported bit widths are 32 and 64, but got " << bits << ".";
if (aten::IsNullArray(prob)) {
if (bits == 32) {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int32_t>(num, population, replace);
.set_body([](DGLArgs args, DGLRetValue *rv) {
const int64_t num = args[0];
const int64_t population = args[1];
const NDArray prob = args[2];
const bool replace = args[3];
const int bits = args[4];
CHECK(bits == 32 || bits == 64)
<< "Supported bit widths are 32 and 64, but got " << bits << ".";
if (aten::IsNullArray(prob)) {
if (bits == 32) {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int32_t>(
num, population, replace);
} else {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int64_t>(
num, population, replace);
}
} else {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int64_t>(num, population, replace);
if (bits == 32) {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int32_t, FloatType>(
num, prob, replace);
});
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int64_t, FloatType>(
num, prob, replace);
});
}
}
} else {
if (bits == 32) {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int32_t, FloatType>(num, prob, replace);
});
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int64_t, FloatType>(num, prob, replace);
});
}
}
});
});
}; // namespace dgl
......@@ -7,6 +7,7 @@
#define DGL_RPC_NET_TYPE_H_
#include <string>
#include "rpc_msg.h"
namespace dgl {
......@@ -29,11 +30,12 @@ struct RPCBase {
struct RPCSender : RPCBase {
/*!
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
*
* When there are multiple receivers to be connected, application will call
* `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* sure that either all the connections are successfully established or some
* of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
......@@ -49,13 +51,11 @@ struct RPCSender : RPCBase {
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiverFinalize(const int max_try_times) {
return true;
}
virtual bool ConnectReceiverFinalize(const int max_try_times) { return true; }
/*!
* \brief Send RPCMessage to specified Receiver.
* \param msg data message
* \param msg data message
* \param recv_id receiver's ID
*/
virtual void Send(const RPCMessage &msg, int recv_id) = 0;
......@@ -71,13 +71,14 @@ struct RPCReceiver : RPCBase {
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(const std::string &addr, int num_sender,
bool blocking = true) = 0;
virtual bool Wait(
const std::string &addr, int num_sender, bool blocking = true) = 0;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
virtual RPCStatus Recv(RPCMessage *msg, int timeout) = 0;
......
......@@ -17,7 +17,8 @@ namespace network {
// In most cases, delim contains only one character. In this case, we
// use CalculateReserveForVector to count the number of elements should
// be reserved in result vector, and thus optimize SplitStringUsing.
static int CalculateReserveForVector(const std::string& full, const char* delim) {
static int CalculateReserveForVector(
const std::string& full, const char* delim) {
int count = 0;
if (delim[0] != '\0' && delim[1] == '\0') {
// Optimize the common case where delim is a single character.
......@@ -38,19 +39,18 @@ static int CalculateReserveForVector(const std::string& full, const char* delim)
return count;
}
void SplitStringUsing(const std::string& full,
const char* delim,
std::vector<std::string>* result) {
void SplitStringUsing(
const std::string& full, const char* delim,
std::vector<std::string>* result) {
CHECK(delim != NULL);
CHECK(result != NULL);
result->reserve(CalculateReserveForVector(full, delim));
back_insert_iterator< std::vector<std::string> > it(*result);
back_insert_iterator<std::vector<std::string> > it(*result);
SplitStringToIteratorUsing(full, delim, &it);
}
void SplitStringToSetUsing(const std::string& full,
const char* delim,
std::set<std::string>* result) {
void SplitStringToSetUsing(
const std::string& full, const char* delim, std::set<std::string>* result) {
CHECK(delim != NULL);
CHECK(result != NULL);
simple_insert_iterator<std::set<std::string> > it(result);
......
......@@ -33,19 +33,18 @@ namespace network {
// substrings[2] == "banana"
//------------------------------------------------------------------------------
void SplitStringUsing(const std::string& full,
const char* delim,
std::vector<std::string>* result);
void SplitStringUsing(
const std::string& full, const char* delim,
std::vector<std::string>* result);
// This function has the same semnatic as SplitStringUsing. Results
// are saved in an STL set container.
void SplitStringToSetUsing(const std::string& full,
const char* delim,
std::set<std::string>* result);
void SplitStringToSetUsing(
const std::string& full, const char* delim, std::set<std::string>* result);
template <typename T>
struct simple_insert_iterator {
explicit simple_insert_iterator(T* t) : t_(t) { }
explicit simple_insert_iterator(T* t) : t_(t) {}
simple_insert_iterator<T>& operator=(const typename T::value_type& value) {
t_->insert(value);
......@@ -76,10 +75,8 @@ struct back_insert_iterator {
};
template <typename StringType, typename ITR>
static inline
void SplitStringToIteratorUsing(const StringType& full,
const char* delim,
ITR* result) {
static inline void SplitStringToIteratorUsing(
const StringType& full, const char* delim, ITR* result) {
CHECK_NOTNULL(delim);
// Optimize the common case where delim is a single character.
if (delim[0] != '\0' && delim[1] == '\0') {
......
......@@ -19,16 +19,17 @@ namespace network {
/*!
* \brief Network Sender for DGL distributed training.
*
* Sender is an abstract class that defines a set of APIs for sending binary
* data message over network. It can be implemented by different underlying
* networking libraries such TCP socket and MPI. One Sender can connect to
* multiple receivers and it can send data to specified receiver via receiver's ID.
* Sender is an abstract class that defines a set of APIs for sending binary
* data message over network. It can be implemented by different underlying
* networking libraries such TCP socket and MPI. One Sender can connect to
* multiple receivers and it can send data to specified receiver via receiver's
* ID.
*/
class Sender : public rpc::RPCSender {
public:
/*!
* \brief Sender constructor
* \param queue_size size (bytes) of message queue.
* \param queue_size size (bytes) of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional.
*/
......@@ -47,12 +48,12 @@ class Sender : public rpc::RPCSender {
* \param recv_id receiver's ID
* \return Status code
*
* (1) The send is non-blocking. There is no guarantee that the message has been
* physically sent out when the function returns.
* (2) The communicator will assume the responsibility of the given message.
* (3) The API is multi-thread safe.
* (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers.
* (1) The send is non-blocking. There is no guarantee that the message has
* been physically sent out when the function returns. (2) The communicator
* will assume the responsibility of the given message. (3) The API is
* multi-thread safe. (4) Messages sent to the same receiver are guaranteed to
* be received in the same order. There is no guarantee for messages sent to
* different receivers.
*/
virtual STATUS Send(Message msg, int recv_id) = 0;
......@@ -70,10 +71,11 @@ class Sender : public rpc::RPCSender {
/*!
* \brief Network Receiver for DGL distributed training.
*
* Receiver is an abstract class that defines a set of APIs for receiving binary data
* message over network. It can be implemented by different underlying networking
* libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders
* and it can receive data from multiple Senders concurrently.
* Receiver is an abstract class that defines a set of APIs for receiving binary
* data message over network. It can be implemented by different underlying
* networking libraries such as TCP socket and MPI. One Receiver can connect
* with multiple Senders and it can receive data from multiple Senders
* concurrently.
*/
class Receiver : public rpc::RPCReceiver {
public:
......@@ -98,11 +100,13 @@ class Receiver : public rpc::RPCReceiver {
* \brief Recv data from Sender
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code
*
* (1) The Recv() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
* (2) Memory allocated by communicator but will not own it after the function
* returns.
*/
virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0;
......@@ -110,11 +114,13 @@ class Receiver : public rpc::RPCReceiver {
* \brief Recv data from a specified Sender
* \param msg pointer of data message
* \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code
*
* (1) The RecvFrom() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
* (2) Memory allocated by communicator but will not own it after the function
* returns.
*/
virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0;
......
......@@ -3,10 +3,11 @@
* \file msg_queue.cc
* \brief Message queue for DGL distributed training.
*/
#include "msg_queue.h"
#include <dmlc/logging.h>
#include <cstring>
#include "msg_queue.h"
#include <cstring>
namespace dgl {
namespace network {
......@@ -38,9 +39,7 @@ STATUS MessageQueue::Add(Message msg, bool is_blocking) {
if (msg.size > free_size_ && !is_blocking) {
return QUEUE_FULL;
}
cond_not_full_.wait(lock, [&]() {
return msg.size <= free_size_;
});
cond_not_full_.wait(lock, [&]() { return msg.size <= free_size_; });
// Add data pointer to queue
queue_.push(msg);
free_size_ -= msg.size;
......@@ -61,9 +60,8 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
}
}
cond_not_empty_.wait(lock, [this] {
return !queue_.empty() || exit_flag_.load();
});
cond_not_empty_.wait(
lock, [this] { return !queue_.empty() || exit_flag_.load(); });
if (finished_producers_.size() >= num_producers_ && queue_.empty()) {
return QUEUE_CLOSE;
}
......@@ -98,8 +96,7 @@ bool MessageQueue::Empty() const {
bool MessageQueue::EmptyAndNoMoreAdd() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size() == 0 &&
finished_producers_.size() >= num_producers_;
return queue_.size() == 0 && finished_producers_.size() >= num_producers_;
}
} // namespace network
......
......@@ -8,14 +8,14 @@
#include <dgl/runtime/ndarray.h>
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <set>
#include <string>
#include <utility> // for pair
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <functional>
namespace dgl {
namespace network {
......@@ -25,13 +25,13 @@ typedef int STATUS;
/*!
* \brief Status code of message queue
*/
#define ADD_SUCCESS 3400 // Add message successfully
#define MSG_GT_SIZE 3401 // Message size beyond queue size
#define MSG_LE_ZERO 3402 // Message size is not a positive number
#define QUEUE_CLOSE 3403 // Cannot add message when queue is closed
#define QUEUE_FULL 3404 // Cannot add message when queue is full
#define REMOVE_SUCCESS 3405 // Remove message successfully
#define QUEUE_EMPTY 3406 // Cannot remove when queue is empty
#define ADD_SUCCESS 3400 // Add message successfully
#define MSG_GT_SIZE 3401 // Message size beyond queue size
#define MSG_LE_ZERO 3402 // Message size is not a positive number
#define QUEUE_CLOSE 3403 // Cannot add message when queue is closed
#define QUEUE_FULL 3404 // Cannot add message when queue is full
#define REMOVE_SUCCESS 3405 // Remove message successfully
#define QUEUE_EMPTY 3406 // Cannot remove when queue is empty
/*!
* \brief Message used by network communicator and message queue.
......@@ -40,13 +40,13 @@ struct Message {
/*!
* \brief Constructor
*/
Message() { }
Message() {}
/*!
* \brief Constructor
*/
*/
Message(char* data_ptr, int64_t data_size)
: data(data_ptr), size(data_size) { }
: data(data_ptr), size(data_size) {}
/*!
* \brief message data
......@@ -69,22 +69,23 @@ struct Message {
/*!
* \brief Free memory buffer of message
*/
inline void DefaultMessageDeleter(Message* msg) { delete [] msg->data; }
inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }
/*!
* \brief Message Queue for network communication.
*
* MessageQueue is FIFO queue that adopts producer/consumer model for data message.
* It supports one or more producer threads and one or more consumer threads.
* Producers invokes Add() to push data message into the queue, and consumers
* invokes Remove() to pop data message from queue. Add() and Remove() use two condition
* variables to synchronize producer threads and consumer threads. Each producer
* invokes SignalFinished(producer_id) to claim that it is about to finish, where
* producer_id is an integer uniquely identify a producer thread. This signaling mechanism
* prevents consumers from waiting after all producers have finished their jobs.
* MessageQueue is FIFO queue that adopts producer/consumer model for data
* message. It supports one or more producer threads and one or more consumer
* threads. Producers invokes Add() to push data message into the queue, and
* consumers invokes Remove() to pop data message from queue. Add() and Remove()
* use two condition variables to synchronize producer threads and consumer
* threads. Each producer invokes SignalFinished(producer_id) to claim that it
* is about to finish, where producer_id is an integer uniquely identify a
* producer thread. This signaling mechanism prevents consumers from waiting
* after all producers have finished their jobs.
*
* MessageQueue is thread-safe.
*
*
*/
class MessageQueue {
public:
......@@ -93,8 +94,8 @@ class MessageQueue {
* \param queue_size size (bytes) of message queue
* \param num_producers number of producers, use 1 by default
*/
explicit MessageQueue(int64_t queue_size /* in bytes */,
int num_producers = 1);
explicit MessageQueue(
int64_t queue_size /* in bytes */, int num_producers = 1);
/*!
* \brief MessageQueue deconstructor
......@@ -134,48 +135,48 @@ class MessageQueue {
bool EmptyAndNoMoreAdd() const;
protected:
/*!
* \brief message queue
/*!
* \brief message queue
*/
std::queue<Message> queue_;
/*!
* \brief Size of the queue in bytes
/*!
* \brief Size of the queue in bytes
*/
int64_t queue_size_;
/*!
* \brief Free size of the queue
/*!
* \brief Free size of the queue
*/
int64_t free_size_;
/*!
* \brief Used to check all producers will no longer produce anything
/*!
* \brief Used to check all producers will no longer produce anything
*/
size_t num_producers_;
/*!
* \brief Store finished producer id
/*!
* \brief Store finished producer id
*/
std::set<int /* producer_id */> finished_producers_;
/*!
* \brief Condition when consumer should wait
/*!
* \brief Condition when consumer should wait
*/
std::condition_variable cond_not_full_;
/*!
* \brief Condition when producer should wait
/*!
* \brief Condition when producer should wait
*/
std::condition_variable cond_not_empty_;
/*!
* \brief Signal for exit wait
/*!
* \brief Signal for exit wait
*/
std::atomic<bool> exit_flag_{false};
/*!
* \brief Protect all above data and conditions
/*!
* \brief Protect all above data and conditions
*/
mutable std::mutex mutex_;
};
......
......@@ -3,29 +3,29 @@
* \file communicator.cc
* \brief SocketCommunicator for DGL distributed training.
*/
#include <dmlc/logging.h>
#include "socket_communicator.h"
#include <string.h>
#include <dmlc/logging.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <memory>
#include "socket_communicator.h"
#include "../../c_api_common.h"
#include "socket_pool.h"
#ifdef _WIN32
#include <windows.h>
#else // !_WIN32
#else // !_WIN32
#include <unistd.h>
#endif // _WIN32
namespace dgl {
namespace network {
/////////////////////////////////////// SocketSender ///////////////////////////////////////////
/////////////////////////////////////// SocketSender
//////////////////////////////////////////////
bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
if (recv_id < 0) {
......@@ -92,9 +92,7 @@ bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
// Create a new thread for this socket connection
threads_.push_back(std::make_shared<std::thread>(
SendLoop,
sockets_[thread_id],
msg_queue_[thread_id]));
SendLoop, sockets_[thread_id], msg_queue_[thread_id]));
}
return true;
......@@ -105,14 +103,13 @@ void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob->append(reinterpret_cast<char*>(&nonempty_ndarray_count),
sizeof(int32_t));
zerocopy_blob->append(
reinterpret_cast<char*>(&nonempty_ndarray_count), sizeof(int32_t));
Message rpc_meta_msg;
rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
rpc_meta_msg.size = zerocopy_blob->size();
rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
CHECK_EQ(Send(
rpc_meta_msg, recv_id), ADD_SUCCESS);
CHECK_EQ(Send(rpc_meta_msg, recv_id), ADD_SUCCESS);
// send real ndarray data
for (auto ptr : zc_write_strm.buffer_list()) {
Message ndarray_data_msg;
......@@ -123,8 +120,7 @@ void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
ndarray_data_msg.size = ptr.size;
NDArray tensor = ptr.tensor;
ndarray_data_msg.deallocator = [tensor](Message*) {};
CHECK_EQ(Send(
ndarray_data_msg, recv_id), ADD_SUCCESS);
CHECK_EQ(Send(ndarray_data_msg, recv_id), ADD_SUCCESS);
}
}
......@@ -156,7 +152,7 @@ void SocketSender::Finalize() {
}
// Clear all sockets
for (auto& group_sockets_ : sockets_) {
for (auto &socket : group_sockets_) {
for (auto& socket : group_sockets_) {
socket.second->Close();
}
}
......@@ -168,9 +164,8 @@ void SendCore(Message msg, TCPSocket* socket) {
int64_t sent_bytes = 0;
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = socket->Send(
reinterpret_cast<char*>(&msg.size) + sent_bytes,
max_len);
int64_t tmp =
socket->Send(reinterpret_cast<char*>(&msg.size) + sent_bytes, max_len);
CHECK_NE(tmp, -1);
sent_bytes += tmp;
}
......@@ -178,7 +173,7 @@ void SendCore(Message msg, TCPSocket* socket) {
sent_bytes = 0;
while (sent_bytes < msg.size) {
int64_t max_len = msg.size - sent_bytes;
int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
int64_t tmp = socket->Send(msg.data + sent_bytes, max_len);
CHECK_NE(tmp, -1);
sent_bytes += tmp;
}
......@@ -189,8 +184,8 @@ void SendCore(Message msg, TCPSocket* socket) {
}
void SocketSender::SendLoop(
std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue) {
std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue) {
for (;;) {
Message msg;
STATUS code = queue->Remove(&msg);
......@@ -205,8 +200,10 @@ void SocketSender::SendLoop(
}
}
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
/////////////////////////////////////// SocketReceiver
//////////////////////////////////////////////
bool SocketReceiver::Wait(
const std::string& addr, int num_sender, bool blocking) {
CHECK_GT(num_sender, 0);
CHECK_EQ(blocking, true);
std::vector<std::string> substring;
......@@ -231,7 +228,7 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
num_sender_ = num_sender;
#ifdef USE_EPOLL
if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
max_thread_count_ = num_sender_;
max_thread_count_ = num_sender_;
}
#else
max_thread_count_ = num_sender_;
......@@ -256,7 +253,8 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
auto socket = std::make_shared<TCPSocket>();
sockets_[thread_id][i] = socket;
msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) {
if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) ==
false) {
LOG(WARNING) << "Error on accept socket.";
return false;
}
......@@ -266,10 +264,7 @@ bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking
for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
// create new thread for each socket
threads_.push_back(std::make_shared<std::thread>(
RecvLoop,
sockets_[thread_id],
msg_queue_,
&queue_sem_));
RecvLoop, sockets_[thread_id], msg_queue_, &queue_sem_));
}
return true;
......@@ -285,7 +280,7 @@ rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
return rpc::kRPCTimeOut;
}
CHECK_EQ(status, REMOVE_SUCCESS);
char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
char* count_ptr = rpc_meta_msg.data + rpc_meta_msg.size - sizeof(int32_t);
int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
// Recv real ndarray data
std::vector<void*> buffer_list(nonempty_ndarray_count);
......@@ -305,7 +300,8 @@ rpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {
CHECK_EQ(status, REMOVE_SUCCESS);
buffer_list[i] = ndarray_data_msg.data;
}
StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list);
StreamWithBuffer zc_read_strm(
rpc_meta_msg.data, rpc_meta_msg.size - sizeof(int32_t), buffer_list);
zc_read_strm.Read(msg);
rpc_meta_msg.deallocator(&rpc_meta_msg);
return rpc::kRPCSuccess;
......@@ -352,7 +348,7 @@ void SocketReceiver::Finalize() {
for (auto& mq : msg_queue_) {
// wait until queue is empty
while (mq.second->Empty() == false) {
std::this_thread::sleep_for(std::chrono::seconds(1));
std::this_thread::sleep_for(std::chrono::seconds(1));
}
mq.second->SignalFinished(mq.first);
}
......@@ -376,8 +372,7 @@ int64_t RecvDataSize(TCPSocket* socket) {
while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - received_bytes;
int64_t tmp = socket->Receive(
reinterpret_cast<char*>(&data_size) + received_bytes,
max_len);
reinterpret_cast<char*>(&data_size) + received_bytes, max_len);
if (tmp == -1) {
if (received_bytes > 0) {
// We want to finish reading full data_size
......@@ -390,8 +385,9 @@ int64_t RecvDataSize(TCPSocket* socket) {
return data_size;
}
void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
int64_t *received_bytes) {
void RecvData(
TCPSocket* socket, char* buffer, const int64_t& data_size,
int64_t* received_bytes) {
while (*received_bytes < data_size) {
int64_t max_len = data_size - *received_bytes;
int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
......@@ -404,15 +400,17 @@ void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
}
void SocketReceiver::RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> queues,
runtime::Semaphore *queue_sem) {
std::unordered_map<
int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
sockets,
std::unordered_map<
int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
queues,
runtime::Semaphore* queue_sem) {
std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
SocketPool socket_pool;
for (auto& socket : sockets) {
auto &sender_id = socket.first;
auto& sender_id = socket.first;
socket_pool.AddSocket(socket.second, sender_id);
recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
}
......@@ -432,9 +430,9 @@ void SocketReceiver::RecvLoop(
// Nonblocking socket might be interrupted at any point. So we need to
// store the partially received data
std::unique_ptr<RecvContext> &ctx = recv_contexts[sender_id];
int64_t &data_size = ctx->data_size;
int64_t &received_bytes = ctx->received_bytes;
std::unique_ptr<RecvContext>& ctx = recv_contexts[sender_id];
int64_t& data_size = ctx->data_size;
int64_t& received_bytes = ctx->received_bytes;
char*& buffer = ctx->buffer;
if (data_size == -1) {
......@@ -443,7 +441,7 @@ void SocketReceiver::RecvLoop(
if (data_size > 0) {
try {
buffer = new char[data_size];
} catch(const std::bad_alloc&) {
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Cannot allocate enough memory for message, "
<< "(message size: " << data_size << ")";
}
......
......@@ -6,22 +6,23 @@
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#include <thread>
#include <vector>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <memory>
#include <vector>
#include "../../runtime/semaphore_wrapper.h"
#include "common.h"
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
#include "common.h"
namespace dgl {
namespace network {
static constexpr int kTimeOut = 10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kTimeOut =
10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024
/*!
......@@ -41,19 +42,20 @@ class SocketSender : public Sender {
public:
/*!
* \brief Sender constructor
* \param queue_size size of message queue
* \param queue_size size of message queue
* \param max_thread_count size of thread pool. 0 for no limit
*/
SocketSender(int64_t queue_size, int max_thread_count)
: Sender(queue_size, max_thread_count) {}
: Sender(queue_size, max_thread_count) {}
/*!
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
*
* When there are multiple receivers to be connected, application will call
* `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* sure that either all the connections are successfully established or some
* of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
......@@ -73,7 +75,7 @@ class SocketSender : public Sender {
/*!
* \brief Send RPCMessage to specified Receiver.
* \param msg data message
* \param msg data message
* \param recv_id receiver's ID
*/
void Send(const rpc::RPCMessage& msg, int recv_id) override;
......@@ -86,60 +88,63 @@ class SocketSender : public Sender {
/*!
* \brief Communicator type: 'socket'
*/
const std::string &NetType() const override {
const std::string& NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
/*!
* \brief Send data to specified Receiver. Actually pushing message to message queue.
* \param msg data message
* \param recv_id receiver's ID
* \return Status code
* \brief Send data to specified Receiver. Actually pushing message to message
* queue.
* \param msg data message.
* \param recv_id receiver's ID.
* \return Status code.
*
* (1) The send is non-blocking. There is no guarantee that the message has been
* physically sent out when the function returns.
* (2) The communicator will assume the responsibility of the given message.
* (3) The API is multi-thread safe.
* (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers.
* (1) The send is non-blocking. There is no guarantee that the message has
* been physically sent out when the function returns. (2) The communicator
* will assume the responsibility of the given message. (3) The API is
* multi-thread safe. (4) Messages sent to the same receiver are guaranteed to
* be received in the same order. There is no guarantee for messages sent to
* different receivers.
*/
STATUS Send(Message msg, int recv_id) override;
private:
/*!
* \brief socket for each connection of receiver
*/
std::vector<std::unordered_map<int /* receiver ID */,
std::shared_ptr<TCPSocket>>> sockets_;
*/
std::vector<
std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>>
sockets_;
/*!
* \brief receivers' address
*/
*/
std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*!
* \brief message queue for each thread
*/
*/
std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
/*!
* \brief Independent thread
*/
*/
std::vector<std::shared_ptr<std::thread>> threads_;
/*!
* \brief Send-loop for each thread
* \param sockets TCPSockets for current thread
* \param queue message_queue for current thread
*
*
* Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static void SendLoop(
std::unordered_map<int /* Receiver (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue);
std::unordered_map<
int /* Receiver (virtual) ID */, std::shared_ptr<TCPSocket>>
sockets,
std::shared_ptr<MessageQueue> queue);
};
/*!
......@@ -155,7 +160,7 @@ class SocketReceiver : public Receiver {
* \param max_thread_count size of thread pool. 0 for no limit
*/
SocketReceiver(int64_t queue_size, int max_thread_count)
: Receiver(queue_size, max_thread_count) {}
: Receiver(queue_size, max_thread_count) {}
/*!
* \brief Wait for all the Senders to connect
......@@ -166,13 +171,14 @@ class SocketReceiver : public Receiver {
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(const std::string &addr, int num_sender,
bool blocking = true) override;
bool Wait(
const std::string& addr, int num_sender, bool blocking = true) override;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
......@@ -181,23 +187,28 @@ class SocketReceiver : public Receiver {
* \brief Recv data from Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code
*
* (1) The Recv() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
* (2) Memory allocated by communicator but will not own it after the function
* returns.
*/
STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
/*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \brief Recv data from a specified Sender. Actually removing data from
* msg_queue.
* \param msg pointer of data message.
* \param send_id sender's ID
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return Status code
*
* (1) The RecvFrom() API is thread-safe.
* (2) Memory allocated by communicator but will not own it after the function returns.
* (2) Memory allocated by communicator but will not own it after the function
* returns.
*/
STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
......@@ -211,7 +222,7 @@ class SocketReceiver : public Receiver {
/*!
* \brief Communicator type: 'socket'
*/
const std::string &NetType() const override {
const std::string& NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
......@@ -220,7 +231,7 @@ class SocketReceiver : public Receiver {
struct RecvContext {
int64_t data_size = -1;
int64_t received_bytes = 0;
char *buffer = nullptr;
char* buffer = nullptr;
};
/*!
* \brief number of sender
......@@ -229,25 +240,27 @@ class SocketReceiver : public Receiver {
/*!
* \brief server socket for listening connections
*/
*/
TCPSocket* server_socket_;
/*!
* \brief socket for each client connections
*/
std::vector<std::unordered_map<int /* Sender (virutal) ID */,
std::shared_ptr<TCPSocket>>> sockets_;
*/
std::vector<std::unordered_map<
int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>>
sockets_;
/*!
* \brief Message queue for each socket connection
*/
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> msg_queue_;
*/
std::unordered_map<
int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
msg_queue_;
std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
/*!
* \brief Independent thead
*/
*/
std::vector<std::shared_ptr<std::thread>> threads_;
/*!
......@@ -263,13 +276,15 @@ class SocketReceiver : public Receiver {
*
* Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
*/
static void RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> queues,
runtime::Semaphore *queue_sem);
std::unordered_map<
int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>
sockets,
std::unordered_map<
int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>
queues,
runtime::Semaphore* queue_sem);
};
} // namespace network
......
......@@ -6,6 +6,7 @@
#include "socket_pool.h"
#include <dmlc/logging.h>
#include "tcp_socket.h"
#ifdef USE_EPOLL
......@@ -24,8 +25,8 @@ SocketPool::SocketPool() {
#endif
}
void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
int events) {
void SocketPool::AddSocket(
std::shared_ptr<TCPSocket> socket, int socket_id, int events) {
int fd = socket->Socket();
tcp_sockets_[fd] = socket;
socket_ids_[fd] = socket_id;
......@@ -47,7 +48,7 @@ void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
#else
if (tcp_sockets_.size() > 1) {
LOG(FATAL) << "SocketPool supports only one socket if not use epoll."
"Please turn on USE_EPOLL on building";
"Please turn on USE_EPOLL on building";
}
#endif
}
......
......@@ -6,9 +6,9 @@
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_
#include <unordered_map>
#include <queue>
#include <memory>
#include <queue>
#include <unordered_map>
namespace dgl {
namespace network {
......@@ -19,7 +19,7 @@ class TCPSocket;
* \brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification
* mechanism in Linux operating system.
* mechanism in Linux operating system.
*/
class SocketPool {
public:
......@@ -42,8 +42,8 @@ class SocketPool {
* \param socket_id receiver/sender id of the socket
* \param events READ, WRITE or READ + WRITE
*/
void AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
int events = READ);
void AddSocket(
std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ);
/*!
* \brief Remove socket from SocketPool
......
......@@ -15,8 +15,8 @@
#include <sys/socket.h>
#include <unistd.h>
#endif // !_WIN32
#include <string.h>
#include <errno.h>
#include <string.h>
namespace dgl {
namespace network {
......@@ -31,28 +31,29 @@ TCPSocket::TCPSocket() {
LOG(FATAL) << "Can't create new socket. Error: " << strerror(errno);
}
#ifndef _WIN32
// This is to make sure the same port can be reused right after the socket is closed.
// This is to make sure the same port can be reused right after the socket is
// closed.
int enable = 1;
if (setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0) {
LOG(WARNING) << "cannot make the socket reusable. Error: " << strerror(errno);
LOG(WARNING) << "cannot make the socket reusable. Error: "
<< strerror(errno);
}
#endif // _WIN32
}
TCPSocket::~TCPSocket() {
Close();
}
TCPSocket::~TCPSocket() { Close(); }
bool TCPSocket::Connect(const char * ip, int port) {
bool TCPSocket::Connect(const char *ip, int port) {
SAI sa_server;
sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port);
sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port);
int retval = 0;
do { // retry if EINTR failure appears
if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
0 <= (retval = connect(socket_, reinterpret_cast<SA*>(&sa_server),
sizeof(sa_server)))) {
0 <= (retval = connect(
socket_, reinterpret_cast<SA *>(&sa_server),
sizeof(sa_server)))) {
return true;
}
} while (retval == -1 && errno == EINTR);
......@@ -60,10 +61,10 @@ bool TCPSocket::Connect(const char * ip, int port) {
return false;
}
bool TCPSocket::Bind(const char * ip, int port) {
bool TCPSocket::Bind(const char *ip, int port) {
SAI sa_server;
sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port);
sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port);
int ret = 0;
ret = inet_pton(AF_INET, ip, &sa_server.sin_addr);
if (ret == 0) {
......@@ -75,13 +76,15 @@ bool TCPSocket::Bind(const char * ip, int port) {
return false;
}
do { // retry if EINTR failure appears
if (0 <= (ret = bind(socket_, reinterpret_cast<SA *>(&sa_server),
sizeof(sa_server)))) {
if (0 <=
(ret = bind(
socket_, reinterpret_cast<SA *>(&sa_server), sizeof(sa_server)))) {
return true;
}
} while (ret == -1 && errno == EINTR);
LOG(ERROR) << "Failed bind on " << ip << ":" << port << " , error: " << strerror(errno);
LOG(ERROR) << "Failed bind on " << ip << ":" << port
<< " , error: " << strerror(errno);
return false;
}
......@@ -93,17 +96,18 @@ bool TCPSocket::Listen(int max_connection) {
}
} while (retval == -1 && errno == EINTR);
LOG(ERROR) << "Failed listen on socket fd: " << socket_ << " , error: " << strerror(errno);
LOG(ERROR) << "Failed listen on socket fd: " << socket_
<< " , error: " << strerror(errno);
return false;
}
bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
bool TCPSocket::Accept(TCPSocket *socket, std::string *ip, int *port) {
int sock_client;
SAI sa_client;
socklen_t len = sizeof(sa_client);
do { // retry if EINTR failure appears
sock_client = accept(socket_, reinterpret_cast<SA*>(&sa_client), &len);
sock_client = accept(socket_, reinterpret_cast<SA *>(&sa_client), &len);
} while (sock_client == -1 && errno == EINTR);
if (sock_client < 0) {
......@@ -114,10 +118,8 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
}
char tmp[INET_ADDRSTRLEN];
const char * ip_client = inet_ntop(AF_INET,
&sa_client.sin_addr,
tmp,
sizeof(tmp));
const char *ip_client =
inet_ntop(AF_INET, &sa_client.sin_addr, tmp, sizeof(tmp));
CHECK(ip_client != nullptr);
ip->assign(ip_client);
*port = ntohs(sa_client.sin_port);
......@@ -166,22 +168,20 @@ bool TCPSocket::SetNonBlocking(bool flag) {
#endif // _WIN32
void TCPSocket::SetTimeout(int timeout) {
#ifdef _WIN32
timeout = timeout * 1000; // WIN API accepts millsec
setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char*>(&timeout), sizeof(timeout));
#else // !_WIN32
struct timeval tv;
tv.tv_sec = timeout;
tv.tv_usec = 0;
setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO,
&tv, sizeof(tv));
#endif // _WIN32
#ifdef _WIN32
timeout = timeout * 1000; // WIN API accepts millsec
setsockopt(
socket_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&timeout),
sizeof(timeout));
#else // !_WIN32
struct timeval tv;
tv.tv_sec = timeout;
tv.tv_usec = 0;
setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
#endif // _WIN32
}
bool TCPSocket::ShutDown(int ways) {
return 0 == shutdown(socket_, ways);
}
bool TCPSocket::ShutDown(int ways) { return 0 == shutdown(socket_, ways); }
void TCPSocket::Close() {
if (socket_ >= 0) {
......@@ -194,7 +194,7 @@ void TCPSocket::Close() {
}
}
int64_t TCPSocket::Send(const char * data, int64_t len_data) {
int64_t TCPSocket::Send(const char *data, int64_t len_data) {
int64_t number_send;
do { // retry if EINTR failure appears
......@@ -207,7 +207,7 @@ int64_t TCPSocket::Send(const char * data, int64_t len_data) {
return number_send;
}
int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
int64_t TCPSocket::Receive(char *buffer, int64_t size_buffer) {
int64_t number_recv;
do { // retry if EINTR failure appears
......@@ -220,9 +220,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
return number_recv;
}
int TCPSocket::Socket() const {
return socket_;
}
int TCPSocket::Socket() const { return socket_; }
} // namespace network
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment