rpc.h 4.4 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file rpc/rpc.h
 * @brief Common headers for remote process call (RPC).
5
6
7
8
9
 */
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_

#include <dgl/runtime/ndarray.h>
10
#include <dgl/runtime/object.h>
11
12
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
13

14
#include <cstdint>
15
#include <deque>
16
#include <memory>
17
#include <mutex>
18
#include <string>
19
#include <unordered_map>
20
#include <vector>
21

22
#include "./network/common.h"
23
#include "./rpc_msg.h"
24
#include "./server_state.h"
25
#include "network/socket_communicator.h"
26
27
28
29

namespace dgl {
namespace rpc {

30
31
struct RPCContext;

32
33
34
// Communicator handler type
typedef void* CommunicatorHandle;

35
/** @brief Context information for RPC communication */
36
struct RPCContext {
37
  /**
38
   * @brief Rank of this process.
39
   *
40
41
   * If the process is a client, this is equal to client ID. Otherwise, the
   * process is a server and this is equal to server ID.
42
43
44
   */
  int32_t rank = -1;

45
  /**
46
   * @brief Cuurent machine ID
47
48
49
   */
  int32_t machine_id = -1;

50
  /**
51
   * @brief Total number of machines.
52
53
54
   */
  int32_t num_machines = 0;

55
  /**
56
   * @brief Message sequence number.
57
   */
58
  std::atomic<int64_t> msg_seq{0};
59

60
  /**
61
   * @brief Total number of server.
62
63
64
   */
  int32_t num_servers = 0;

65
  /**
66
   * @brief Total number of client.
67
68
69
   */
  int32_t num_clients = 0;

70
  /**
71
   * @brief Current barrier count
72
   */
73
  std::unordered_map<int32_t, int32_t> barrier_count;
74

75
  /**
76
   * @brief Total number of server per machine.
77
78
79
   */
  int32_t num_servers_per_machine = 0;

80
  /**
81
   * @brief Sender communicator.
82
   */
83
  std::shared_ptr<network::SocketSender> sender;
84

85
  /**
86
   * @brief Receiver communicator.
87
   */
88
  std::shared_ptr<network::SocketReceiver> receiver;
89

90
  /**
91
   * @brief Server state data.
92
93
94
95
96
97
98
99
100
   *
   * If the process is a server, this stores necessary
   * server-side data. Otherwise, the process is a client and it stores a cache
   * of the server co-located with the client (if available). When the client
   * invokes a RPC to the co-located server, it can thus perform computation
   * locally without an actual remote call.
   */
  std::shared_ptr<ServerState> server_state;

101
  /**
102
   * @brief Cuurent group ID
103
104
105
106
107
   */
  int32_t group_id = -1;
  int32_t curr_client_id = -1;
  std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;

108
  /** @brief Get the RPC context singleton */
109
110
111
  static RPCContext* getInstance() {
    static RPCContext ctx;
    return &ctx;
112
  }
113

114
  /** @brief Reset the RPC context */
115
  static void Reset() {
116
    auto* t = getInstance();
117
118
119
    t->rank = -1;
    t->machine_id = -1;
    t->num_machines = 0;
120
121
    t->msg_seq = 0;
    t->num_servers = 0;
122
    t->num_clients = 0;
123
    t->barrier_count.clear();
124
    t->num_servers_per_machine = 0;
125
126
    t->sender.reset();
    t->receiver.reset();
127
    t->server_state.reset();
128
129
130
131
132
133
    t->group_id = -1;
    t->curr_client_id = -1;
    t->clients_.clear();
  }

  int32_t RegisterClient(int32_t client_id, int32_t group_id) {
134
    auto&& m = clients_[group_id];
135
136
137
138
139
140
141
142
143
144
145
    if (m.find(client_id) != m.end()) {
      return -1;
    }
    m[client_id] = ++curr_client_id;
    return curr_client_id;
  }

  int32_t GetClient(int32_t client_id, int32_t group_id) const {
    if (clients_.find(group_id) == clients_.end()) {
      return -1;
    }
146
    const auto& m = clients_.at(group_id);
147
148
149
150
    if (m.find(client_id) == m.end()) {
      return -1;
    }
    return m.at(client_id);
151
152
153
  }
};

154
/**
155
 * @brief Send out one RPC message.
156
157
158
159
160
161
162
163
164
165
166
 *
 * The operation is non-blocking -- it does not guarantee the payloads have
 * reached the target or even have left the sender process. However,
 * all the payloads (i.e., data and arrays) can be safely freed after this
 * function returns.
 *
 * The data buffer in the requst will be copied to internal buffer for actual
 * transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
 * The underlying sending threads will hold references to the tensors until
 * the contents have been transmitted.
 *
167
168
 * @param msg RPC message to send
 * @return status flag
169
170
171
 */
RPCStatus SendRPCMessage(const RPCMessage& msg);

172
/**
173
 * @brief Receive one RPC message.
174
175
176
 *
 * The operation is blocking -- it returns when it receives any message
 *
177
178
179
 * @param msg The received message
 * @param timeout The timeout value in milliseconds. If zero, wait indefinitely.
 * @return status flag
180
181
182
183
184
185
186
 */
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_RPC_H_