tp_communicator.h 5.17 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file tp_communicator.h
 * @brief Tensorpipe Communicator for DGL distributed training.
5
6
7
8
9
10
11
 */
#ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
#define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_

#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>

12
#include <atomic>
13
14
15
16
17
18
#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
19

20
#include "../net_type.h"
21
#include "./queue.h"
22
23
24
25
26
27

namespace dgl {
namespace rpc {

typedef Queue<RPCMessage> RPCMessageQueue;

28
/**
29
 * @brief TPSender for DGL distributed training.
30
31
32
 *
 * TPSender is the communicator implemented by tcp socket.
 */
33
class TPSender : public RPCSender {
34
 public:
35
  /**
36
37
   * @brief Sender constructor
   * @param queue_size size of message queue
38
39
40
41
42
43
   */
  explicit TPSender(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
  }

44
  /**
45
   * @brief Sender destructor
46
   */
47
  ~TPSender() { Finalize(); }
48

49
  /**
50
   * @brief Connect to a receiver.
51
52
53
54
55
56
   *
   * When there are multiple receivers to be connected, application will call
   * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
   * sure that either all the connections are successfully established or some
   * of them fail.
   *
57
58
59
   * @param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * @param recv_id receiver's ID
   * @return True for success and False for fail
60
   *
61
   * The function is *not* thread-safe; only one thread can invoke this API.
62
   */
63
  bool ConnectReceiver(const std::string& addr, int recv_id) override;
64

65
  /**
66
67
68
   * @brief Send RPCMessage to specified Receiver.
   * @param msg data message
   * @param recv_id receiver's ID
69
   */
70
  void Send(const RPCMessage& msg, int recv_id) override;
71

72
  /**
73
   * @brief Finalize TPSender
74
   */
75
  void Finalize() override;
76

77
  /**
78
   * @brief Communicator type: 'tp'
79
   */
80
  const std::string& NetType() const override {
81
82
83
    static const std::string net_type = "tensorpipe";
    return net_type;
  }
84
85

 private:
86
  /**
87
   * @brief global context of tensorpipe
88
89
90
   */
  std::shared_ptr<tensorpipe::Context> context;

91
  /**
92
   * @brief pipe for each connection of receiver
93
94
   */
  std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
95
      pipes_;
96

97
  /**
98
   * @brief receivers' listening address
99
100
101
102
   */
  std::unordered_map<int /* receiver ID */, std::string> receiver_addrs_;
};

103
/**
104
 * @brief TPReceiver for DGL distributed training.
105
106
107
 *
 * Tensorpipe Receiver is the communicator implemented by tcp socket.
 */
108
class TPReceiver : public RPCReceiver {
109
 public:
110
  /**
111
112
   * @brief Receiver constructor
   * @param queue_size size of message queue.
113
114
115
116
117
118
119
   */
  explicit TPReceiver(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
    queue_ = std::make_shared<RPCMessageQueue>();
  }

120
  /**
121
   * @brief Receiver destructor
122
123
124
   */
  ~TPReceiver() { Finalize(); }

125
  /**
126
127
128
129
130
   * @brief Wait for all the Senders to connect
   * @param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
   * @param num_sender total number of Senders
   * @param blocking whether to wait blockingly
   * @return True for success and False for fail
131
132
133
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
   */
134
135
  bool Wait(
      const std::string& addr, int num_sender, bool blocking = true) override;
136

137
  /**
138
139
140
   * @brief Recv RPCMessage from Sender. Actually removing data from queue.
   * @param msg pointer of RPCmessage
   * @param timeout The timeout value in milliseconds. If zero, wait
141
   * indefinitely.
142
   * @return RPCStatus: kRPCSuccess or kRPCTimeOut.
143
   */
144
  RPCStatus Recv(RPCMessage* msg, int timeout) override;
145

146
  /**
147
   * @brief Finalize SocketReceiver
148
149
150
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
   */
151
  void Finalize() override;
152

153
  /**
154
   * @brief Communicator type: 'tp' (tensorpipe)
155
   */
156
  const std::string& NetType() const override {
157
158
159
    static const std::string net_type = "tensorpipe";
    return net_type;
  }
160

161
  /**
162
   * @brief Issue a receive request on pipe, and push the result into queue
163
   */
164
165
166
  static void ReceiveFromPipe(
      std::shared_ptr<tensorpipe::Pipe> pipe,
      std::shared_ptr<RPCMessageQueue> queue);
167

168
 private:
169
  /**
170
   * @brief Callback for new connection is accepted.
171
172
173
   */
  void OnAccepted(const tensorpipe::Error&, std::shared_ptr<tensorpipe::Pipe>);

174
 private:
175
  /**
176
   * @brief number of sender
177
178
179
   */
  int num_sender_;

180
  /**
181
   * @brief listener to build pipe
182
183
184
   */
  std::shared_ptr<tensorpipe::Listener> listener;

185
  /**
186
   * @brief global context of tensorpipe
187
188
189
   */
  std::shared_ptr<tensorpipe::Context> context;

190
  /**
191
   * @brief pipe for each client connections
192
   */
193
194
195
  std::unordered_map<
      int /* Sender (virutal) ID */, std::shared_ptr<tensorpipe::Pipe>>
      pipes_;
196

197
  /**
198
   * @brief RPCMessage queue
199
200
   */
  std::shared_ptr<RPCMessageQueue> queue_;
201

202
  /**
203
   * @brief number of accepted connections
204
205
206
   */
  std::atomic<int32_t> num_connected_{0};

207
  /**
208
   * @brief listner
209
210
   */
  std::shared_ptr<tensorpipe::Listener> listener_{nullptr};
211
212
213
214
215
216
};

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_