Unverified Commit 55425584 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistDGL] remove RPCSender/Receiver base classes (#5879)

parent 4e5780e3
/**
* Copyright (c) 2022 by Contributors
* @file net_type.h
* @brief Base communicator for DGL distributed training.
*/
#ifndef DGL_RPC_NET_TYPE_H_
#define DGL_RPC_NET_TYPE_H_
#include <string>
#include "rpc_msg.h"
namespace dgl {
namespace rpc {
struct RPCBase {
/**
* @brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/**
* @brief Communicator type such as 'socket'.
*/
virtual const std::string &NetType() const = 0;
};
struct RPCSender : RPCBase {
/**
* @brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call
* `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* sure that either all the connections are successfully established or some
* of them fail.
*
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* @param recv_id receiver's ID
* @return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiver(const std::string &addr, int recv_id) = 0;
/**
* @brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* @return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiverFinalize(const int max_try_times) { return true; }
/**
* @brief Send RPCMessage to specified Receiver.
* @param msg data message
* @param recv_id receiver's ID
*/
virtual void Send(const RPCMessage &msg, int recv_id) = 0;
};
struct RPCReceiver : RPCBase {
/**
* @brief Wait for all the Senders to connect
* @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* @param num_sender total number of Senders
* @param blocking whether wait blockingly
* @return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(
const std::string &addr, int num_sender, bool blocking = true) = 0;
/**
* @brief Recv RPCMessage from Sender. Actually removing data from queue.
* @param msg pointer of RPCmessage
* @param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* @return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
virtual RPCStatus Recv(RPCMessage *msg, int timeout) = 0;
};
} // namespace rpc
} // namespace dgl
#endif // DGL_RPC_NET_TYPE_H_
......@@ -10,7 +10,6 @@
#include <string>
#include "../net_type.h"
#include "msg_queue.h"
namespace dgl {
......@@ -25,7 +24,7 @@ namespace network {
* multiple receivers and it can send data to specified receiver via receiver's
* ID.
*/
class Sender : public rpc::RPCSender {
class Sender {
public:
/**
* @brief Sender constructor
......@@ -77,7 +76,7 @@ class Sender : public rpc::RPCSender {
* with multiple Senders and it can receive data from multiple Senders
* concurrently.
*/
class Receiver : public rpc::RPCReceiver {
class Receiver {
public:
/**
* @brief Receiver constructor
......
......@@ -13,6 +13,7 @@
#include <vector>
#include "../../runtime/semaphore_wrapper.h"
#include "../rpc_msg.h"
#include "common.h"
#include "communicator.h"
#include "msg_queue.h"
......@@ -62,7 +63,7 @@ class SocketSender : public Sender {
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiver(const std::string& addr, int recv_id) override;
bool ConnectReceiver(const std::string& addr, int recv_id);
/**
* @brief Finalize the action to connect to receivers. Make sure that either
......@@ -71,27 +72,19 @@ class SocketSender : public Sender {
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiverFinalize(const int max_try_times) override;
bool ConnectReceiverFinalize(const int max_try_times);
/**
* @brief Send RPCMessage to specified Receiver.
* @param msg data message
* @param recv_id receiver's ID
*/
void Send(const rpc::RPCMessage& msg, int recv_id) override;
void Send(const rpc::RPCMessage& msg, int recv_id);
/**
* @brief Finalize TPSender
*/
void Finalize() override;
/**
* @brief Communicator type: 'socket'
*/
const std::string& NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
void Finalize();
/**
* @brief Send data to specified Receiver. Actually pushing message to message
......@@ -171,8 +164,7 @@ class SocketReceiver : public Receiver {
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
bool Wait(
const std::string& addr, int num_sender, bool blocking = true) override;
bool Wait(const std::string& addr, int num_sender, bool blocking = true);
/**
* @brief Recv RPCMessage from Sender. Actually removing data from queue.
......@@ -181,7 +173,7 @@ class SocketReceiver : public Receiver {
* indefinitely.
* @return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout) override;
rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout);
/**
* @brief Recv data from Sender. Actually removing data from msg_queue.
......@@ -217,15 +209,7 @@ class SocketReceiver : public Receiver {
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize() override;
/**
* @brief Communicator type: 'socket'
*/
const std::string& NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
void Finalize();
private:
struct RecvContext {
......
......@@ -22,7 +22,6 @@
#include "./network/common.h"
#include "./rpc_msg.h"
#include "./server_state.h"
#include "net_type.h"
#include "network/socket_communicator.h"
namespace dgl {
......@@ -81,12 +80,12 @@ struct RPCContext {
/**
* @brief Sender communicator.
*/
std::shared_ptr<RPCSender> sender;
std::shared_ptr<network::SocketSender> sender;
/**
* @brief Receiver communicator.
*/
std::shared_ptr<RPCReceiver> receiver;
std::shared_ptr<network::SocketReceiver> receiver;
/**
* @brief Server state data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment