"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ca60ad8e55e8c2c43c3b88279fd3351918af8c39"
Unverified Commit 8f0df39e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4810)



* [Misc] clang-format auto fix.

* manual

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 401e1278
...@@ -4,15 +4,17 @@ ...@@ -4,15 +4,17 @@
* \brief Convert multigraphs to simple graphs * \brief Convert multigraphs to simple graphs
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <vector> #include <dgl/transform.h>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h"
#include "../heterograph.h" #include "../heterograph.h"
#include "../unit_graph.h" #include "../unit_graph.h"
#include "../../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -25,7 +27,8 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ...@@ -25,7 +27,8 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToSimpleGraph(const HeteroGraphPtr graph) { ToSimpleGraph(const HeteroGraphPtr graph) {
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
const auto metagraph = graph->meta_graph(); const auto metagraph = graph->meta_graph();
const auto &ugs = std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs(); const auto &ugs =
std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();
std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes); std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes);
std::vector<HeteroGraphPtr> rel_graphs(num_etypes); std::vector<HeteroGraphPtr> rel_graphs(num_etypes);
...@@ -35,31 +38,31 @@ ToSimpleGraph(const HeteroGraphPtr graph) { ...@@ -35,31 +38,31 @@ ToSimpleGraph(const HeteroGraphPtr graph) {
std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result; std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result;
} }
const HeteroGraphPtr result = CreateHeteroGraph( const HeteroGraphPtr result =
metagraph, rel_graphs, graph->NumVerticesPerType()); CreateHeteroGraph(metagraph, rel_graphs, graph->NumVerticesPerType());
return std::make_tuple(result, counts, edge_maps); return std::make_tuple(result, counts, edge_maps);
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const auto result = ToSimpleGraph(graph_ref.sptr()); const auto result = ToSimpleGraph(graph_ref.sptr());
List<Value> counts, edge_maps; List<Value> counts, edge_maps;
for (const IdArray &count : std::get<1>(result)) for (const IdArray &count : std::get<1>(result))
counts.push_back(Value(MakeValue(count))); counts.push_back(Value(MakeValue(count)));
for (const IdArray &edge_map : std::get<2>(result)) for (const IdArray &edge_map : std::get<2>(result))
edge_maps.push_back(Value(MakeValue(edge_map))); edge_maps.push_back(Value(MakeValue(edge_map)));
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(std::get<0>(result))); ret.push_back(HeteroGraphRef(std::get<0>(result)));
ret.push_back(counts); ret.push_back(counts);
ret.push_back(edge_maps); ret.push_back(edge_maps);
*rv = ret; *rv = ret;
}); });
}; // namespace transform }; // namespace transform
......
This diff is collapsed.
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
* \file graph/traversal.cc * \file graph/traversal.cc
* \brief Graph traversal implementation * \brief Graph traversal implementation
*/ */
#include "./traversal.h"
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include "./traversal.h"
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -15,46 +18,36 @@ namespace dgl { ...@@ -15,46 +18,36 @@ namespace dgl {
namespace traverse { namespace traverse {
namespace { namespace {
// A utility view class to wrap a vector into a queue. // A utility view class to wrap a vector into a queue.
template<typename DType> template <typename DType>
struct VectorQueueWrapper { struct VectorQueueWrapper {
std::vector<DType>* vec; std::vector<DType>* vec;
size_t head = 0; size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {} explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
void push(const DType& elem) { void push(const DType& elem) { vec->push_back(elem); }
vec->push_back(elem);
}
DType top() const { DType top() const { return vec->operator[](head); }
return vec->operator[](head);
}
void pop() { void pop() { ++head; }
++head;
}
bool empty() const { bool empty() const { return head == vec->size(); }
return head == vec->size();
}
size_t size() const { size_t size() const { return vec->size() - head; }
return vec->size() - head;
}
}; };
// Internal function to merge multiple traversal traces into one ndarray. // Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together. // It is similar to zip the vectors together.
template<typename DType> template <typename DType>
IdArray MergeMultipleTraversals( IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0; int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) { for (size_t j = 0; j < traces.size(); ++j) {
...@@ -70,15 +63,15 @@ IdArray MergeMultipleTraversals( ...@@ -70,15 +63,15 @@ IdArray MergeMultipleTraversals(
// Internal function to compute sections if multiple traversal traces // Internal function to compute sections if multiple traversal traces
// are merged into one ndarray. // are merged into one ndarray.
template<typename DType> template <typename DType>
IdArray ComputeMergedSections( IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0; int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
...@@ -99,7 +92,8 @@ IdArray ComputeMergedSections( ...@@ -99,7 +92,8 @@ IdArray ComputeMergedSections(
* \brief Class for representing frontiers. * \brief Class for representing frontiers.
* *
* Each frontier is a list of nodes/edges (specified by their ids). * Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value). * An optional tag can be specified on each node/edge (represented by an int
* value).
*/ */
struct Frontiers { struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */ /*!\brief a vector store for the nodes/edges in all the frontiers */
...@@ -112,142 +106,145 @@ struct Frontiers { ...@@ -112,142 +106,145 @@ struct Frontiers {
std::vector<int64_t> sections; std::vector<int64_t> sections;
}; };
Frontiers BFSNodesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) { Frontiers BFSNodesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
} }
}; };
BFSNodes(graph, source, reversed, &queue, visit, make_frontier); BFSNodes(graph, source, reversed, &queue, visit, make_frontier);
return front; return front;
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed); const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids); IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
Frontiers BFSEdgesFrontiers(const GraphInterface& graph, IdArray source, bool reversed) { Frontiers BFSEdgesFrontiers(
const GraphInterface& graph, IdArray source, bool reversed) {
Frontiers front; Frontiers front;
// NOTE: std::queue has no top() method. // NOTE: std::queue has no top() method.
std::vector<dgl_id_t> nodes; std::vector<dgl_id_t> nodes;
VectorQueueWrapper<dgl_id_t> queue(&nodes); VectorQueueWrapper<dgl_id_t> queue(&nodes);
auto visit = [&] (const dgl_id_t e) { front.ids.push_back(e); }; auto visit = [&](const dgl_id_t e) { front.ids.push_back(e); };
bool first_frontier = true; bool first_frontier = true;
auto make_frontier = [&] { auto make_frontier = [&] {
if (first_frontier) { if (first_frontier) {
first_frontier = false; // do not push the first section when doing edges first_frontier = false; // do not push the first section when doing edges
} else if (!queue.empty()) { } else if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
} }
}; };
BFSEdges(graph, source, reversed, &queue, visit, make_frontier); BFSEdges(graph, source, reversed, &queue, visit, make_frontier);
return front; return front;
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed); const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);
IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids); IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});
}); });
Frontiers TopologicalNodesFrontiers(const GraphInterface& graph, bool reversed) { Frontiers TopologicalNodesFrontiers(
const GraphInterface& graph, bool reversed) {
Frontiers front; Frontiers front;
VectorQueueWrapper<dgl_id_t> queue(&front.ids); VectorQueueWrapper<dgl_id_t> queue(&front.ids);
auto visit = [&] (const dgl_id_t v) { }; auto visit = [&](const dgl_id_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
front.sections.push_back(queue.size()); front.sections.push_back(queue.size());
} }
}; };
TopologicalNodes(graph, reversed, &queue, visit, make_frontier); TopologicalNodes(graph, reversed, &queue, visit, make_frontier);
return front; return front;
} }
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
bool reversed = args[1]; bool reversed = args[1];
const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed); const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);
IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids); IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);
IdArray sections = CopyVectorToNDArray<int64_t>(front.sections); IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);
*rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { edges[i].push_back(e); }; auto visit = [&](dgl_id_t e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit); DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);
} }
IdArray ids = MergeMultipleTraversals(edges); IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges); IdArray sections = ComputeMergedSections(edges);
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
const bool has_reverse_edge = args[3]; const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4]; const bool has_nontree_edge = args[4];
const bool return_labels = args[5]; const bool return_labels = args[5];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
std::vector<std::vector<int64_t>> tags; std::vector<std::vector<int64_t>> tags;
if (return_labels) { if (return_labels) {
tags.resize(len); tags.resize(len);
} }
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (dgl_id_t e, int tag) { auto visit = [&](dgl_id_t e, int tag) {
edges[i].push_back(e); edges[i].push_back(e);
if (return_labels) { if (return_labels) {
tags[i].push_back(tag); tags[i].push_back(tag);
} }
}; };
DFSLabeledEdges(*g.sptr(), src_data[i], reversed, DFSLabeledEdges(
has_reverse_edge, has_nontree_edge, visit); *g.sptr(), src_data[i], reversed, has_reverse_edge,
} has_nontree_edge, visit);
}
IdArray ids = MergeMultipleTraversals(edges); IdArray ids = MergeMultipleTraversals(edges);
IdArray sections = ComputeMergedSections(edges); IdArray sections = ComputeMergedSections(edges);
if (return_labels) { if (return_labels) {
IdArray labels = MergeMultipleTraversals(tags); IdArray labels = MergeMultipleTraversals(tags);
*rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections}); *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});
} else { } else {
*rv = ConvertNDArrayVectorToPackedFunc({ids, sections}); *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});
} }
}); });
} // namespace traverse } // namespace traverse
} // namespace dgl } // namespace dgl
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* \file ndarray_partition.h * \file ndarray_partition.h
* \brief DGL utilities for working with the partitioned NDArrays * \brief DGL utilities for working with the partitioned NDArrays
*/ */
#ifndef DGL_PARTITION_NDARRAY_PARTITION_H_ #ifndef DGL_PARTITION_NDARRAY_PARTITION_H_
#define DGL_PARTITION_NDARRAY_PARTITION_H_ #define DGL_PARTITION_NDARRAY_PARTITION_H_
#include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/object.h>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -28,9 +28,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -28,9 +28,7 @@ class NDArrayPartition : public runtime::Object {
* @param array_size The first dimension of the partitioned array. * @param array_size The first dimension of the partitioned array.
* @param num_parts The number parts to the array is split into. * @param num_parts The number parts to the array is split into.
*/ */
NDArrayPartition( NDArrayPartition(int64_t array_size, int num_parts);
int64_t array_size,
int num_parts);
virtual ~NDArrayPartition() = default; virtual ~NDArrayPartition() = default;
...@@ -50,8 +48,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -50,8 +48,7 @@ class NDArrayPartition : public runtime::Object {
* @return A pair containing 0) the permutation to re-order the indices by * @return A pair containing 0) the permutation to re-order the indices by
* partition, 1) the number of indices per partition (int64_t). * partition, 1) the number of indices per partition (int64_t).
*/ */
virtual std::pair<IdArray, NDArray> virtual std::pair<IdArray, NDArray> GeneratePermutation(
GeneratePermutation(
IdArray in_idx) const = 0; IdArray in_idx) const = 0;
/** /**
...@@ -62,8 +59,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -62,8 +59,7 @@ class NDArrayPartition : public runtime::Object {
* *
* @return The local indices. * @return The local indices.
*/ */
virtual IdArray MapToLocal( virtual IdArray MapToLocal(IdArray in_idx) const = 0;
IdArray in_idx) const = 0;
/** /**
* @brief Generate the global indices (the numbering unique across all * @brief Generate the global indices (the numbering unique across all
...@@ -74,9 +70,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -74,9 +70,7 @@ class NDArrayPartition : public runtime::Object {
* *
* @return The global indices. * @return The global indices.
*/ */
virtual IdArray MapToGlobal( virtual IdArray MapToGlobal(IdArray in_idx, int part_id) const = 0;
IdArray in_idx,
int part_id) const = 0;
/** /**
* @brief Get the number of rows/items assigned to the given part. * @brief Get the number of rows/items assigned to the given part.
...@@ -85,8 +79,7 @@ class NDArrayPartition : public runtime::Object { ...@@ -85,8 +79,7 @@ class NDArrayPartition : public runtime::Object {
* *
* @return The size. * @return The size.
*/ */
virtual int64_t PartSize( virtual int64_t PartSize(int part_id) const = 0;
int part_id) const = 0;
/** /**
* @brief Get the first dimension of the partitioned array. * @brief Get the first dimension of the partitioned array.
...@@ -119,9 +112,7 @@ DGL_DEFINE_OBJECT_REF(NDArrayPartitionRef, NDArrayPartition); ...@@ -119,9 +112,7 @@ DGL_DEFINE_OBJECT_REF(NDArrayPartitionRef, NDArrayPartition);
* @return The partition object. * @return The partition object.
*/ */
NDArrayPartitionRef CreatePartitionRemainderBased( NDArrayPartitionRef CreatePartitionRemainderBased(
int64_t array_size, int64_t array_size, int num_parts);
int num_parts);
/** /**
* @brief Create a new partition object, using the range (exclusive prefix-sum) * @brief Create a new partition object, using the range (exclusive prefix-sum)
...@@ -136,9 +127,7 @@ NDArrayPartitionRef CreatePartitionRemainderBased( ...@@ -136,9 +127,7 @@ NDArrayPartitionRef CreatePartitionRemainderBased(
* @return The partition object. * @return The partition object.
*/ */
NDArrayPartitionRef CreatePartitionRangeBased( NDArrayPartitionRef CreatePartitionRangeBased(
int64_t array_size, int64_t array_size, int num_parts, IdArray range);
int num_parts,
IdArray range);
} // namespace partition } // namespace partition
} // namespace dgl } // namespace dgl
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "socket_pool.h" #include "socket_pool.h"
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include "tcp_socket.h" #include "tcp_socket.h"
#ifdef USE_EPOLL #ifdef USE_EPOLL
...@@ -24,8 +25,8 @@ SocketPool::SocketPool() { ...@@ -24,8 +25,8 @@ SocketPool::SocketPool() {
#endif #endif
} }
void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id, void SocketPool::AddSocket(
int events) { std::shared_ptr<TCPSocket> socket, int socket_id, int events) {
int fd = socket->Socket(); int fd = socket->Socket();
tcp_sockets_[fd] = socket; tcp_sockets_[fd] = socket;
socket_ids_[fd] = socket_id; socket_ids_[fd] = socket_id;
...@@ -47,7 +48,7 @@ void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id, ...@@ -47,7 +48,7 @@ void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
#else #else
if (tcp_sockets_.size() > 1) { if (tcp_sockets_.size() > 1) {
LOG(FATAL) << "SocketPool supports only one socket if not use epoll." LOG(FATAL) << "SocketPool supports only one socket if not use epoll."
"Please turn on USE_EPOLL on building"; "Please turn on USE_EPOLL on building";
} }
#endif #endif
} }
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment