Unverified Commit bcd37684 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Replace /*! with /**. (#4823)



* replace

* blabla

* balbla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 619d735d
/*!
/**
* Copyright (c) 2019 by Contributors
* @file graph/unit_graph.cc
* @brief UnitGraph graph implementation
......@@ -161,17 +161,17 @@ class UnitGraph::COO : public BaseHeteroGraph {
}
/*! @brief Pin the adj_: COOMatrix of the COO graph. */
/** @brief Pin the adj_: COOMatrix of the COO graph. */
void PinMemory_() {
adj_.PinMemory_();
}
/*! @brief Unpin the adj_: COOMatrix of the COO graph. */
/** @brief Unpin the adj_: COOMatrix of the COO graph. */
void UnpinMemory_() {
adj_.UnpinMemory_();
}
/*! @brief Record stream for the adj_: COOMatrix of the COO graph. */
/** @brief Record stream for the adj_: COOMatrix of the COO graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
......@@ -432,7 +432,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
return adj_;
}
/*!
/**
* @brief Determines whether the graph is "hypersparse", i.e. having significantly more
* nodes than edges.
*/
......@@ -457,7 +457,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
private:
friend class Serializer;
/*! @brief internal adjacency matrix. Data array is empty */
/** @brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_;
};
......@@ -467,7 +467,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
//
//////////////////////////////////////////////////////////
/*! @brief CSR graph */
/** @brief CSR graph */
class UnitGraph::CSR : public BaseHeteroGraph {
public:
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
......@@ -571,17 +571,17 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
}
/*! @brief Pin the adj_: CSRMatrix of the CSR graph. */
/** @brief Pin the adj_: CSRMatrix of the CSR graph. */
void PinMemory_() {
adj_.PinMemory_();
}
/*! @brief Unpin the adj_: CSRMatrix of the CSR graph. */
/** @brief Unpin the adj_: CSRMatrix of the CSR graph. */
void UnpinMemory_() {
adj_.UnpinMemory_();
}
/*! @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
/** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
......@@ -851,7 +851,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
private:
friend class Serializer;
/*! @brief internal adjacency matrix. Data array stores edge ids */
/** @brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_;
};
......@@ -1433,7 +1433,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
return ret;
}
/* !\brief Return out csr. If not exist, transpose the other one.*/
/** @brief Return out csr. If not exist, transpose the other one.*/
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
if (inplace)
if (!(formats_ & CSR_CODE))
......@@ -1469,7 +1469,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
return ret;
}
/* !\brief Return coo. If not exist, create from csr.*/
/** @brief Return coo. If not exist, create from csr.*/
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
if (inplace)
if (!(formats_ & COO_CODE))
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file graph/unit_graph.h
* @brief UnitGraph graph
......@@ -26,7 +26,7 @@ class HeteroGraph;
class UnitGraph;
typedef std::shared_ptr<UnitGraph> UnitGraphPtr;
/*!
/**
* @brief UnitGraph graph
*
* UnitGraph graph is a special type of heterograph which
......@@ -164,7 +164,7 @@ class UnitGraph : public BaseHeteroGraph {
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
// creators
/*! @brief Create a graph with no edges */
/** @brief Create a graph with no edges */
static HeteroGraphPtr Empty(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
DGLDataType dtype, DGLContext ctx) {
......@@ -173,7 +173,7 @@ class UnitGraph : public BaseHeteroGraph {
return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
}
/*! @brief Create a graph from COO arrays */
/** @brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, bool row_sorted = false,
......@@ -183,7 +183,7 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats = ALL_CODE);
/*! @brief Create a graph from (out) CSR arrays */
/** @brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
......@@ -193,7 +193,7 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE);
/*! @brief Create a graph from (in) CSC arrays */
/** @brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
......@@ -203,13 +203,13 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE);
/*! @brief Convert the graph to use the given number of bits for storage */
/** @brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/*! @brief Copy the data to another context */
/** @brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx);
/*!
/**
* @brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be pinned inplace. Behavior depends on the current context,
* kDGLCPU: will be pinned;
......@@ -219,7 +219,7 @@ class UnitGraph : public BaseHeteroGraph {
*/
void PinMemory_() override;
/*!
/**
* @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be unpinned inplace. Behavior depends on the current context,
* IsPinned: will be unpinned;
......@@ -228,13 +228,13 @@ class UnitGraph : public BaseHeteroGraph {
*/
void UnpinMemory_();
/*!
/**
* @brief Record stream for this graph.
* @param stream The stream that is using the graph
*/
void RecordStream(DGLStreamHandle stream) override;
/*!
/**
* @brief Create in-edge CSR format of the unit graph.
* @param inplace if true and the in-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
......@@ -242,7 +242,7 @@ class UnitGraph : public BaseHeteroGraph {
*/
CSRPtr GetInCSR(bool inplace = true) const;
/*!
/**
* @brief Create out-edge CSR format of the unit graph.
* @param inplace if true and the out-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
......@@ -250,7 +250,7 @@ class UnitGraph : public BaseHeteroGraph {
*/
CSRPtr GetOutCSR(bool inplace = true) const;
/*!
/**
* @brief Create COO format of the unit graph.
* @param inplace if true and the COO format does not exist, the created
* format will be cached in this object unless the format is restricted.
......@@ -258,20 +258,20 @@ class UnitGraph : public BaseHeteroGraph {
*/
COOPtr GetCOO(bool inplace = true) const;
/*! @return Return the COO matrix form */
/** @return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;
/*! @return Return the in-edge CSC in the matrix form */
/** @return Return the in-edge CSC in the matrix form */
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override;
/*! @return Return the out-edge CSR in the matrix form */
/** @return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return SelectFormat(preferred_formats);
}
/*!
/**
* @brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist.
*
......@@ -285,19 +285,19 @@ class UnitGraph : public BaseHeteroGraph {
HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
/*! @return Load UnitGraph from stream, using CSRMatrix*/
/** @return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);
/*! @return Save UnitGraph to stream, using CSRMatrix */
/** @return Save UnitGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;
/*! @brief Creat a LineGraph of self */
/** @brief Creat a LineGraph of self */
HeteroGraphPtr LineGraph(bool backtracking) const;
/*! @return the reversed graph */
/** @return the reversed graph */
UnitGraphPtr Reverse() const;
/*! @return the simpled (no-multi-edge) graph
/** @return the simpled (no-multi-edge) graph
* the count recording the number of duplicated edges from the original graph.
* the edge mapping from the edge IDs of original graph to those of the
* returned graph.
......@@ -319,7 +319,7 @@ class UnitGraph : public BaseHeteroGraph {
// private empty constructor
UnitGraph() {}
/*!
/**
* @brief constructor
* @param metagraph metagraph
* @param in_csr in edge csr
......@@ -329,7 +329,7 @@ class UnitGraph : public BaseHeteroGraph {
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
dgl_format_code_t formats = ALL_CODE);
/*!
/**
* @brief constructor
* @param num_vtypes number of vertex types (1 or 2)
* @param metagraph metagraph
......@@ -350,10 +350,10 @@ class UnitGraph : public BaseHeteroGraph {
bool has_coo,
dgl_format_code_t formats = ALL_CODE);
/*! @return Return any existing format. */
/** @return Return any existing format. */
HeteroGraphPtr GetAny() const;
/*!
/**
* @brief Determine which format to use with a preference.
*
* If the storage of unit graph is "locked", i.e. no conversion is allowed, then
......@@ -364,24 +364,24 @@ class UnitGraph : public BaseHeteroGraph {
*/
SparseFormat SelectFormat(dgl_format_code_t preferred_formats) const;
/*! @return Whether the graph is hypersparse */
/** @return Whether the graph is hypersparse */
bool IsHypersparse() const;
GraphPtr AsImmutableGraph() const override;
// Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked.
/*! @brief CSR graph that stores reverse edges */
/** @brief CSR graph that stores reverse edges */
CSRPtr in_csr_;
/*! @brief CSR representation */
/** @brief CSR representation */
CSRPtr out_csr_;
/*! @brief COO representation */
/** @brief COO representation */
COOPtr coo_;
/*!
/**
* @brief Storage format restriction.
*/
dgl_format_code_t formats_;
/*! @brief which streams have recorded the graph */
/** @brief which streams have recorded the graph */
std::vector<DGLStreamHandle> recorded_streams;
};
......
/*!
/**
* Copyright (c) 2021 by Contributors
* @file ndarray_partition.h
* @brief Operations on partition implemented in CUDA.
......
/*!
/**
* Copyright (c) 2021 by Contributors
* @file ndarray_partition.cc
* @brief DGL utilities for working with the partitioned NDArrays
......
/*!
/**
* Copyright (c) 2021 by Contributors
* @file ndarray_partition.h
* @brief DGL utilities for working with the partitioned NDArrays
......
/*!
/**
* Copyright (c) 2021 by Contributors
* @file ndarray_partition.h
* @brief DGL utilities for working with the partitioned NDArrays
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file random/choice.cc
* @brief Non-uniform discrete sampling implementation
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file dgl/sample_utils.h
* @brief Sampling utilities
......@@ -20,12 +20,12 @@
namespace dgl {
namespace utils {
/*! @brief Base sampler class */
/** @brief Base sampler class */
template <typename Idx>
class BaseSampler {
public:
virtual ~BaseSampler() = default;
/*! @brief Draw one integer sample */
/** @brief Draw one integer sample */
virtual Idx Draw() {
LOG(INFO) << "Not implemented yet.";
return 0;
......@@ -37,7 +37,7 @@ class BaseSampler {
// probability 0. DType could be uint8 in this case, which will give incorrect arithmetic
// results due to overflowing and/or integer division.
/*
/**
* AliasSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method)
* Sampler building complexity: O(n)
......@@ -165,7 +165,7 @@ class AliasSampler: public BaseSampler<Idx> {
};
/*
/**
* CDFSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a cumulative distribution function and conduct binary search for sampling.
* Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
......@@ -256,7 +256,7 @@ class CDFSampler: public BaseSampler<Idx> {
};
/*
/**
* TreeSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a heap that stores accumulated likelihood of its leaf descendents.
* Reference: https://blog.smola.org/post/1016514759
......
/*!
/**
* Copyright (c) 2017 by Contributors
* @file random.cc
* @brief Random number generator interfaces
......
/*!
/**
* Copyright (c) 2022 by Contributors
* @file net_type.h
* @brief Base communicator for DGL distributed training.
......@@ -14,21 +14,21 @@ namespace dgl {
namespace rpc {
struct RPCBase {
/*!
/**
* @brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
/**
* @brief Communicator type: 'socket', 'tensorpipe', etc
*/
virtual const std::string &NetType() const = 0;
};
struct RPCSender : RPCBase {
/*!
/**
* @brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call
......@@ -44,7 +44,7 @@ struct RPCSender : RPCBase {
*/
virtual bool ConnectReceiver(const std::string &addr, int recv_id) = 0;
/*!
/**
* @brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* @return True for success and False for fail
......@@ -53,7 +53,7 @@ struct RPCSender : RPCBase {
*/
virtual bool ConnectReceiverFinalize(const int max_try_times) { return true; }
/*!
/**
* @brief Send RPCMessage to specified Receiver.
* @param msg data message
* @param recv_id receiver's ID
......@@ -62,7 +62,7 @@ struct RPCSender : RPCBase {
};
struct RPCReceiver : RPCBase {
/*!
/**
* @brief Wait for all the Senders to connect
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* @param num_sender total number of Senders
......@@ -74,7 +74,7 @@ struct RPCReceiver : RPCBase {
virtual bool Wait(
const std::string &addr, int num_sender, bool blocking = true) = 0;
/*!
/**
* @brief Recv RPCMessage from Sender. Actually removing data from queue.
* @param msg pointer of RPCmessage
* @param timeout The timeout value in milliseconds. If zero, wait
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file common.cc
* @brief This file provide basic facilities for string
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file common.h
* @brief This file provide basic facilities for string
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file communicator.h
* @brief Communicator for DGL distributed training.
......@@ -16,7 +16,7 @@
namespace dgl {
namespace network {
/*!
/**
* @brief Network Sender for DGL distributed training.
*
* Sender is an abstract class that defines a set of APIs for sending binary
......@@ -27,7 +27,7 @@ namespace network {
*/
class Sender : public rpc::RPCSender {
public:
/*!
/**
* @brief Sender constructor
* @param queue_size size (bytes) of message queue.
* @param max_thread_count size of thread pool. 0 for no limit
......@@ -42,7 +42,7 @@ class Sender : public rpc::RPCSender {
virtual ~Sender() {}
/*!
/**
* @brief Send data to specified Receiver.
* @param msg data message
* @param recv_id receiver's ID
......@@ -58,17 +58,17 @@ class Sender : public rpc::RPCSender {
virtual STATUS Send(Message msg, int recv_id) = 0;
protected:
/*!
/**
* @brief Size of message queue
*/
int64_t queue_size_;
/*!
/**
* @brief Size of thread pool. 0 for no limit
*/
int max_thread_count_;
};
/*!
/**
* @brief Network Receiver for DGL distributed training.
*
* Receiver is an abstract class that defines a set of APIs for receiving binary
......@@ -79,7 +79,7 @@ class Sender : public rpc::RPCSender {
*/
class Receiver : public rpc::RPCReceiver {
public:
/*!
/**
* @brief Receiver constructor
* @param queue_size size of message queue.
* @param max_thread_count size of thread pool. 0 for no limit
......@@ -96,7 +96,7 @@ class Receiver : public rpc::RPCReceiver {
virtual ~Receiver() {}
/*!
/**
* @brief Recv data from Sender
* @param msg pointer of data message
* @param send_id which sender current msg comes from
......@@ -110,7 +110,7 @@ class Receiver : public rpc::RPCReceiver {
*/
virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0;
/*!
/**
* @brief Recv data from a specified Sender
* @param msg pointer of data message
* @param send_id sender's ID
......@@ -125,11 +125,11 @@ class Receiver : public rpc::RPCReceiver {
virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0;
protected:
/*!
/**
* @brief Size of message queue
*/
int64_t queue_size_;
/*!
/**
* @brief Size of thread pool. 0 for no limit
*/
int max_thread_count_;
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file msg_queue.cc
* @brief Message queue for DGL distributed training.
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file msg_queue.h
* @brief Message queue for DGL distributed training.
......@@ -22,7 +22,7 @@ namespace network {
typedef int STATUS;
/*!
/**
* @brief Status code of message queue
*/
#define ADD_SUCCESS 3400 // Add message successfully
......@@ -33,45 +33,45 @@ typedef int STATUS;
#define REMOVE_SUCCESS 3405 // Remove message successfully
#define QUEUE_EMPTY 3406 // Cannot remove when queue is empty
/*!
/**
* @brief Message used by network communicator and message queue.
*/
struct Message {
/*!
/**
* @brief Constructor
*/
Message() {}
/*!
/**
* @brief Constructor
*/
Message(char* data_ptr, int64_t data_size)
: data(data_ptr), size(data_size) {}
/*!
/**
* @brief message data
*/
char* data;
/*!
/**
* @brief message size in bytes
*/
int64_t size;
/*!
/**
* @brief message receiver id
*/
int receiver_id = -1;
/*!
/**
* @brief user-defined deallocator, which can be nullptr
*/
std::function<void(Message*)> deallocator = nullptr;
};
/*!
/**
* @brief Free memory buffer of message
*/
inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }
/*!
/**
* @brief Message Queue for network communication.
*
* MessageQueue is FIFO queue that adopts producer/consumer model for data
......@@ -89,7 +89,7 @@ inline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }
*/
class MessageQueue {
public:
/*!
/**
* @brief MessageQueue constructor
* @param queue_size size (bytes) of message queue
* @param num_producers number of producers, use 1 by default
......@@ -97,12 +97,12 @@ class MessageQueue {
explicit MessageQueue(
int64_t queue_size /* in bytes */, int num_producers = 1);
/*!
/**
* @brief MessageQueue deconstructor
*/
~MessageQueue() {}
/*!
/**
* @brief Add message to the queue
* @param msg data message
* @param is_blocking Blocking if cannot add, else return
......@@ -110,7 +110,7 @@ class MessageQueue {
*/
STATUS Add(Message msg, bool is_blocking = true);
/*!
/**
* @brief Remove message from the queue
* @param msg pointer of data msg
* @param is_blocking Blocking if cannot remove, else return
......@@ -118,64 +118,64 @@ class MessageQueue {
*/
STATUS Remove(Message* msg, bool is_blocking = true);
/*!
/**
* @brief Signal that producer producer_id will no longer produce anything
* @param producer_id An integer uniquely to identify a producer thread
*/
void SignalFinished(int producer_id);
/*!
/**
* @return true if queue is empty.
*/
bool Empty() const;
/*!
/**
* @return true if queue is empty and all num_producers have signaled.
*/
bool EmptyAndNoMoreAdd() const;
protected:
/*!
/**
* @brief message queue
*/
std::queue<Message> queue_;
/*!
/**
* @brief Size of the queue in bytes
*/
int64_t queue_size_;
/*!
/**
* @brief Free size of the queue
*/
int64_t free_size_;
/*!
/**
* @brief Used to check all producers will no longer produce anything
*/
size_t num_producers_;
/*!
/**
* @brief Store finished producer id
*/
std::set<int /* producer_id */> finished_producers_;
/*!
/**
* @brief Condition when consumer should wait
*/
std::condition_variable cond_not_full_;
/*!
/**
* @brief Condition when producer should wait
*/
std::condition_variable cond_not_empty_;
/*!
/**
* @brief Signal for exit wait
*/
std::atomic<bool> exit_flag_{false};
/*!
/**
* @brief Protect all above data and conditions
*/
mutable std::mutex mutex_;
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file communicator.cc
* @brief SocketCommunicator for DGL distributed training.
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file communicator.h
* @brief SocketCommunicator for DGL distributed training.
......@@ -25,7 +25,7 @@ static constexpr int kTimeOut =
10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024
/*!
/**
* @breif Networking address
*/
struct IPAddr {
......@@ -33,14 +33,14 @@ struct IPAddr {
int port;
};
/*!
/**
* @brief SocketSender for DGL distributed training.
*
* SocketSender is the communicator implemented by tcp socket.
*/
class SocketSender : public Sender {
public:
/*!
/**
* @brief Sender constructor
* @param queue_size size of message queue
* @param max_thread_count size of thread pool. 0 for no limit
......@@ -48,7 +48,7 @@ class SocketSender : public Sender {
SocketSender(int64_t queue_size, int max_thread_count)
: Sender(queue_size, max_thread_count) {}
/*!
/**
* @brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call
......@@ -64,7 +64,7 @@ class SocketSender : public Sender {
*/
bool ConnectReceiver(const std::string& addr, int recv_id) override;
/*!
/**
* @brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* @return True for success and False for fail
......@@ -73,19 +73,19 @@ class SocketSender : public Sender {
*/
bool ConnectReceiverFinalize(const int max_try_times) override;
/*!
/**
* @brief Send RPCMessage to specified Receiver.
* @param msg data message
* @param recv_id receiver's ID
*/
void Send(const rpc::RPCMessage& msg, int recv_id) override;
/*!
/**
* @brief Finalize TPSender
*/
void Finalize() override;
/*!
/**
* @brief Communicator type: 'socket'
*/
const std::string& NetType() const override {
......@@ -93,7 +93,7 @@ class SocketSender : public Sender {
return net_type;
}
/*!
/**
* @brief Send data to specified Receiver. Actually pushing message to message
* queue.
* @param msg data message.
......@@ -110,29 +110,29 @@ class SocketSender : public Sender {
STATUS Send(Message msg, int recv_id) override;
private:
/*!
/**
* @brief socket for each connection of receiver
*/
std::vector<
std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>>
sockets_;
/*!
/**
* @brief receivers' address
*/
std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*!
/**
* @brief message queue for each thread
*/
std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
/*!
/**
* @brief Independent thread
*/
std::vector<std::shared_ptr<std::thread>> threads_;
/*!
/**
* @brief Send-loop for each thread
* @param sockets TCPSockets for current thread
* @param queue message_queue for current thread
......@@ -147,14 +147,14 @@ class SocketSender : public Sender {
std::shared_ptr<MessageQueue> queue);
};
/*!
/**
* @brief SocketReceiver for DGL distributed training.
*
* SocketReceiver is the communicator implemented by tcp socket.
*/
class SocketReceiver : public Receiver {
public:
/*!
/**
* @brief Receiver constructor
* @param queue_size size of message queue.
* @param max_thread_count size of thread pool. 0 for no limit
......@@ -162,7 +162,7 @@ class SocketReceiver : public Receiver {
SocketReceiver(int64_t queue_size, int max_thread_count)
: Receiver(queue_size, max_thread_count) {}
/*!
/**
* @brief Wait for all the Senders to connect
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* @param num_sender total number of Senders
......@@ -174,7 +174,7 @@ class SocketReceiver : public Receiver {
bool Wait(
const std::string& addr, int num_sender, bool blocking = true) override;
/*!
/**
* @brief Recv RPCMessage from Sender. Actually removing data from queue.
* @param msg pointer of RPCmessage
* @param timeout The timeout value in milliseconds. If zero, wait
......@@ -183,7 +183,7 @@ class SocketReceiver : public Receiver {
*/
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
/*!
/**
* @brief Recv data from Sender. Actually removing data from msg_queue.
* @param msg pointer of data message
* @param send_id which sender current msg comes from
......@@ -197,7 +197,7 @@ class SocketReceiver : public Receiver {
*/
STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;
/*!
/**
* @brief Recv data from a specified Sender. Actually removing data from
* msg_queue.
* @param msg pointer of data message.
......@@ -212,14 +212,14 @@ class SocketReceiver : public Receiver {
*/
STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;
/*!
/**
* @brief Finalize SocketReceiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize() override;
/*!
/**
* @brief Communicator type: 'socket'
*/
const std::string& NetType() const override {
......@@ -233,24 +233,24 @@ class SocketReceiver : public Receiver {
int64_t received_bytes = 0;
char* buffer = nullptr;
};
/*!
/**
* @brief number of sender
*/
int num_sender_;
/*!
/**
* @brief server socket for listening connections
*/
TCPSocket* server_socket_;
/*!
/**
* @brief socket for each client connections
*/
std::vector<std::unordered_map<
int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>>
sockets_;
/*!
/**
* @brief Message queue for each socket connection
*/
std::unordered_map<
......@@ -258,18 +258,18 @@ class SocketReceiver : public Receiver {
msg_queue_;
std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
/*!
/**
* @brief Independent thead
*/
std::vector<std::shared_ptr<std::thread>> threads_;
/*!
/**
* @brief queue_sem_ semphore to indicate number of messages in multiple
* message queues to prevent busy wait of Recv
*/
runtime::Semaphore queue_sem_;
/*!
/**
* @brief Recv-loop for each thread
* @param sockets client sockets of current thread
* @param queue message queues of current thread
......
/*!
/**
* Copyright (c) 2021 by Contributors
* @file socket_pool.cc
* @brief Socket pool of nonblocking sockets for DGL distributed training.
......
/*!
/**
* Copyright (c) 2021 by Contributors
* @file socket_pool.h
* @brief Socket pool of nonblocking sockets for DGL distributed training.
......@@ -15,7 +15,7 @@ namespace network {
class TCPSocket;
/*!
/**
* @brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification
......@@ -23,20 +23,20 @@ class TCPSocket;
*/
class SocketPool {
public:
/*!
/**
* @brief socket mode read/receive
*/
static const int READ = 1;
/*!
/**
* @brief socket mode write/send
*/
static const int WRITE = 2;
/*!
/**
* @brief SocketPool constructor
*/
SocketPool();
/*!
/**
* @brief Add a socket to SocketPool
* @param socket tcp socket to add
* @param socket_id receiver/sender id of the socket
......@@ -45,19 +45,19 @@ class SocketPool {
void AddSocket(
std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ);
/*!
/**
* @brief Remove socket from SocketPool
* @param socket tcp socket to remove
* @return number of remaing sockets in the pool
*/
size_t RemoveSocket(std::shared_ptr<TCPSocket> socket);
/*!
/**
* @brief SocketPool destructor
*/
~SocketPool();
/*!
/**
* @brief Get current active socket. This is a blocking method
* @param socket_id output parameter of the socket_id of active socket
* @return active TCPSocket
......@@ -65,27 +65,27 @@ class SocketPool {
std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id);
private:
/*!
/**
* @brief Wait for event notification
*/
void Wait();
/*!
/**
* @brief map from fd to TCPSocket
*/
std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_;
/*!
/**
* @brief map from fd to socket_id
*/
std::unordered_map<int, int> socket_ids_;
/*!
/**
* @brief fd for epoll base
*/
int epfd_;
/*!
/**
* @brief queue for current active fds
*/
std::queue<int> pending_fds_;
......
/*!
/**
* Copyright (c) 2019 by Contributors
* @file tcp_socket.cc
* @brief TCP socket for DGL distributed training.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment