tp_communicator.h 4.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tp_communicator.h
 * \brief Tensorpipe Communicator for DGL distributed training.
 */
#ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
#define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_

#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>

#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
18
#include <atomic>
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include "./queue.h"

namespace dgl {
namespace rpc {

class RPCMessage;

typedef Queue<RPCMessage> RPCMessageQueue;

/*!
 * \brief TPSender for DGL distributed training.
 *
 * TPSender is the communicator implemented by tcp socket.
 */
class TPSender {
 public:
  /*!
   * \brief Sender constructor
   * \param queue_size size of message queue
   */
  explicit TPSender(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
  }

  /*!
45
   * \brief Sender destructor
46
   */
47
  ~TPSender() { Finalize(); }
48
49

  /*!
50
51
52
   * \brief Connect to receiver with address and ID
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * \param recv_id receiver's ID
53
54
   * \return True for success and False for fail
   *
55
   * ConnectReceiver() is not thread-safe and only one thread can invoke this API.
56
   */
57
  bool ConnectReceiver(const std::string& addr, int recv_id);
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

  /*!
   * \brief Send RPCMessage to specified Receiver.
   * \param msg data message \param recv_id receiver's ID
   */
  void Send(const RPCMessage& msg, int recv_id);

  /*!
   * \brief Finalize TPSender
   */
  void Finalize();

  /*!
   * \brief Communicator type: 'tp'
   */
  inline std::string Type() const { return std::string("tp"); }

 private:
  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each connection of receiver
   */
  std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
    pipes_;

  /*!
   * \brief receivers' listening address
   */
  std::unordered_map<int /* receiver ID */, std::string> receiver_addrs_;
};

/*!
 * \brief TPReceiver for DGL distributed training.
 *
 * Tensorpipe Receiver is the communicator implemented by tcp socket.
 */
class TPReceiver {
 public:
  /*!
   * \brief Receiver constructor
   * \param queue_size size of message queue.
   */
  explicit TPReceiver(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
    queue_ = std::make_shared<RPCMessageQueue>();
  }

110
111
112
113
114
  /*!
   * \brief Receiver destructor
   */
  ~TPReceiver() { Finalize(); }

115
116
117
118
  /*!
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
   * \param num_sender total number of Senders
119
   * \param blocking whether to wait blockingly
120
121
122
123
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
   */
124
  bool Wait(const std::string &addr, int num_sender, bool blocking = true);
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

  /*!
   * \brief Recv RPCMessage from Sender. Actually removing data from queue.
   * \param msg pointer of RPCmessage
   * \param send_id which sender current msg comes from
   * \return Status code
   *
   * (1) The Recv() API is blocking, which will not
   *     return until getting data from message queue.
   * (2) The Recv() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function
   * returns.
   */
  void Recv(RPCMessage* msg);

  /*!
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
   */
  void Finalize();

  /*!
   * \brief Communicator type: 'tp' (tensorpipe)
   */
  inline std::string Type() const { return std::string("tp"); }

  /*!
   * \brief Issue a receive request on pipe, and push the result into queue
   */
  static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe,
                              std::shared_ptr<RPCMessageQueue> queue);

158
159
160
161
162
163
 private:
  /*!
   * \brief Callback for new connection is accepted.
   */
  void OnAccepted(const tensorpipe::Error&, std::shared_ptr<tensorpipe::Pipe>);

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
 private:
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
   * \brief listener to build pipe
   */
  std::shared_ptr<tensorpipe::Listener> listener;

  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each client connections
   */
  std::unordered_map<int /* Sender (virutal) ID */,
                     std::shared_ptr<tensorpipe::Pipe>>
    pipes_;

  /*!
   * \brief RPCMessage queue
   */
  std::shared_ptr<RPCMessageQueue> queue_;
191
192
193
194
195
196
197
198
199
200

  /*!
   * \brief number of accepted connections
   */
  std::atomic<int32_t> num_connected_{0};

  /*!
   * \brief listner
   */
  std::shared_ptr<tensorpipe::Listener> listener_{nullptr};
201
202
203
204
205
206
};

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_