communicator.h 5.4 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file communicator.h
 * \brief Communicator for DGL distributed training.
 */
6
7
#ifndef DGL_RPC_NETWORK_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_COMMUNICATOR_H_
8

9
10
#include <dmlc/logging.h>

11
12
#include <string>

13
14
#include "msg_queue.h"

15
16
17
18
namespace dgl {
namespace network {

/*!
19
 * \brief Network Sender for DGL distributed training.
20
 *
21
22
23
24
 * Sender is an abstract class that defines a set of APIs for sending binary 
 * data message over network. It can be implemented by different underlying 
 * networking libraries such TCP socket and MPI. One Sender can connect to 
 * multiple receivers and it can send data to specified receiver via receiver's ID.
25
 */
26
class Sender {
27
 public:
28
29
30
  /*!
   * \brief Sender constructor
   * \param queue_size size (bytes) of message queue. 
31
   * \param max_thread_count size of thread pool. 0 for no limit
32
33
   * Note that, the queue_size parameter is optional.
   */
34
  explicit Sender(int64_t queue_size = 0, int max_thread_count = 0) {
35
    CHECK_GE(queue_size, 0);
36
    CHECK_GE(max_thread_count, 0);
37
    queue_size_ = queue_size;
38
    max_thread_count_ = max_thread_count;
39
40
  }

41
  virtual ~Sender() {}
42
43

  /*!
44
45
   * \brief Add receiver's address and ID to the sender's namebook
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
46
   * \param id receiver's ID
47
48
   *
   * AddReceiver() is not thread-safe and only one thread can invoke this API.
49
   */
50
  virtual void AddReceiver(const char* addr, int id) = 0;
51

52
  /*!
53
   * \brief Connect with all the Receivers
54
55
56
   * \return True for success and False for fail
   *
   * Connect() is not thread-safe and only one thread can invoke this API.
57
   */
58
  virtual bool Connect() = 0;
59
60

  /*!
61
62
   * \brief Send data to specified Receiver.
   * \param msg data message
63
   * \param recv_id receiver's ID
64
65
66
67
68
69
70
71
   * \return Status code
   *
   * (1) The send is non-blocking. There is no guarantee that the message has been 
   *     physically sent out when the function returns.
   * (2) The communicator will assume the responsibility of the given message.
   * (3) The API is multi-thread safe.
   * (4) Messages sent to the same receiver are guaranteed to be received in the same order. 
   *     There is no guarantee for messages sent to different receivers.
72
   */
73
  virtual STATUS Send(Message msg, int recv_id) = 0;
74
75

  /*!
76
   * \brief Finalize Sender
77
78
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
79
80
   */
  virtual void Finalize() = 0;
81
82

  /*!
83
   * \brief Communicator type: 'socket', 'mpi', etc.
84
   */
85
  virtual std::string Type() const = 0;
86

87
 protected:
88
  /*!
89
   * \brief Size of message queue
90
   */
91
  int64_t queue_size_;
92
93
94
95
  /*!
   * \brief Size of thread pool. 0 for no limit
   */
  int max_thread_count_;
96
97
98
99
100
};

/*!
 * \brief Network Receiver for DGL distributed training.
 *
101
102
103
104
 * Receiver is an abstract class that defines a set of APIs for receiving binary data 
 * message over network. It can be implemented by different underlying networking 
 * libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders 
 * and it can receive data from multiple Senders concurrently.
105
106
107
 */
class Receiver {
 public:
108
109
110
  /*!
   * \brief Receiver constructor
   * \param queue_size size of message queue.
111
   * \param max_thread_count size of thread pool. 0 for no limit
112
113
   * Note that, the queue_size parameter is optional.
   */
114
  explicit Receiver(int64_t queue_size = 0, int max_thread_count = 0) {
115
116
117
    if (queue_size < 0) {
      LOG(FATAL) << "queue_size cannot be a negative number.";
    }
118
    CHECK_GE(max_thread_count, 0);
119
    queue_size_ = queue_size;
120
    max_thread_count_ = max_thread_count;
121
122
  }

123
  virtual ~Receiver() {}
124
125

  /*!
126
127
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
128
   * \param num_sender total number of Senders
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
   */
  virtual bool Wait(const char* addr, int num_sender) = 0;

  /*!
   * \brief Recv data from Sender
   * \param msg pointer of data message
   * \param send_id which sender current msg comes from
   * \return Status code
   *
   * (1) The Recv() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The Recv() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
145
   */
146
  virtual STATUS Recv(Message* msg, int* send_id) = 0;
147
148

  /*!
149
150
151
152
153
154
155
156
157
   * \brief Recv data from a specified Sender
   * \param msg pointer of data message
   * \param send_id sender's ID
   * \return Status code
   *
   * (1) The RecvFrom() API is blocking, which will not 
   *     return until getting data from message queue.
   * (2) The RecvFrom() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function returns.
158
   */
159
  virtual STATUS RecvFrom(Message* msg, int send_id) = 0;
160
161
162

  /*!
   * \brief Finalize Receiver
163
164
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
165
166
   */
  virtual void Finalize() = 0;
167
168

  /*!
169
   * \brief Communicator type: 'socket', 'mpi', etc
170
   */
171
  virtual std::string Type() const = 0;
172

173
 protected:
174
  /*!
175
   * \brief Size of message queue
176
   */
177
  int64_t queue_size_;
178
179
180
181
  /*!
   * \brief Size of thread pool. 0 for no limit
   */
  int max_thread_count_;
182
183
184
185
186
};

}  // namespace network
}  // namespace dgl

187
#endif  // DGL_RPC_NETWORK_COMMUNICATOR_H_