tp_communicator.h 5.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tp_communicator.h
 * \brief Tensorpipe Communicator for DGL distributed training.
 */
#ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
#define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_

#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>

#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
18
#include <atomic>
19
#include "./queue.h"
20
#include "../net_type.h"
21
22
23
24
25
26
27
28
29
30
31

namespace dgl {
namespace rpc {

typedef Queue<RPCMessage> RPCMessageQueue;

/*!
 * \brief TPSender for DGL distributed training.
 *
 * TPSender is the communicator implemented by tcp socket.
 */
32
class TPSender : public RPCSender {
33
34
35
36
37
38
39
40
41
42
43
 public:
  /*!
   * \brief Sender constructor
   * \param queue_size size of message queue
   */
  explicit TPSender(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
  }

  /*!
44
   * \brief Sender destructor
45
   */
46
  ~TPSender() { Finalize(); }
47
48

  /*!
49
50
51
52
53
54
   * \brief Connect to a receiver.
   * 
   * When there are multiple receivers to be connected, application will call `ConnectReceiver`
   * for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
   * successfully established or some of them fail.
   * 
55
56
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * \param recv_id receiver's ID
57
58
   * \return True for success and False for fail
   *
59
   * The function is *not* thread-safe; only one thread can invoke this API.
60
   */
61
  bool ConnectReceiver(const std::string& addr, int recv_id) override;
62
63
64

  /*!
   * \brief Send RPCMessage to specified Receiver.
65
66
   * \param msg data message
   * \param recv_id receiver's ID
67
   */
68
  void Send(const RPCMessage& msg, int recv_id) override;
69
70
71
72

  /*!
   * \brief Finalize TPSender
   */
73
  void Finalize() override;
74
75
76
77

  /*!
   * \brief Communicator type: 'tp'
   */
78
79
80
81
  const std::string &NetType() const override {
    static const std::string net_type = "tensorpipe";
    return net_type;
  }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

 private:
  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each connection of receiver
   */
  std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
    pipes_;

  /*!
   * \brief receivers' listening address
   */
  std::unordered_map<int /* receiver ID */, std::string> receiver_addrs_;
};

/*!
 * \brief TPReceiver for DGL distributed training.
 *
 * Tensorpipe Receiver is the communicator implemented by tcp socket.
 */
106
class TPReceiver : public RPCReceiver {
107
108
109
110
111
112
113
114
115
116
117
 public:
  /*!
   * \brief Receiver constructor
   * \param queue_size size of message queue.
   */
  explicit TPReceiver(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
    queue_ = std::make_shared<RPCMessageQueue>();
  }

118
119
120
121
122
  /*!
   * \brief Receiver destructor
   */
  ~TPReceiver() { Finalize(); }

123
124
125
126
  /*!
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
   * \param num_sender total number of Senders
127
   * \param blocking whether to wait blockingly
128
129
130
131
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
   */
132
133
  bool Wait(const std::string &addr, int num_sender,
            bool blocking = true) override;
134
135
136
137
138

  /*!
   * \brief Recv RPCMessage from Sender. Actually removing data from queue.
   * \param msg pointer of RPCmessage
   */
139
  void Recv(RPCMessage* msg) override;
140
141
142
143
144
145

  /*!
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
   */
146
  void Finalize() override;
147
148
149
150

  /*!
   * \brief Communicator type: 'tp' (tensorpipe)
   */
151
152
153
154
  const std::string &NetType() const override {
    static const std::string net_type = "tensorpipe";
    return net_type;
  }
155
156
157
158
159
160
161

  /*!
   * \brief Issue a receive request on pipe, and push the result into queue
   */
  static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe,
                              std::shared_ptr<RPCMessageQueue> queue);

162
163
164
165
166
167
 private:
  /*!
   * \brief Callback for new connection is accepted.
   */
  void OnAccepted(const tensorpipe::Error&, std::shared_ptr<tensorpipe::Pipe>);

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
 private:
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
   * \brief listener to build pipe
   */
  std::shared_ptr<tensorpipe::Listener> listener;

  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each client connections
   */
  std::unordered_map<int /* Sender (virutal) ID */,
                     std::shared_ptr<tensorpipe::Pipe>>
    pipes_;

  /*!
   * \brief RPCMessage queue
   */
  std::shared_ptr<RPCMessageQueue> queue_;
195
196
197
198
199
200
201
202
203
204

  /*!
   * \brief number of accepted connections
   */
  std::atomic<int32_t> num_connected_{0};

  /*!
   * \brief listner
   */
  std::shared_ptr<tensorpipe::Listener> listener_{nullptr};
205
206
207
208
209
210
};

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_