rpc.h 4.45 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file rpc/rpc.h
 * @brief Common headers for remote process call (RPC).
6
7
8
9
10
 */
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_

#include <dgl/runtime/ndarray.h>
11
#include <dgl/runtime/object.h>
12
13
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
14

15
#include <cstdint>
16
#include <deque>
17
#include <memory>
18
#include <mutex>
19
#include <string>
20
#include <unordered_map>
21
#include <vector>
22

sangwzh's avatar
sangwzh committed
23
24
25
#include "network/common.h"
#include "rpc_msg.h"
#include "server_state.h"
26
#include "network/socket_communicator.h"
27
28
29
30

namespace dgl {
namespace rpc {

31
32
struct RPCContext;

33
34
35
// Communicator handler type
typedef void* CommunicatorHandle;

36
/** @brief Context information for RPC communication */
37
struct RPCContext {
38
  /**
39
   * @brief Rank of this process.
40
   *
41
42
   * If the process is a client, this is equal to client ID. Otherwise, the
   * process is a server and this is equal to server ID.
43
44
45
   */
  int32_t rank = -1;

46
  /**
47
   * @brief Cuurent machine ID
48
49
50
   */
  int32_t machine_id = -1;

51
  /**
52
   * @brief Total number of machines.
53
54
55
   */
  int32_t num_machines = 0;

56
  /**
57
   * @brief Message sequence number.
58
   */
59
  std::atomic<int64_t> msg_seq{0};
60

61
  /**
62
   * @brief Total number of server.
63
64
65
   */
  int32_t num_servers = 0;

66
  /**
67
   * @brief Total number of client.
68
69
70
   */
  int32_t num_clients = 0;

71
  /**
72
   * @brief Current barrier count
73
   */
74
  std::unordered_map<int32_t, int32_t> barrier_count;
75

76
  /**
77
   * @brief Total number of server per machine.
78
79
80
   */
  int32_t num_servers_per_machine = 0;

81
  /**
82
   * @brief Sender communicator.
83
   */
84
  std::shared_ptr<network::SocketSender> sender;
85

86
  /**
87
   * @brief Receiver communicator.
88
   */
89
  std::shared_ptr<network::SocketReceiver> receiver;
90

91
  /**
92
   * @brief Server state data.
93
94
95
96
97
98
99
100
101
   *
   * If the process is a server, this stores necessary
   * server-side data. Otherwise, the process is a client and it stores a cache
   * of the server co-located with the client (if available). When the client
   * invokes a RPC to the co-located server, it can thus perform computation
   * locally without an actual remote call.
   */
  std::shared_ptr<ServerState> server_state;

102
  /**
103
   * @brief Cuurent group ID
104
105
106
107
108
   */
  int32_t group_id = -1;
  int32_t curr_client_id = -1;
  std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;

109
  /** @brief Get the RPC context singleton */
110
111
112
  static RPCContext* getInstance() {
    static RPCContext ctx;
    return &ctx;
113
  }
114

115
  /** @brief Reset the RPC context */
116
  static void Reset() {
117
    auto* t = getInstance();
118
119
120
    t->rank = -1;
    t->machine_id = -1;
    t->num_machines = 0;
121
122
    t->msg_seq = 0;
    t->num_servers = 0;
123
    t->num_clients = 0;
124
    t->barrier_count.clear();
125
    t->num_servers_per_machine = 0;
126
127
    t->sender.reset();
    t->receiver.reset();
128
    t->server_state.reset();
129
130
131
132
133
134
    t->group_id = -1;
    t->curr_client_id = -1;
    t->clients_.clear();
  }

  int32_t RegisterClient(int32_t client_id, int32_t group_id) {
135
    auto&& m = clients_[group_id];
136
137
138
139
140
141
142
143
144
145
146
    if (m.find(client_id) != m.end()) {
      return -1;
    }
    m[client_id] = ++curr_client_id;
    return curr_client_id;
  }

  int32_t GetClient(int32_t client_id, int32_t group_id) const {
    if (clients_.find(group_id) == clients_.end()) {
      return -1;
    }
147
    const auto& m = clients_.at(group_id);
148
149
150
151
    if (m.find(client_id) == m.end()) {
      return -1;
    }
    return m.at(client_id);
152
153
154
  }
};

155
/**
156
 * @brief Send out one RPC message.
157
158
159
160
161
162
163
164
165
166
167
 *
 * The operation is non-blocking -- it does not guarantee the payloads have
 * reached the target or even have left the sender process. However,
 * all the payloads (i.e., data and arrays) can be safely freed after this
 * function returns.
 *
 * The data buffer in the requst will be copied to internal buffer for actual
 * transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
 * The underlying sending threads will hold references to the tensors until
 * the contents have been transmitted.
 *
168
169
 * @param msg RPC message to send
 * @return status flag
170
171
172
 */
RPCStatus SendRPCMessage(const RPCMessage& msg);

173
/**
174
 * @brief Receive one RPC message.
175
176
177
 *
 * The operation is blocking -- it returns when it receives any message
 *
178
179
180
 * @param msg The received message
 * @param timeout The timeout value in milliseconds. If zero, wait indefinitely.
 * @return status flag
181
182
183
184
185
186
187
 */
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);

}  // namespace rpc
}  // namespace dgl

#endif  // DGL_RPC_RPC_H_