tp_communicator.h 5.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tp_communicator.h
 * \brief Tensorpipe Communicator for DGL distributed training.
 */
#ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
#define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_

#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>

12
#include <atomic>
13
14
15
16
17
18
#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
19

20
#include "../net_type.h"
21
#include "./queue.h"
22
23
24
25
26
27
28
29
30
31
32

namespace dgl {
namespace rpc {

typedef Queue<RPCMessage> RPCMessageQueue;

/*!
 * \brief TPSender for DGL distributed training.
 *
 * TPSender is the communicator implemented by tcp socket.
 */
33
class TPSender : public RPCSender {
34
35
36
37
38
39
40
41
42
43
44
 public:
  /*!
   * \brief Sender constructor
   * \param queue_size size of message queue
   */
  explicit TPSender(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
  }

  /*!
45
   * \brief Sender destructor
46
   */
47
  ~TPSender() { Finalize(); }
48
49

  /*!
50
   * \brief Connect to a receiver.
51
52
53
54
55
56
   *
   * When there are multiple receivers to be connected, application will call
   * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
   * sure that either all the connections are successfully established or some
   * of them fail.
   *
57
58
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * \param recv_id receiver's ID
59
60
   * \return True for success and False for fail
   *
61
   * The function is *not* thread-safe; only one thread can invoke this API.
62
   */
63
  bool ConnectReceiver(const std::string& addr, int recv_id) override;
64
65
66

  /*!
   * \brief Send RPCMessage to specified Receiver.
67
68
   * \param msg data message
   * \param recv_id receiver's ID
69
   */
70
  void Send(const RPCMessage& msg, int recv_id) override;
71
72
73
74

  /*!
   * \brief Finalize TPSender
   */
75
  void Finalize() override;
76
77
78
79

  /*!
   * \brief Communicator type: 'tp'
   */
80
  const std::string& NetType() const override {
81
82
83
    static const std::string net_type = "tensorpipe";
    return net_type;
  }
84
85
86
87
88
89
90
91
92
93
94

 private:
  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each connection of receiver
   */
  std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
95
      pipes_;
96
97
98
99
100
101
102
103
104
105
106
107

  /*!
   * \brief receivers' listening address
   */
  std::unordered_map<int /* receiver ID */, std::string> receiver_addrs_;
};

/*!
 * \brief TPReceiver for DGL distributed training.
 *
 * Tensorpipe Receiver is the communicator implemented by tcp socket.
 */
108
class TPReceiver : public RPCReceiver {
109
110
111
112
113
114
115
116
117
118
119
 public:
  /*!
   * \brief Receiver constructor
   * \param queue_size size of message queue.
   */
  explicit TPReceiver(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
    queue_ = std::make_shared<RPCMessageQueue>();
  }

120
121
122
123
124
  /*!
   * \brief Receiver destructor
   */
  ~TPReceiver() { Finalize(); }

125
126
127
128
  /*!
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
   * \param num_sender total number of Senders
129
   * \param blocking whether to wait blockingly
130
131
132
133
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
   */
134
135
  bool Wait(
      const std::string& addr, int num_sender, bool blocking = true) override;
136
137
138
139

  /*!
   * \brief Recv RPCMessage from Sender. Actually removing data from queue.
   * \param msg pointer of RPCmessage
140
141
   * \param timeout The timeout value in milliseconds. If zero, wait
   * indefinitely.
142
   * \return RPCStatus: kRPCSuccess or kRPCTimeOut.
143
   */
144
  RPCStatus Recv(RPCMessage* msg, int timeout) override;
145
146
147
148
149
150

  /*!
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
   */
151
  void Finalize() override;
152
153
154
155

  /*!
   * \brief Communicator type: 'tp' (tensorpipe)
   */
156
  const std::string& NetType() const override {
157
158
159
    static const std::string net_type = "tensorpipe";
    return net_type;
  }
160
161
162
163

  /*!
   * \brief Issue a receive request on pipe, and push the result into queue
   */
164
165
166
  static void ReceiveFromPipe(
      std::shared_ptr<tensorpipe::Pipe> pipe,
      std::shared_ptr<RPCMessageQueue> queue);
167

168
169
170
171
172
173
 private:
  /*!
   * \brief Callback for new connection is accepted.
   */
  void OnAccepted(const tensorpipe::Error&, std::shared_ptr<tensorpipe::Pipe>);

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
 private:
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
   * \brief listener to build pipe
   */
  std::shared_ptr<tensorpipe::Listener> listener;

  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each client connections
   */
193
194
195
  std::unordered_map<
      int /* Sender (virutal) ID */, std::shared_ptr<tensorpipe::Pipe>>
      pipes_;
196
197
198
199
200

  /*!
   * \brief RPCMessage queue
   */
  std::shared_ptr<RPCMessageQueue> queue_;
201
202
203
204
205
206
207
208
209
210

  /*!
   * \brief number of accepted connections
   */
  std::atomic<int32_t> num_connected_{0};

  /*!
   * \brief listner
   */
  std::shared_ptr<tensorpipe::Listener> listener_{nullptr};
211
212
213
214
215
216
};

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_