rpc.h 4.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.h
 * \brief Common headers for remote process call (RPC).
 */
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_

#include <dgl/runtime/ndarray.h>
10
#include <dgl/runtime/object.h>
11
12
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
13

14
#include <cstdint>
15
#include <deque>
16
#include <memory>
17
#include <mutex>
18
#include <string>
19
#include <unordered_map>
20
#include <vector>
21

22
#include "./network/common.h"
23
#include "./rpc_msg.h"
24
#include "./server_state.h"
25
26
27
#include "net_type.h"
#include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h"
28
29
30
31

namespace dgl {
namespace rpc {

32
33
struct RPCContext;

34
35
36
37
38
39
40
41
// Communicator handler type
typedef void* CommunicatorHandle;

/*! \brief Context information for RPC communication */
struct RPCContext {
  /*!
   * \brief Rank of this process.
   *
42
43
   * If the process is a client, this is equal to client ID. Otherwise, the
   * process is a server and this is equal to server ID.
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
   */
  int32_t rank = -1;

  /*!
   * \brief Cuurent machine ID
   */
  int32_t machine_id = -1;

  /*!
   * \brief Total number of machines.
   */
  int32_t num_machines = 0;

  /*!
   * \brief Message sequence number.
   */
60
  std::atomic<int64_t> msg_seq{0};
61
62
63
64
65
66

  /*!
   * \brief Total number of server.
   */
  int32_t num_servers = 0;

67
68
69
70
71
  /*!
   * \brief Total number of client.
   */
  int32_t num_clients = 0;

72
73
74
  /*!
   * \brief Current barrier count
   */
75
  std::unordered_map<int32_t, int32_t> barrier_count;
76

77
78
79
80
81
  /*!
   * \brief Total number of server per machine.
   */
  int32_t num_servers_per_machine = 0;

82
83
84
  /*!
   * \brief Sender communicator.
   */
85
  std::shared_ptr<RPCSender> sender;
86
87
88
89

  /*!
   * \brief Receiver communicator.
   */
90
  std::shared_ptr<RPCReceiver> receiver;
91
92
93
94
95

  /*!
   * \brief Tensorpipe global context
   */
  std::shared_ptr<tensorpipe::Context> ctx;
96
97
98
99
100
101
102
103
104
105
106
107

  /*!
   * \brief Server state data.
   *
   * If the process is a server, this stores necessary
   * server-side data. Otherwise, the process is a client and it stores a cache
   * of the server co-located with the client (if available). When the client
   * invokes a RPC to the co-located server, it can thus perform computation
   * locally without an actual remote call.
   */
  std::shared_ptr<ServerState> server_state;

108
109
110
111
112
113
114
  /*!
   * \brief Cuurent group ID
   */
  int32_t group_id = -1;
  int32_t curr_client_id = -1;
  std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;

115
116
117
118
  /*! \brief Get the RPC context singleton */
  static RPCContext* getInstance() {
    static RPCContext ctx;
    return &ctx;
119
  }
120
121
122

  /*! \brief Reset the RPC context */
  static void Reset() {
123
    auto* t = getInstance();
124
125
126
    t->rank = -1;
    t->machine_id = -1;
    t->num_machines = 0;
127
128
    t->msg_seq = 0;
    t->num_servers = 0;
129
    t->num_clients = 0;
130
    t->barrier_count.clear();
131
    t->num_servers_per_machine = 0;
132
133
134
    t->sender.reset();
    t->receiver.reset();
    t->ctx.reset();
135
    t->server_state.reset();
136
137
138
139
140
141
    t->group_id = -1;
    t->curr_client_id = -1;
    t->clients_.clear();
  }

  int32_t RegisterClient(int32_t client_id, int32_t group_id) {
142
    auto&& m = clients_[group_id];
143
144
145
146
147
148
149
150
151
152
153
    if (m.find(client_id) != m.end()) {
      return -1;
    }
    m[client_id] = ++curr_client_id;
    return curr_client_id;
  }

  int32_t GetClient(int32_t client_id, int32_t group_id) const {
    if (clients_.find(group_id) == clients_.end()) {
      return -1;
    }
154
    const auto& m = clients_.at(group_id);
155
156
157
158
    if (m.find(client_id) == m.end()) {
      return -1;
    }
    return m.at(client_id);
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
  }
};

/*!
 * \brief Send out one RPC message.
 *
 * The operation is non-blocking -- it does not guarantee the payloads have
 * reached the target or even have left the sender process. However,
 * all the payloads (i.e., data and arrays) can be safely freed after this
 * function returns.
 *
 * The data buffer in the requst will be copied to internal buffer for actual
 * transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
 * The underlying sending threads will hold references to the tensors until
 * the contents have been transmitted.
 *
 * \param msg RPC message to send
 * \return status flag
 */
RPCStatus SendRPCMessage(const RPCMessage& msg);

/*!
 * \brief Receive one RPC message.
 *
 * The operation is blocking -- it returns when it receives any message
 *
 * \param msg The received message
 * \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
 * \return status flag
 */
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_RPC_H_