rpc.h 4.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.h
 * \brief Common headers for remote process call (RPC).
 */
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_

#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
#include <cstdint>
#include <memory>
15
#include <deque>
16
17
#include <vector>
#include <string>
18
#include <mutex>
19
#include <unordered_map>
20
21
22

#include "./rpc_msg.h"
#include "./tensorpipe/tp_communicator.h"
23
24
25
26
27
28
#include "./network/common.h"
#include "./server_state.h"

namespace dgl {
namespace rpc {

29
30
struct RPCContext;

31
32
33
34
35
36
37
38
// Communicator handler type
typedef void* CommunicatorHandle;

/*! \brief Context information for RPC communication */
struct RPCContext {
  /*!
   * \brief Rank of this process.
   *
39
40
   * If the process is a client, this is equal to client ID. Otherwise, the
   * process is a server and this is equal to server ID.
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
   */
  int32_t rank = -1;

  /*!
   * \brief Cuurent machine ID
   */
  int32_t machine_id = -1;

  /*!
   * \brief Total number of machines.
   */
  int32_t num_machines = 0;

  /*!
   * \brief Message sequence number.
   */
57
  std::atomic<int64_t> msg_seq{0};
58
59
60
61
62
63

  /*!
   * \brief Total number of server.
   */
  int32_t num_servers = 0;

64
65
66
67
68
  /*!
   * \brief Total number of client.
   */
  int32_t num_clients = 0;

69
70
71
  /*!
   * \brief Current barrier count
   */
72
  std::unordered_map<int32_t, int32_t> barrier_count;
73

74
75
76
77
78
  /*!
   * \brief Total number of server per machine.
   */
  int32_t num_servers_per_machine = 0;

79
80
81
  /*!
   * \brief Sender communicator.
   */
82
  std::shared_ptr<TPSender> sender;
83
84
85
86

  /*!
   * \brief Receiver communicator.
   */
87
88
89
90
91
92
  std::shared_ptr<TPReceiver> receiver;

  /*!
   * \brief Tensorpipe global context
   */
  std::shared_ptr<tensorpipe::Context> ctx;
93
94
95
96
97
98
99
100
101
102
103
104

  /*!
   * \brief Server state data.
   *
   * If the process is a server, this stores necessary
   * server-side data. Otherwise, the process is a client and it stores a cache
   * of the server co-located with the client (if available). When the client
   * invokes a RPC to the co-located server, it can thus perform computation
   * locally without an actual remote call.
   */
  std::shared_ptr<ServerState> server_state;

105
106
107
108
109
110
111
  /*!
   * \brief Cuurent group ID
   */
  int32_t group_id = -1;
  int32_t curr_client_id = -1;
  std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;

112
113
114
115
  /*! \brief Get the RPC context singleton */
  static RPCContext* getInstance() {
    static RPCContext ctx;
    return &ctx;
116
  }
117
118
119

  /*! \brief Reset the RPC context */
  static void Reset() {
120
    auto* t = getInstance();
121
122
123
    t->rank = -1;
    t->machine_id = -1;
    t->num_machines = 0;
124
125
    t->msg_seq = 0;
    t->num_servers = 0;
126
    t->num_clients = 0;
127
    t->barrier_count.clear();
128
    t->num_servers_per_machine = 0;
129
130
131
    t->sender.reset();
    t->receiver.reset();
    t->ctx.reset();
132
    t->server_state.reset();
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    t->group_id = -1;
    t->curr_client_id = -1;
    t->clients_.clear();
  }

  int32_t RegisterClient(int32_t client_id, int32_t group_id) {
    auto &&m = clients_[group_id];
    if (m.find(client_id) != m.end()) {
      return -1;
    }
    m[client_id] = ++curr_client_id;
    return curr_client_id;
  }

  int32_t GetClient(int32_t client_id, int32_t group_id) const {
    if (clients_.find(group_id) == clients_.end()) {
      return -1;
    }
    const auto &m = clients_.at(group_id);
    if (m.find(client_id) == m.end()) {
      return -1;
    }
    return m.at(client_id);
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  }
};

/*! \brief RPC status flag */
enum RPCStatus {
  kRPCSuccess = 0,
  kRPCTimeOut,
};

/*!
 * \brief Send out one RPC message.
 *
 * The operation is non-blocking -- it does not guarantee the payloads have
 * reached the target or even have left the sender process. However,
 * all the payloads (i.e., data and arrays) can be safely freed after this
 * function returns.
 *
 * The data buffer in the requst will be copied to internal buffer for actual
 * transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
 * The underlying sending threads will hold references to the tensors until
 * the contents have been transmitted.
 *
 * \param msg RPC message to send
 * \return status flag
 */
RPCStatus SendRPCMessage(const RPCMessage& msg);

/*!
 * \brief Receive one RPC message.
 *
 * The operation is blocking -- it returns when it receives any message
 *
 * \param msg The received message
 * \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
 * \return status flag
 */
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);

}  // namespace rpc
}  // namespace dgl

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);
}  // namespace dmlc

#endif  // DGL_RPC_RPC_H_