Unverified Commit 55425584 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistDGL] remove RPCSender/Receiver base classes (#5879)

parent 4e5780e3
/**
* Copyright (c) 2022 by Contributors
* @file net_type.h
* @brief Base communicator for DGL distributed training.
*/
#ifndef DGL_RPC_NET_TYPE_H_
#define DGL_RPC_NET_TYPE_H_
#include <string>
#include "rpc_msg.h"
namespace dgl {
namespace rpc {
struct RPCBase {
/**
* @brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/**
* @brief Communicator type such as 'socket'.
*/
virtual const std::string &NetType() const = 0;
};
struct RPCSender : RPCBase {
/**
* @brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call
* `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* sure that either all the connections are successfully established or some
* of them fail.
*
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* @param recv_id receiver's ID
* @return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiver(const std::string &addr, int recv_id) = 0;
/**
* @brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* @return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiverFinalize(const int max_try_times) { return true; }
/**
* @brief Send RPCMessage to specified Receiver.
* @param msg data message
* @param recv_id receiver's ID
*/
virtual void Send(const RPCMessage &msg, int recv_id) = 0;
};
struct RPCReceiver : RPCBase {
/**
* @brief Wait for all the Senders to connect
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* @param num_sender total number of Senders
* @param blocking whether wait blockingly
* @return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(
const std::string &addr, int num_sender, bool blocking = true) = 0;
/**
* @brief Recv RPCMessage from Sender. Actually removing data from queue.
* @param msg pointer of RPCmessage
* @param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* @return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
virtual RPCStatus Recv(RPCMessage *msg, int timeout) = 0;
};
} // namespace rpc
} // namespace dgl
#endif // DGL_RPC_NET_TYPE_H_
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include <string> #include <string>
#include "../net_type.h"
#include "msg_queue.h" #include "msg_queue.h"
namespace dgl { namespace dgl {
...@@ -25,7 +24,7 @@ namespace network { ...@@ -25,7 +24,7 @@ namespace network {
* multiple receivers and it can send data to specified receiver via receiver's * multiple receivers and it can send data to specified receiver via receiver's
* ID. * ID.
*/ */
class Sender : public rpc::RPCSender { class Sender {
public: public:
/** /**
* @brief Sender constructor * @brief Sender constructor
...@@ -77,7 +76,7 @@ class Sender : public rpc::RPCSender { ...@@ -77,7 +76,7 @@ class Sender : public rpc::RPCSender {
* with multiple Senders and it can receive data from multiple Senders * with multiple Senders and it can receive data from multiple Senders
* concurrently. * concurrently.
*/ */
class Receiver : public rpc::RPCReceiver { class Receiver {
public: public:
/** /**
* @brief Receiver constructor * @brief Receiver constructor
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <vector> #include <vector>
#include "../../runtime/semaphore_wrapper.h" #include "../../runtime/semaphore_wrapper.h"
#include "../rpc_msg.h"
#include "common.h" #include "common.h"
#include "communicator.h" #include "communicator.h"
#include "msg_queue.h" #include "msg_queue.h"
...@@ -62,7 +63,7 @@ class SocketSender : public Sender { ...@@ -62,7 +63,7 @@ class SocketSender : public Sender {
* *
* The function is *not* thread-safe; only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/ */
bool ConnectReceiver(const std::string& addr, int recv_id) override; bool ConnectReceiver(const std::string& addr, int recv_id);
/** /**
* @brief Finalize the action to connect to receivers. Make sure that either * @brief Finalize the action to connect to receivers. Make sure that either
...@@ -71,27 +72,19 @@ class SocketSender : public Sender { ...@@ -71,27 +72,19 @@ class SocketSender : public Sender {
* *
* The function is *not* thread-safe; only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/ */
bool ConnectReceiverFinalize(const int max_try_times) override; bool ConnectReceiverFinalize(const int max_try_times);
/** /**
* @brief Send RPCMessage to specified Receiver. * @brief Send RPCMessage to specified Receiver.
* @param msg data message * @param msg data message
* @param recv_id receiver's ID * @param recv_id receiver's ID
*/ */
void Send(const rpc::RPCMessage& msg, int recv_id) override; void Send(const rpc::RPCMessage& msg, int recv_id);
/** /**
* @brief Finalize TPSender * @brief Finalize TPSender
*/ */
void Finalize() override; void Finalize();
/**
* @brief Communicator type: 'socket'
*/
const std::string& NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
/** /**
* @brief Send data to specified Receiver. Actually pushing message to message * @brief Send data to specified Receiver. Actually pushing message to message
...@@ -171,8 +164,7 @@ class SocketReceiver : public Receiver { ...@@ -171,8 +164,7 @@ class SocketReceiver : public Receiver {
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
bool Wait( bool Wait(const std::string& addr, int num_sender, bool blocking = true);
const std::string& addr, int num_sender, bool blocking = true) override;
/** /**
* @brief Recv RPCMessage from Sender. Actually removing data from queue. * @brief Recv RPCMessage from Sender. Actually removing data from queue.
...@@ -181,7 +173,7 @@ class SocketReceiver : public Receiver { ...@@ -181,7 +173,7 @@ class SocketReceiver : public Receiver {
* indefinitely. * indefinitely.
* @return RPCStatus: kRPCSuccess or kRPCTimeOut. * @return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override; rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout);
/** /**
* @brief Recv data from Sender. Actually removing data from msg_queue. * @brief Recv data from Sender. Actually removing data from msg_queue.
...@@ -217,15 +209,7 @@ class SocketReceiver : public Receiver { ...@@ -217,15 +209,7 @@ class SocketReceiver : public Receiver {
* *
* Finalize() is not thread-safe and only one thread can invoke this API. * Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
void Finalize() override; void Finalize();
/**
* @brief Communicator type: 'socket'
*/
const std::string& NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
private: private:
struct RecvContext { struct RecvContext {
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "./network/common.h" #include "./network/common.h"
#include "./rpc_msg.h" #include "./rpc_msg.h"
#include "./server_state.h" #include "./server_state.h"
#include "net_type.h"
#include "network/socket_communicator.h" #include "network/socket_communicator.h"
namespace dgl { namespace dgl {
...@@ -81,12 +80,12 @@ struct RPCContext { ...@@ -81,12 +80,12 @@ struct RPCContext {
/** /**
* @brief Sender communicator. * @brief Sender communicator.
*/ */
std::shared_ptr<RPCSender> sender; std::shared_ptr<network::SocketSender> sender;
/** /**
* @brief Receiver communicator. * @brief Receiver communicator.
*/ */
std::shared_ptr<RPCReceiver> receiver; std::shared_ptr<network::SocketReceiver> receiver;
/** /**
* @brief Server state data. * @brief Server state data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment