rpc.h 4.55 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file rpc/rpc.h
 * @brief Common headers for remote process call (RPC).
5
6
7
8
9
 */
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_

#include <dgl/runtime/ndarray.h>
10
#include <dgl/runtime/object.h>
11
12
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
13

14
#include <cstdint>
15
#include <deque>
16
#include <memory>
17
#include <mutex>
18
#include <string>
19
#include <unordered_map>
20
#include <vector>
21

22
#include "./network/common.h"
23
#include "./rpc_msg.h"
24
#include "./server_state.h"
25
26
27
#include "net_type.h"
#include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h"
28
29
30
31

namespace dgl {
namespace rpc {

32
33
struct RPCContext;

34
35
36
// Communicator handler type
typedef void* CommunicatorHandle;

37
/** @brief Context information for RPC communication */
38
struct RPCContext {
39
  /**
40
   * @brief Rank of this process.
41
   *
42
43
   * If the process is a client, this is equal to client ID. Otherwise, the
   * process is a server and this is equal to server ID.
44
45
46
   */
  int32_t rank = -1;

47
  /**
48
   * @brief Cuurent machine ID
49
50
51
   */
  int32_t machine_id = -1;

52
  /**
53
   * @brief Total number of machines.
54
55
56
   */
  int32_t num_machines = 0;

57
  /**
58
   * @brief Message sequence number.
59
   */
60
  std::atomic<int64_t> msg_seq{0};
61

62
  /**
63
   * @brief Total number of server.
64
65
66
   */
  int32_t num_servers = 0;

67
  /**
68
   * @brief Total number of client.
69
70
71
   */
  int32_t num_clients = 0;

72
  /**
73
   * @brief Current barrier count
74
   */
75
  std::unordered_map<int32_t, int32_t> barrier_count;
76

77
  /**
78
   * @brief Total number of server per machine.
79
80
81
   */
  int32_t num_servers_per_machine = 0;

82
  /**
83
   * @brief Sender communicator.
84
   */
85
  std::shared_ptr<RPCSender> sender;
86

87
  /**
88
   * @brief Receiver communicator.
89
   */
90
  std::shared_ptr<RPCReceiver> receiver;
91

92
  /**
93
   * @brief Tensorpipe global context
94
95
   */
  std::shared_ptr<tensorpipe::Context> ctx;
96

97
  /**
98
   * @brief Server state data.
99
100
101
102
103
104
105
106
107
   *
   * If the process is a server, this stores necessary
   * server-side data. Otherwise, the process is a client and it stores a cache
   * of the server co-located with the client (if available). When the client
   * invokes a RPC to the co-located server, it can thus perform computation
   * locally without an actual remote call.
   */
  std::shared_ptr<ServerState> server_state;

108
  /**
109
   * @brief Cuurent group ID
110
111
112
113
114
   */
  int32_t group_id = -1;
  int32_t curr_client_id = -1;
  std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;

115
  /** @brief Get the RPC context singleton */
116
117
118
  static RPCContext* getInstance() {
    static RPCContext ctx;
    return &ctx;
119
  }
120

121
  /** @brief Reset the RPC context */
122
  static void Reset() {
123
    auto* t = getInstance();
124
125
126
    t->rank = -1;
    t->machine_id = -1;
    t->num_machines = 0;
127
128
    t->msg_seq = 0;
    t->num_servers = 0;
129
    t->num_clients = 0;
130
    t->barrier_count.clear();
131
    t->num_servers_per_machine = 0;
132
133
134
    t->sender.reset();
    t->receiver.reset();
    t->ctx.reset();
135
    t->server_state.reset();
136
137
138
139
140
141
    t->group_id = -1;
    t->curr_client_id = -1;
    t->clients_.clear();
  }

  int32_t RegisterClient(int32_t client_id, int32_t group_id) {
142
    auto&& m = clients_[group_id];
143
144
145
146
147
148
149
150
151
152
153
    if (m.find(client_id) != m.end()) {
      return -1;
    }
    m[client_id] = ++curr_client_id;
    return curr_client_id;
  }

  int32_t GetClient(int32_t client_id, int32_t group_id) const {
    if (clients_.find(group_id) == clients_.end()) {
      return -1;
    }
154
    const auto& m = clients_.at(group_id);
155
156
157
158
    if (m.find(client_id) == m.end()) {
      return -1;
    }
    return m.at(client_id);
159
160
161
  }
};

162
/**
163
 * @brief Send out one RPC message.
164
165
166
167
168
169
170
171
172
173
174
 *
 * The operation is non-blocking -- it does not guarantee the payloads have
 * reached the target or even have left the sender process. However,
 * all the payloads (i.e., data and arrays) can be safely freed after this
 * function returns.
 *
 * The data buffer in the requst will be copied to internal buffer for actual
 * transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
 * The underlying sending threads will hold references to the tensors until
 * the contents have been transmitted.
 *
175
176
 * @param msg RPC message to send
 * @return status flag
177
178
179
 */
RPCStatus SendRPCMessage(const RPCMessage& msg);

180
/**
181
 * @brief Receive one RPC message.
182
183
184
 *
 * The operation is blocking -- it returns when it receives any message
 *
185
186
187
 * @param msg The received message
 * @param timeout The timeout value in milliseconds. If zero, wait indefinitely.
 * @return status flag
188
189
190
191
192
193
194
 */
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_RPC_H_