Unverified Commit bcd37684 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Replace /*! with /**. (#4823)



* replace

* blabla

* balbla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 619d735d
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file graph/unit_graph.cc * @file graph/unit_graph.cc
* @brief UnitGraph graph implementation * @brief UnitGraph graph implementation
...@@ -161,17 +161,17 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -161,17 +161,17 @@ class UnitGraph::COO : public BaseHeteroGraph {
} }
/*! @brief Pin the adj_: COOMatrix of the COO graph. */ /** @brief Pin the adj_: COOMatrix of the COO graph. */
void PinMemory_() { void PinMemory_() {
adj_.PinMemory_(); adj_.PinMemory_();
} }
/*! @brief Unpin the adj_: COOMatrix of the COO graph. */ /** @brief Unpin the adj_: COOMatrix of the COO graph. */
void UnpinMemory_() { void UnpinMemory_() {
adj_.UnpinMemory_(); adj_.UnpinMemory_();
} }
/*! @brief Record stream for the adj_: COOMatrix of the COO graph. */ /** @brief Record stream for the adj_: COOMatrix of the COO graph. */
void RecordStream(DGLStreamHandle stream) override { void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream); adj_.RecordStream(stream);
} }
...@@ -432,7 +432,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -432,7 +432,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
return adj_; return adj_;
} }
/*! /**
* @brief Determines whether the graph is "hypersparse", i.e. having significantly more * @brief Determines whether the graph is "hypersparse", i.e. having significantly more
* nodes than edges. * nodes than edges.
*/ */
...@@ -457,7 +457,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -457,7 +457,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
private: private:
friend class Serializer; friend class Serializer;
/*! @brief internal adjacency matrix. Data array is empty */ /** @brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_; aten::COOMatrix adj_;
}; };
...@@ -467,7 +467,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -467,7 +467,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
/*! @brief CSR graph */ /** @brief CSR graph */
class UnitGraph::CSR : public BaseHeteroGraph { class UnitGraph::CSR : public BaseHeteroGraph {
public: public:
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
...@@ -571,17 +571,17 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -571,17 +571,17 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
} }
/*! @brief Pin the adj_: CSRMatrix of the CSR graph. */ /** @brief Pin the adj_: CSRMatrix of the CSR graph. */
void PinMemory_() { void PinMemory_() {
adj_.PinMemory_(); adj_.PinMemory_();
} }
/*! @brief Unpin the adj_: CSRMatrix of the CSR graph. */ /** @brief Unpin the adj_: CSRMatrix of the CSR graph. */
void UnpinMemory_() { void UnpinMemory_() {
adj_.UnpinMemory_(); adj_.UnpinMemory_();
} }
/*! @brief Record stream for the adj_: CSRMatrix of the CSR graph. */ /** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
void RecordStream(DGLStreamHandle stream) override { void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream); adj_.RecordStream(stream);
} }
...@@ -851,7 +851,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -851,7 +851,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
private: private:
friend class Serializer; friend class Serializer;
/*! @brief internal adjacency matrix. Data array stores edge ids */ /** @brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
}; };
...@@ -1433,7 +1433,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1433,7 +1433,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
return ret; return ret;
} }
/* !\brief Return out csr. If not exist, transpose the other one.*/ /** @brief Return out csr. If not exist, transpose the other one.*/
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & CSR_CODE)) if (!(formats_ & CSR_CODE))
...@@ -1469,7 +1469,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1469,7 +1469,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
return ret; return ret;
} }
/* !\brief Return coo. If not exist, create from csr.*/ /** @brief Return coo. If not exist, create from csr.*/
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & COO_CODE)) if (!(formats_ & COO_CODE))
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file graph/unit_graph.h * @file graph/unit_graph.h
* @brief UnitGraph graph * @brief UnitGraph graph
...@@ -26,7 +26,7 @@ class HeteroGraph; ...@@ -26,7 +26,7 @@ class HeteroGraph;
class UnitGraph; class UnitGraph;
typedef std::shared_ptr<UnitGraph> UnitGraphPtr; typedef std::shared_ptr<UnitGraph> UnitGraphPtr;
/*! /**
* @brief UnitGraph graph * @brief UnitGraph graph
* *
* UnitGraph graph is a special type of heterograph which * UnitGraph graph is a special type of heterograph which
...@@ -164,7 +164,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -164,7 +164,7 @@ class UnitGraph : public BaseHeteroGraph {
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
// creators // creators
/*! @brief Create a graph with no edges */ /** @brief Create a graph with no edges */
static HeteroGraphPtr Empty( static HeteroGraphPtr Empty(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
DGLDataType dtype, DGLContext ctx) { DGLDataType dtype, DGLContext ctx) {
...@@ -173,7 +173,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -173,7 +173,7 @@ class UnitGraph : public BaseHeteroGraph {
return CreateFromCOO(num_vtypes, num_src, num_dst, row, col); return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
} }
/*! @brief Create a graph from COO arrays */ /** @brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, bool row_sorted = false, IdArray row, IdArray col, bool row_sorted = false,
...@@ -183,7 +183,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -183,7 +183,7 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! @brief Create a graph from (out) CSR arrays */ /** @brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
...@@ -193,7 +193,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -193,7 +193,7 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! @brief Create a graph from (in) CSC arrays */ /** @brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
...@@ -203,13 +203,13 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -203,13 +203,13 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! @brief Convert the graph to use the given number of bits for storage */ /** @brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! @brief Copy the data to another context */ /** @brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx);
/*! /**
* @brief Pin the in_csr_, out_scr_ and coo_ of the current graph. * @brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be pinned inplace. Behavior depends on the current context, * @note The graph will be pinned inplace. Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
...@@ -219,7 +219,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -219,7 +219,7 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
void PinMemory_() override; void PinMemory_() override;
/*! /**
* @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph. * @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be unpinned inplace. Behavior depends on the current context, * @note The graph will be unpinned inplace. Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
...@@ -228,13 +228,13 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -228,13 +228,13 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
void UnpinMemory_(); void UnpinMemory_();
/*! /**
* @brief Record stream for this graph. * @brief Record stream for this graph.
* @param stream The stream that is using the graph * @param stream The stream that is using the graph
*/ */
void RecordStream(DGLStreamHandle stream) override; void RecordStream(DGLStreamHandle stream) override;
/*! /**
* @brief Create in-edge CSR format of the unit graph. * @brief Create in-edge CSR format of the unit graph.
* @param inplace if true and the in-edge CSR format does not exist, the created * @param inplace if true and the in-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted. * format will be cached in this object unless the format is restricted.
...@@ -242,7 +242,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -242,7 +242,7 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
CSRPtr GetInCSR(bool inplace = true) const; CSRPtr GetInCSR(bool inplace = true) const;
/*! /**
* @brief Create out-edge CSR format of the unit graph. * @brief Create out-edge CSR format of the unit graph.
* @param inplace if true and the out-edge CSR format does not exist, the created * @param inplace if true and the out-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted. * format will be cached in this object unless the format is restricted.
...@@ -250,7 +250,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -250,7 +250,7 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
CSRPtr GetOutCSR(bool inplace = true) const; CSRPtr GetOutCSR(bool inplace = true) const;
/*! /**
* @brief Create COO format of the unit graph. * @brief Create COO format of the unit graph.
* @param inplace if true and the COO format does not exist, the created * @param inplace if true and the COO format does not exist, the created
* format will be cached in this object unless the format is restricted. * format will be cached in this object unless the format is restricted.
...@@ -258,20 +258,20 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -258,20 +258,20 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
COOPtr GetCOO(bool inplace = true) const; COOPtr GetCOO(bool inplace = true) const;
/*! @return Return the COO matrix form */ /** @return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override; aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;
/*! @return Return the in-edge CSC in the matrix form */ /** @return Return the in-edge CSC in the matrix form */
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override; aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override;
/*! @return Return the out-edge CSR in the matrix form */ /** @return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override; aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return SelectFormat(preferred_formats); return SelectFormat(preferred_formats);
} }
/*! /**
* @brief Return the graph in the given format. Perform format conversion if the * @brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist. * requested format does not exist.
* *
...@@ -285,19 +285,19 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -285,19 +285,19 @@ class UnitGraph : public BaseHeteroGraph {
HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override; HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
/*! @return Load UnitGraph from stream, using CSRMatrix*/ /** @return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs); bool Load(dmlc::Stream* fs);
/*! @return Save UnitGraph to stream, using CSRMatrix */ /** @return Save UnitGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream* fs) const;
/*! @brief Creat a LineGraph of self */ /** @brief Creat a LineGraph of self */
HeteroGraphPtr LineGraph(bool backtracking) const; HeteroGraphPtr LineGraph(bool backtracking) const;
/*! @return the reversed graph */ /** @return the reversed graph */
UnitGraphPtr Reverse() const; UnitGraphPtr Reverse() const;
/*! @return the simpled (no-multi-edge) graph /** @return the simpled (no-multi-edge) graph
* the count recording the number of duplicated edges from the original graph. * the count recording the number of duplicated edges from the original graph.
* the edge mapping from the edge IDs of original graph to those of the * the edge mapping from the edge IDs of original graph to those of the
* returned graph. * returned graph.
...@@ -319,7 +319,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -319,7 +319,7 @@ class UnitGraph : public BaseHeteroGraph {
// private empty constructor // private empty constructor
UnitGraph() {} UnitGraph() {}
/*! /**
* @brief constructor * @brief constructor
* @param metagraph metagraph * @param metagraph metagraph
* @param in_csr in edge csr * @param in_csr in edge csr
...@@ -329,7 +329,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -329,7 +329,7 @@ class UnitGraph : public BaseHeteroGraph {
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! /**
* @brief constructor * @brief constructor
* @param num_vtypes number of vertex types (1 or 2) * @param num_vtypes number of vertex types (1 or 2)
* @param metagraph metagraph * @param metagraph metagraph
...@@ -350,10 +350,10 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -350,10 +350,10 @@ class UnitGraph : public BaseHeteroGraph {
bool has_coo, bool has_coo,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! @return Return any existing format. */ /** @return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
/*! /**
* @brief Determine which format to use with a preference. * @brief Determine which format to use with a preference.
* *
* If the storage of unit graph is "locked", i.e. no conversion is allowed, then * If the storage of unit graph is "locked", i.e. no conversion is allowed, then
...@@ -364,24 +364,24 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -364,24 +364,24 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
SparseFormat SelectFormat(dgl_format_code_t preferred_formats) const; SparseFormat SelectFormat(dgl_format_code_t preferred_formats) const;
/*! @return Whether the graph is hypersparse */ /** @return Whether the graph is hypersparse */
bool IsHypersparse() const; bool IsHypersparse() const;
GraphPtr AsImmutableGraph() const override; GraphPtr AsImmutableGraph() const override;
// Graph stored in different format. We use an on-demand strategy: the format is // Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked. // only materialized if the operation that suitable for it is invoked.
/*! @brief CSR graph that stores reverse edges */ /** @brief CSR graph that stores reverse edges */
CSRPtr in_csr_; CSRPtr in_csr_;
/*! @brief CSR representation */ /** @brief CSR representation */
CSRPtr out_csr_; CSRPtr out_csr_;
/*! @brief COO representation */ /** @brief COO representation */
COOPtr coo_; COOPtr coo_;
/*! /**
* @brief Storage format restriction. * @brief Storage format restriction.
*/ */
dgl_format_code_t formats_; dgl_format_code_t formats_;
/*! @brief which streams have recorded the graph */ /** @brief which streams have recorded the graph */
std::vector<DGLStreamHandle> recorded_streams; std::vector<DGLStreamHandle> recorded_streams;
}; };
......
/*! /**
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* @file ndarray_partition.h * @file ndarray_partition.h
* @brief Operations on partition implemented in CUDA. * @brief Operations on partition implemented in CUDA.
......
/*! /**
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* @file ndarray_partition.cc * @file ndarray_partition.cc
* @brief DGL utilities for working with the partitioned NDArrays * @brief DGL utilities for working with the partitioned NDArrays
......
/*! /**
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* @file ndarray_partition.h * @file ndarray_partition.h
* @brief DGL utilities for working with the partitioned NDArrays * @brief DGL utilities for working with the partitioned NDArrays
......
/*! /**
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* @file ndarray_partition.h * @file ndarray_partition.h
* @brief DGL utilities for working with the partitioned NDArrays * @brief DGL utilities for working with the partitioned NDArrays
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file random/choice.cc * @file random/choice.cc
* @brief Non-uniform discrete sampling implementation * @brief Non-uniform discrete sampling implementation
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file dgl/sample_utils.h * @file dgl/sample_utils.h
* @brief Sampling utilities * @brief Sampling utilities
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
namespace dgl { namespace dgl {
namespace utils { namespace utils {
/*! @brief Base sampler class */ /** @brief Base sampler class */
template <typename Idx> template <typename Idx>
class BaseSampler { class BaseSampler {
public: public:
virtual ~BaseSampler() = default; virtual ~BaseSampler() = default;
/*! @brief Draw one integer sample */ /** @brief Draw one integer sample */
virtual Idx Draw() { virtual Idx Draw() {
LOG(INFO) << "Not implemented yet."; LOG(INFO) << "Not implemented yet.";
return 0; return 0;
...@@ -37,7 +37,7 @@ class BaseSampler { ...@@ -37,7 +37,7 @@ class BaseSampler {
// probability 0. DType could be uint8 in this case, which will give incorrect arithmetic // probability 0. DType could be uint8 in this case, which will give incorrect arithmetic
// results due to overflowing and/or integer division. // results due to overflowing and/or integer division.
/* /**
* AliasSampler is used to sample elements from a given discrete categorical distribution. * AliasSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method) * Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method)
* Sampler building complexity: O(n) * Sampler building complexity: O(n)
...@@ -165,7 +165,7 @@ class AliasSampler: public BaseSampler<Idx> { ...@@ -165,7 +165,7 @@ class AliasSampler: public BaseSampler<Idx> {
}; };
/* /**
* CDFSampler is used to sample elements from a given discrete categorical distribution. * CDFSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a cumulative distribution function and conduct binary search for sampling. * Algorithm: create a cumulative distribution function and conduct binary search for sampling.
* Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804 * Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
...@@ -256,7 +256,7 @@ class CDFSampler: public BaseSampler<Idx> { ...@@ -256,7 +256,7 @@ class CDFSampler: public BaseSampler<Idx> {
}; };
/* /**
* TreeSampler is used to sample elements from a given discrete categorical distribution. * TreeSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a heap that stores accumulated likelihood of its leaf descendents. * Algorithm: create a heap that stores accumulated likelihood of its leaf descendents.
* Reference: https://blog.smola.org/post/1016514759 * Reference: https://blog.smola.org/post/1016514759
......
/*! /**
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* @file random.cc * @file random.cc
* @brief Random number generator interfaces * @brief Random number generator interfaces
......
/*! /**
* Copyright (c) 2022 by Contributors * Copyright (c) 2022 by Contributors
* @file net_type.h * @file net_type.h
* @brief Base communicator for DGL distributed training. * @brief Base communicator for DGL distributed training.
...@@ -14,21 +14,21 @@ namespace dgl { ...@@ -14,21 +14,21 @@ namespace dgl {
namespace rpc { namespace rpc {
struct RPCBase { struct RPCBase {
/*! /**
* @brief Finalize Receiver * @brief Finalize Receiver
* *
* Finalize() is not thread-safe and only one thread can invoke this API. * Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
virtual void Finalize() = 0; virtual void Finalize() = 0;
/*! /**
* @brief Communicator type: 'socket', 'tensorpipe', etc * @brief Communicator type: 'socket', 'tensorpipe', etc
*/ */
virtual const std::string &NetType() const = 0; virtual const std::string &NetType() const = 0;
}; };
struct RPCSender : RPCBase { struct RPCSender : RPCBase {
/*! /**
* @brief Connect to a receiver. * @brief Connect to a receiver.
* *
* When there are multiple receivers to be connected, application will call * When there are multiple receivers to be connected, application will call
...@@ -44,7 +44,7 @@ struct RPCSender : RPCBase { ...@@ -44,7 +44,7 @@ struct RPCSender : RPCBase {
*/ */
virtual bool ConnectReceiver(const std::string &addr, int recv_id) = 0; virtual bool ConnectReceiver(const std::string &addr, int recv_id) = 0;
/*! /**
* @brief Finalize the action to connect to receivers. Make sure that either * @brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails. * all connections are successfully established or connection fails.
* @return True for success and False for fail * @return True for success and False for fail
...@@ -53,7 +53,7 @@ struct RPCSender : RPCBase { ...@@ -53,7 +53,7 @@ struct RPCSender : RPCBase {
*/ */
virtual bool ConnectReceiverFinalize(const int max_try_times) { return true; } virtual bool ConnectReceiverFinalize(const int max_try_times) { return true; }
/*! /**
* @brief Send RPCMessage to specified Receiver. * @brief Send RPCMessage to specified Receiver.
* @param msg data message * @param msg data message
* @param recv_id receiver's ID * @param recv_id receiver's ID
...@@ -62,7 +62,7 @@ struct RPCSender : RPCBase { ...@@ -62,7 +62,7 @@ struct RPCSender : RPCBase {
}; };
struct RPCReceiver : RPCBase { struct RPCReceiver : RPCBase {
/*! /**
* @brief Wait for all the Senders to connect * @brief Wait for all the Senders to connect
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0' * @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* @param num_sender total number of Senders * @param num_sender total number of Senders
...@@ -74,7 +74,7 @@ struct RPCReceiver : RPCBase { ...@@ -74,7 +74,7 @@ struct RPCReceiver : RPCBase {
virtual bool Wait( virtual bool Wait(
const std::string &addr, int num_sender, bool blocking = true) = 0; const std::string &addr, int num_sender, bool blocking = true) = 0;
/*! /**
* @brief Recv RPCMessage from Sender. Actually removing data from queue. * @brief Recv RPCMessage from Sender. Actually removing data from queue.
* @param msg pointer of RPCmessage * @param msg pointer of RPCmessage
* @param timeout The timeout value in milliseconds. If zero, wait * @param timeout The timeout value in milliseconds. If zero, wait
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file common.cc * @file common.cc
* @brief This file provide basic facilities for string * @brief This file provide basic facilities for string
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file common.h * @file common.h
* @brief This file provide basic facilities for string * @brief This file provide basic facilities for string
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file communicator.h * @file communicator.h
* @brief Communicator for DGL distributed training. * @brief Communicator for DGL distributed training.
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace dgl { namespace dgl {
namespace network { namespace network {
/*! /**
* @brief Network Sender for DGL distributed training. * @brief Network Sender for DGL distributed training.
* *
* Sender is an abstract class that defines a set of APIs for sending binary * Sender is an abstract class that defines a set of APIs for sending binary
...@@ -27,7 +27,7 @@ namespace network { ...@@ -27,7 +27,7 @@ namespace network {
*/ */
class Sender : public rpc::RPCSender { class Sender : public rpc::RPCSender {
public: public:
/*! /**
* @brief Sender constructor * @brief Sender constructor
* @param queue_size size (bytes) of message queue. * @param queue_size size (bytes) of message queue.
* @param max_thread_count size of thread pool. 0 for no limit * @param max_thread_count size of thread pool. 0 for no limit
...@@ -42,7 +42,7 @@ class Sender : public rpc::RPCSender { ...@@ -42,7 +42,7 @@ class Sender : public rpc::RPCSender {
virtual ~Sender() {} virtual ~Sender() {}
/*! /**
* @brief Send data to specified Receiver. * @brief Send data to specified Receiver.
* @param msg data message * @param msg data message
* @param recv_id receiver's ID * @param recv_id receiver's ID
...@@ -58,17 +58,17 @@ class Sender : public rpc::RPCSender { ...@@ -58,17 +58,17 @@ class Sender : public rpc::RPCSender {
virtual STATUS Send(Message msg, int recv_id) = 0; virtual STATUS Send(Message msg, int recv_id) = 0;
protected: protected:
/*! /**
* @brief Size of message queue * @brief Size of message queue
*/ */
int64_t queue_size_; int64_t queue_size_;
/*! /**
* @brief Size of thread pool. 0 for no limit * @brief Size of thread pool. 0 for no limit
*/ */
int max_thread_count_; int max_thread_count_;
}; };
/*! /**
* @brief Network Receiver for DGL distributed training. * @brief Network Receiver for DGL distributed training.
* *
* Receiver is an abstract class that defines a set of APIs for receiving binary * Receiver is an abstract class that defines a set of APIs for receiving binary
...@@ -79,7 +79,7 @@ class Sender : public rpc::RPCSender { ...@@ -79,7 +79,7 @@ class Sender : public rpc::RPCSender {
*/ */
class Receiver : public rpc::RPCReceiver { class Receiver : public rpc::RPCReceiver {
public: public:
/*! /**
* @brief Receiver constructor * @brief Receiver constructor
* @param queue_size size of message queue. * @param queue_size size of message queue.
* @param max_thread_count size of thread pool. 0 for no limit * @param max_thread_count size of thread pool. 0 for no limit
...@@ -96,7 +96,7 @@ class Receiver : public rpc::RPCReceiver { ...@@ -96,7 +96,7 @@ class Receiver : public rpc::RPCReceiver {
virtual ~Receiver() {} virtual ~Receiver() {}
/*! /**
* @brief Recv data from Sender * @brief Recv data from Sender
* @param msg pointer of data message * @param msg pointer of data message
* @param send_id which sender current msg comes from * @param send_id which sender current msg comes from
...@@ -110,7 +110,7 @@ class Receiver : public rpc::RPCReceiver { ...@@ -110,7 +110,7 @@ class Receiver : public rpc::RPCReceiver {
*/ */
virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0; virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0;
/*! /**
* @brief Recv data from a specified Sender * @brief Recv data from a specified Sender
* @param msg pointer of data message * @param msg pointer of data message
* @param send_id sender's ID * @param send_id sender's ID
...@@ -125,11 +125,11 @@ class Receiver : public rpc::RPCReceiver { ...@@ -125,11 +125,11 @@ class Receiver : public rpc::RPCReceiver {
virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0; virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0;
protected: protected:
/*! /**
* @brief Size of message queue * @brief Size of message queue
*/ */
int64_t queue_size_; int64_t queue_size_;
/*! /**
* @brief Size of thread pool. 0 for no limit * @brief Size of thread pool. 0 for no limit
*/ */
int max_thread_count_; int max_thread_count_;
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file msg_queue.cc * @file msg_queue.cc
* @brief Message queue for DGL distributed training. * @brief Message queue for DGL distributed training.
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file msg_queue.h * @file msg_queue.h
* @brief Message queue for DGL distributed training. * @brief Message queue for DGL distributed training.
...@@ -22,7 +22,7 @@ namespace network { ...@@ -22,7 +22,7 @@ namespace network {
typedef int STATUS; typedef int STATUS;
/*! /**
* @brief Status code of message queue * @brief Status code of message queue
*/ */
#define ADD_SUCCESS 3400 // Add message successfully #define ADD_SUCCESS 3400 // Add message successfully
...@@ -33,45 +33,45 @@ typedef int STATUS; ...@@ -33,45 +33,45 @@ typedef int STATUS;
#define REMOVE_SUCCESS 3405 // Remove message successfully #define REMOVE_SUCCESS 3405 // Remove message successfully
#define QUEUE_EMPTY 3406 // Cannot remove when queue is empty #define QUEUE_EMPTY 3406 // Cannot remove when queue is empty
/*! /**
* @brief Message used by network communicator and message queue. * @brief Message used by network communicator and message queue.
*/ */
struct Message { struct Message {
/*! /**
* @brief Constructor * @brief Constructor
*/ */
Message() {} Message() {}
/*! /**
* @brief Constructor * @brief Constructor
*/ */
Message(char* data_ptr, int64_t data_size) Message(char* data_ptr, int64_t data_size)
: data(data_ptr), size(data_size) {} : data(data_ptr), size(data_size) {}
/*! /**
* @brief message data * @brief message data
*/ */
char* data; char* data;
/*! /**
* @brief message size in bytes * @brief message size in bytes
*/ */
int64_t size; int64_t size;
/*! /**
* @brief message receiver id * @brief message receiver id
*/ */
int receiver_id = -1; int receiver_id = -1;
/*! /**
* @brief user-defined deallocator, which can be nullptr * @brief user-defined deallocator, which can be nullptr
*/ */
std::function<void(Message*)> deallocator = nullptr; std::function<void(Message*)> deallocator = nullptr;
}; };
/*! /**
* @brief Free memory buffer of message * @brief Free memory buffer of message
*/ */
inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; } inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }
/*! /**
* @brief Message Queue for network communication. * @brief Message Queue for network communication.
* *
* MessageQueue is FIFO queue that adopts producer/consumer model for data * MessageQueue is FIFO queue that adopts producer/consumer model for data
...@@ -89,7 +89,7 @@ inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; } ...@@ -89,7 +89,7 @@ inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }
*/ */
class MessageQueue { class MessageQueue {
public: public:
/*! /**
* @brief MessageQueue constructor * @brief MessageQueue constructor
* @param queue_size size (bytes) of message queue * @param queue_size size (bytes) of message queue
* @param num_producers number of producers, use 1 by default * @param num_producers number of producers, use 1 by default
...@@ -97,12 +97,12 @@ class MessageQueue { ...@@ -97,12 +97,12 @@ class MessageQueue {
explicit MessageQueue( explicit MessageQueue(
int64_t queue_size /* in bytes */, int num_producers = 1); int64_t queue_size /* in bytes */, int num_producers = 1);
/*! /**
* @brief MessageQueue deconstructor * @brief MessageQueue deconstructor
*/ */
~MessageQueue() {} ~MessageQueue() {}
/*! /**
* @brief Add message to the queue * @brief Add message to the queue
* @param msg data message * @param msg data message
* @param is_blocking Blocking if cannot add, else return * @param is_blocking Blocking if cannot add, else return
...@@ -110,7 +110,7 @@ class MessageQueue { ...@@ -110,7 +110,7 @@ class MessageQueue {
*/ */
STATUS Add(Message msg, bool is_blocking = true); STATUS Add(Message msg, bool is_blocking = true);
/*! /**
* @brief Remove message from the queue * @brief Remove message from the queue
* @param msg pointer of data msg * @param msg pointer of data msg
* @param is_blocking Blocking if cannot remove, else return * @param is_blocking Blocking if cannot remove, else return
...@@ -118,64 +118,64 @@ class MessageQueue { ...@@ -118,64 +118,64 @@ class MessageQueue {
*/ */
STATUS Remove(Message* msg, bool is_blocking = true); STATUS Remove(Message* msg, bool is_blocking = true);
/*! /**
* @brief Signal that producer producer_id will no longer produce anything * @brief Signal that producer producer_id will no longer produce anything
* @param producer_id An integer uniquely to identify a producer thread * @param producer_id An integer uniquely to identify a producer thread
*/ */
void SignalFinished(int producer_id); void SignalFinished(int producer_id);
/*! /**
* @return true if queue is empty. * @return true if queue is empty.
*/ */
bool Empty() const; bool Empty() const;
/*! /**
* @return true if queue is empty and all num_producers have signaled. * @return true if queue is empty and all num_producers have signaled.
*/ */
bool EmptyAndNoMoreAdd() const; bool EmptyAndNoMoreAdd() const;
protected: protected:
/*! /**
* @brief message queue * @brief message queue
*/ */
std::queue<Message> queue_; std::queue<Message> queue_;
/*! /**
* @brief Size of the queue in bytes * @brief Size of the queue in bytes
*/ */
int64_t queue_size_; int64_t queue_size_;
/*! /**
* @brief Free size of the queue * @brief Free size of the queue
*/ */
int64_t free_size_; int64_t free_size_;
/*! /**
* @brief Used to check all producers will no longer produce anything * @brief Used to check all producers will no longer produce anything
*/ */
size_t num_producers_; size_t num_producers_;
/*! /**
* @brief Store finished producer id * @brief Store finished producer id
*/ */
std::set<int /* producer_id */> finished_producers_; std::set<int /* producer_id */> finished_producers_;
/*! /**
* @brief Condition when consumer should wait * @brief Condition when consumer should wait
*/ */
std::condition_variable cond_not_full_; std::condition_variable cond_not_full_;
/*! /**
* @brief Condition when producer should wait * @brief Condition when producer should wait
*/ */
std::condition_variable cond_not_empty_; std::condition_variable cond_not_empty_;
/*! /**
* @brief Signal for exit wait * @brief Signal for exit wait
*/ */
std::atomic<bool> exit_flag_{false}; std::atomic<bool> exit_flag_{false};
/*! /**
* @brief Protect all above data and conditions * @brief Protect all above data and conditions
*/ */
mutable std::mutex mutex_; mutable std::mutex mutex_;
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file communicator.cc * @file communicator.cc
* @brief SocketCommunicator for DGL distributed training. * @brief SocketCommunicator for DGL distributed training.
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file communicator.h * @file communicator.h
* @brief SocketCommunicator for DGL distributed training. * @brief SocketCommunicator for DGL distributed training.
...@@ -25,7 +25,7 @@ static constexpr int kTimeOut = ...@@ -25,7 +25,7 @@ static constexpr int kTimeOut =
10 * 60; // 10 minutes (in seconds) for socket timeout 10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024 static constexpr int kMaxConnection = 1024; // maximal connection: 1024
/*! /**
* @breif Networking address * @breif Networking address
*/ */
struct IPAddr { struct IPAddr {
...@@ -33,14 +33,14 @@ struct IPAddr { ...@@ -33,14 +33,14 @@ struct IPAddr {
int port; int port;
}; };
/*! /**
* @brief SocketSender for DGL distributed training. * @brief SocketSender for DGL distributed training.
* *
* SocketSender is the communicator implemented by tcp socket. * SocketSender is the communicator implemented by tcp socket.
*/ */
class SocketSender : public Sender { class SocketSender : public Sender {
public: public:
/*! /**
* @brief Sender constructor * @brief Sender constructor
* @param queue_size size of message queue * @param queue_size size of message queue
* @param max_thread_count size of thread pool. 0 for no limit * @param max_thread_count size of thread pool. 0 for no limit
...@@ -48,7 +48,7 @@ class SocketSender : public Sender { ...@@ -48,7 +48,7 @@ class SocketSender : public Sender {
SocketSender(int64_t queue_size, int max_thread_count) SocketSender(int64_t queue_size, int max_thread_count)
: Sender(queue_size, max_thread_count) {} : Sender(queue_size, max_thread_count) {}
/*! /**
* @brief Connect to a receiver. * @brief Connect to a receiver.
* *
* When there are multiple receivers to be connected, application will call * When there are multiple receivers to be connected, application will call
...@@ -64,7 +64,7 @@ class SocketSender : public Sender { ...@@ -64,7 +64,7 @@ class SocketSender : public Sender {
*/ */
bool ConnectReceiver(const std::string& addr, int recv_id) override; bool ConnectReceiver(const std::string& addr, int recv_id) override;
/*! /**
* @brief Finalize the action to connect to receivers. Make sure that either * @brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails. * all connections are successfully established or connection fails.
* @return True for success and False for fail * @return True for success and False for fail
...@@ -73,19 +73,19 @@ class SocketSender : public Sender { ...@@ -73,19 +73,19 @@ class SocketSender : public Sender {
*/ */
bool ConnectReceiverFinalize(const int max_try_times) override; bool ConnectReceiverFinalize(const int max_try_times) override;
/*! /**
* @brief Send RPCMessage to specified Receiver. * @brief Send RPCMessage to specified Receiver.
* @param msg data message * @param msg data message
* @param recv_id receiver's ID * @param recv_id receiver's ID
*/ */
void Send(const rpc::RPCMessage& msg, int recv_id) override; void Send(const rpc::RPCMessage& msg, int recv_id) override;
/*! /**
* @brief Finalize TPSender * @brief Finalize TPSender
*/ */
void Finalize() override; void Finalize() override;
/*! /**
* @brief Communicator type: 'socket' * @brief Communicator type: 'socket'
*/ */
const std::string& NetType() const override { const std::string& NetType() const override {
...@@ -93,7 +93,7 @@ class SocketSender : public Sender { ...@@ -93,7 +93,7 @@ class SocketSender : public Sender {
return net_type; return net_type;
} }
/*! /**
* @brief Send data to specified Receiver. Actually pushing message to message * @brief Send data to specified Receiver. Actually pushing message to message
* queue. * queue.
* @param msg data message. * @param msg data message.
...@@ -110,29 +110,29 @@ class SocketSender : public Sender { ...@@ -110,29 +110,29 @@ class SocketSender : public Sender {
STATUS Send(Message msg, int recv_id) override; STATUS Send(Message msg, int recv_id) override;
private: private:
/*! /**
* @brief socket for each connection of receiver * @brief socket for each connection of receiver
*/ */
std::vector< std::vector<
std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>> std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>>
sockets_; sockets_;
/*! /**
* @brief receivers' address * @brief receivers' address
*/ */
std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_; std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*! /**
* @brief message queue for each thread * @brief message queue for each thread
*/ */
std::vector<std::shared_ptr<MessageQueue>> msg_queue_; std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
/*! /**
* @brief Independent thread * @brief Independent thread
*/ */
std::vector<std::shared_ptr<std::thread>> threads_; std::vector<std::shared_ptr<std::thread>> threads_;
/*! /**
* @brief Send-loop for each thread * @brief Send-loop for each thread
* @param sockets TCPSockets for current thread * @param sockets TCPSockets for current thread
* @param queue message_queue for current thread * @param queue message_queue for current thread
...@@ -147,14 +147,14 @@ class SocketSender : public Sender { ...@@ -147,14 +147,14 @@ class SocketSender : public Sender {
std::shared_ptr<MessageQueue> queue); std::shared_ptr<MessageQueue> queue);
}; };
/*! /**
* @brief SocketReceiver for DGL distributed training. * @brief SocketReceiver for DGL distributed training.
* *
* SocketReceiver is the communicator implemented by tcp socket. * SocketReceiver is the communicator implemented by tcp socket.
*/ */
class SocketReceiver : public Receiver { class SocketReceiver : public Receiver {
public: public:
/*! /**
* @brief Receiver constructor * @brief Receiver constructor
* @param queue_size size of message queue. * @param queue_size size of message queue.
* @param max_thread_count size of thread pool. 0 for no limit * @param max_thread_count size of thread pool. 0 for no limit
...@@ -162,7 +162,7 @@ class SocketReceiver : public Receiver { ...@@ -162,7 +162,7 @@ class SocketReceiver : public Receiver {
SocketReceiver(int64_t queue_size, int max_thread_count) SocketReceiver(int64_t queue_size, int max_thread_count)
: Receiver(queue_size, max_thread_count) {} : Receiver(queue_size, max_thread_count) {}
/*! /**
* @brief Wait for all the Senders to connect * @brief Wait for all the Senders to connect
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0' * @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* @param num_sender total number of Senders * @param num_sender total number of Senders
...@@ -174,7 +174,7 @@ class SocketReceiver : public Receiver { ...@@ -174,7 +174,7 @@ class SocketReceiver : public Receiver {
bool Wait( bool Wait(
const std::string& addr, int num_sender, bool blocking = true) override; const std::string& addr, int num_sender, bool blocking = true) override;
/*! /**
* @brief Recv RPCMessage from Sender. Actually removing data from queue. * @brief Recv RPCMessage from Sender. Actually removing data from queue.
* @param msg pointer of RPCmessage * @param msg pointer of RPCmessage
* @param timeout The timeout value in milliseconds. If zero, wait * @param timeout The timeout value in milliseconds. If zero, wait
...@@ -183,7 +183,7 @@ class SocketReceiver : public Receiver { ...@@ -183,7 +183,7 @@ class SocketReceiver : public Receiver {
*/ */
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override; rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
/*! /**
* @brief Recv data from Sender. Actually removing data from msg_queue. * @brief Recv data from Sender. Actually removing data from msg_queue.
* @param msg pointer of data message * @param msg pointer of data message
* @param send_id which sender current msg comes from * @param send_id which sender current msg comes from
...@@ -197,7 +197,7 @@ class SocketReceiver : public Receiver { ...@@ -197,7 +197,7 @@ class SocketReceiver : public Receiver {
*/ */
STATUS Recv(Message* msg, int* send_id, int timeout = 0) override; STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
/*! /**
* @brief Recv data from a specified Sender. Actually removing data from * @brief Recv data from a specified Sender. Actually removing data from
* msg_queue. * msg_queue.
* @param msg pointer of data message. * @param msg pointer of data message.
...@@ -212,14 +212,14 @@ class SocketReceiver : public Receiver { ...@@ -212,14 +212,14 @@ class SocketReceiver : public Receiver {
*/ */
STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override; STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
/*! /**
* @brief Finalize SocketReceiver * @brief Finalize SocketReceiver
* *
* Finalize() is not thread-safe and only one thread can invoke this API. * Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
void Finalize() override; void Finalize() override;
/*! /**
* @brief Communicator type: 'socket' * @brief Communicator type: 'socket'
*/ */
const std::string& NetType() const override { const std::string& NetType() const override {
...@@ -233,24 +233,24 @@ class SocketReceiver : public Receiver { ...@@ -233,24 +233,24 @@ class SocketReceiver : public Receiver {
int64_t received_bytes = 0; int64_t received_bytes = 0;
char* buffer = nullptr; char* buffer = nullptr;
}; };
/*! /**
* @brief number of sender * @brief number of sender
*/ */
int num_sender_; int num_sender_;
/*! /**
* @brief server socket for listening connections * @brief server socket for listening connections
*/ */
TCPSocket* server_socket_; TCPSocket* server_socket_;
/*! /**
* @brief socket for each client connections * @brief socket for each client connections
*/ */
std::vector<std::unordered_map< std::vector<std::unordered_map<
int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>> int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>>
sockets_; sockets_;
/*! /**
* @brief Message queue for each socket connection * @brief Message queue for each socket connection
*/ */
std::unordered_map< std::unordered_map<
...@@ -258,18 +258,18 @@ class SocketReceiver : public Receiver { ...@@ -258,18 +258,18 @@ class SocketReceiver : public Receiver {
msg_queue_; msg_queue_;
std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_; std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
/*! /**
* @brief Independent thead * @brief Independent thead
*/ */
std::vector<std::shared_ptr<std::thread>> threads_; std::vector<std::shared_ptr<std::thread>> threads_;
/*! /**
* @brief queue_sem_ semphore to indicate number of messages in multiple * @brief queue_sem_ semphore to indicate number of messages in multiple
* message queues to prevent busy wait of Recv * message queues to prevent busy wait of Recv
*/ */
runtime::Semaphore queue_sem_; runtime::Semaphore queue_sem_;
/*! /**
* @brief Recv-loop for each thread * @brief Recv-loop for each thread
* @param sockets client sockets of current thread * @param sockets client sockets of current thread
* @param queue message queues of current thread * @param queue message queues of current thread
......
/*! /**
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* @file socket_pool.cc * @file socket_pool.cc
* @brief Socket pool of nonblocking sockets for DGL distributed training. * @brief Socket pool of nonblocking sockets for DGL distributed training.
......
/*! /**
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* @file socket_pool.h * @file socket_pool.h
* @brief Socket pool of nonblocking sockets for DGL distributed training. * @brief Socket pool of nonblocking sockets for DGL distributed training.
...@@ -15,7 +15,7 @@ namespace network { ...@@ -15,7 +15,7 @@ namespace network {
class TCPSocket; class TCPSocket;
/*! /**
* @brief SocketPool maintains a group of nonblocking sockets, and can provide * @brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets. * active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification * Currently SocketPool is based on epoll, a scalable I/O event notification
...@@ -23,20 +23,20 @@ class TCPSocket; ...@@ -23,20 +23,20 @@ class TCPSocket;
*/ */
class SocketPool { class SocketPool {
public: public:
/*! /**
* @brief socket mode read/receive * @brief socket mode read/receive
*/ */
static const int READ = 1; static const int READ = 1;
/*! /**
* @brief socket mode write/send * @brief socket mode write/send
*/ */
static const int WRITE = 2; static const int WRITE = 2;
/*! /**
* @brief SocketPool constructor * @brief SocketPool constructor
*/ */
SocketPool(); SocketPool();
/*! /**
* @brief Add a socket to SocketPool * @brief Add a socket to SocketPool
* @param socket tcp socket to add * @param socket tcp socket to add
* @param socket_id receiver/sender id of the socket * @param socket_id receiver/sender id of the socket
...@@ -45,19 +45,19 @@ class SocketPool { ...@@ -45,19 +45,19 @@ class SocketPool {
void AddSocket( void AddSocket(
std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ); std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ);
/*! /**
* @brief Remove socket from SocketPool * @brief Remove socket from SocketPool
* @param socket tcp socket to remove * @param socket tcp socket to remove
* @return number of remaing sockets in the pool * @return number of remaing sockets in the pool
*/ */
size_t RemoveSocket(std::shared_ptr<TCPSocket> socket); size_t RemoveSocket(std::shared_ptr<TCPSocket> socket);
/*! /**
* @brief SocketPool destructor * @brief SocketPool destructor
*/ */
~SocketPool(); ~SocketPool();
/*! /**
* @brief Get current active socket. This is a blocking method * @brief Get current active socket. This is a blocking method
* @param socket_id output parameter of the socket_id of active socket * @param socket_id output parameter of the socket_id of active socket
* @return active TCPSocket * @return active TCPSocket
...@@ -65,27 +65,27 @@ class SocketPool { ...@@ -65,27 +65,27 @@ class SocketPool {
std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id); std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id);
private: private:
/*! /**
* @brief Wait for event notification * @brief Wait for event notification
*/ */
void Wait(); void Wait();
/*! /**
* @brief map from fd to TCPSocket * @brief map from fd to TCPSocket
*/ */
std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_; std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_;
/*! /**
* @brief map from fd to socket_id * @brief map from fd to socket_id
*/ */
std::unordered_map<int, int> socket_ids_; std::unordered_map<int, int> socket_ids_;
/*! /**
* @brief fd for epoll base * @brief fd for epoll base
*/ */
int epfd_; int epfd_;
/*! /**
* @brief queue for current active fds * @brief queue for current active fds
*/ */
std::queue<int> pending_fds_; std::queue<int> pending_fds_;
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file tcp_socket.cc * @file tcp_socket.cc
* @brief TCP socket for DGL distributed training. * @brief TCP socket for DGL distributed training.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment