tp_communicator.h 4.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tp_communicator.h
 * \brief Tensorpipe Communicator for DGL distributed training.
 */
#ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_
#define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_

#include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h>

#include <deque>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>

#include "./queue.h"

namespace dgl {
namespace rpc {

class RPCMessage;

typedef Queue<RPCMessage> RPCMessageQueue;

/*!
 * \brief TPSender for DGL distributed training.
 *
 * TPSender is the communicator implemented by tcp socket.
 */
class TPSender {
 public:
  /*!
   * \brief Sender constructor
   * \param queue_size size of message queue
   */
  explicit TPSender(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
  }

  /*!
   * \brief Add receiver's address and ID to the sender's namebook
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
   * \param id receiver's ID
   *
   * AddReceiver() is not thread-safe and only one thread can invoke this API.
   */
  void AddReceiver(const std::string& addr, int recv_id);

  /*!
   * \brief Connect with all the Receivers
   * \return True for success and False for fail
   *
   * Connect() is not thread-safe and only one thread can invoke this API.
   */
  bool Connect();

  /*!
   * \brief Send RPCMessage to specified Receiver.
   * \param msg data message \param recv_id receiver's ID
   */
  void Send(const RPCMessage& msg, int recv_id);

  /*!
   * \brief Finalize TPSender
   */
  void Finalize();

  /*!
   * \brief Communicator type: 'tp'
   */
  inline std::string Type() const { return std::string("tp"); }

 private:
  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each connection of receiver
   */
  std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
    pipes_;

  /*!
   * \brief receivers' listening address
   */
  std::unordered_map<int /* receiver ID */, std::string> receiver_addrs_;
};

/*!
 * \brief TPReceiver for DGL distributed training.
 *
 * Tensorpipe Receiver is the communicator implemented by tcp socket.
 */
class TPReceiver {
 public:
  /*!
   * \brief Receiver constructor
   * \param queue_size size of message queue.
   */
  explicit TPReceiver(std::shared_ptr<tensorpipe::Context> ctx) {
    CHECK(ctx) << "Context is not initialized";
    this->context = ctx;
    queue_ = std::make_shared<RPCMessageQueue>();
  }

  /*!
   * \brief Wait for all the Senders to connect
   * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051'
   * \param num_sender total number of Senders
   * \return True for success and False for fail
   *
   * Wait() is not thread-safe and only one thread can invoke this API.
   */
  bool Wait(const std::string& addr, int num_sender);

  /*!
   * \brief Recv RPCMessage from Sender. Actually removing data from queue.
   * \param msg pointer of RPCmessage
   * \param send_id which sender current msg comes from
   * \return Status code
   *
   * (1) The Recv() API is blocking, which will not
   *     return until getting data from message queue.
   * (2) The Recv() API is thread-safe.
   * (3) Memory allocated by communicator but will not own it after the function
   * returns.
   */
  void Recv(RPCMessage* msg);

  /*!
   * \brief Finalize SocketReceiver
   *
   * Finalize() is not thread-safe and only one thread can invoke this API.
   */
  void Finalize();

  /*!
   * \brief Communicator type: 'tp' (tensorpipe)
   */
  inline std::string Type() const { return std::string("tp"); }

  /*!
   * \brief Issue a receive request on pipe, and push the result into queue
   */
  static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe,
                              std::shared_ptr<RPCMessageQueue> queue);

 private:
  /*!
   * \brief number of sender
   */
  int num_sender_;

  /*!
   * \brief listener to build pipe
   */
  std::shared_ptr<tensorpipe::Listener> listener;

  /*!
   * \brief global context of tensorpipe
   */
  std::shared_ptr<tensorpipe::Context> context;

  /*!
   * \brief pipe for each client connections
   */
  std::unordered_map<int /* Sender (virutal) ID */,
                     std::shared_ptr<tensorpipe::Pipe>>
    pipes_;

  /*!
   * \brief RPCMessage queue
   */
  std::shared_ptr<RPCMessageQueue> queue_;
};

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_