rpc.h 3.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.h
 * \brief Common headers for remote process call (RPC).
 */
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_

#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
#include <cstdint>
#include <memory>
15
#include <deque>
16
17
#include <vector>
#include <string>
18
19
20
21
#include <mutex>

#include "./rpc_msg.h"
#include "./tensorpipe/tp_communicator.h"
22
23
24
25
26
27
#include "./network/common.h"
#include "./server_state.h"

namespace dgl {
namespace rpc {

28
29
struct RPCContext;

30
31
32
33
34
35
36
37
// Communicator handler type
typedef void* CommunicatorHandle;

/*! \brief Context information for RPC communication */
struct RPCContext {
  /*!
   * \brief Rank of this process.
   *
38
39
   * If the process is a client, this is equal to client ID. Otherwise, the
   * process is a server and this is equal to server ID.
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
   */
  int32_t rank = -1;

  /*!
   * \brief Cuurent machine ID
   */
  int32_t machine_id = -1;

  /*!
   * \brief Total number of machines.
   */
  int32_t num_machines = 0;

  /*!
   * \brief Message sequence number.
   */
56
  std::atomic<int64_t> msg_seq{0};
57
58
59
60
61
62

  /*!
   * \brief Total number of server.
   */
  int32_t num_servers = 0;

63
64
65
66
67
  /*!
   * \brief Total number of client.
   */
  int32_t num_clients = 0;

68
69
70
71
72
  /*!
   * \brief Current barrier count
   */
  int32_t barrier_count = 0;

73
74
75
76
77
  /*!
   * \brief Total number of server per machine.
   */
  int32_t num_servers_per_machine = 0;

78
79
80
  /*!
   * \brief Sender communicator.
   */
81
  std::shared_ptr<TPSender> sender;
82
83
84
85

  /*!
   * \brief Receiver communicator.
   */
86
87
88
89
90
91
  std::shared_ptr<TPReceiver> receiver;

  /*!
   * \brief Tensorpipe global context
   */
  std::shared_ptr<tensorpipe::Context> ctx;
92
93
94
95
96
97
98
99
100
101
102
103

  /*!
   * \brief Server state data.
   *
   * If the process is a server, this stores necessary
   * server-side data. Otherwise, the process is a client and it stores a cache
   * of the server co-located with the client (if available). When the client
   * invokes a RPC to the co-located server, it can thus perform computation
   * locally without an actual remote call.
   */
  std::shared_ptr<ServerState> server_state;

104
105
106
107
  /*! \brief Get the RPC context singleton */
  static RPCContext* getInstance() {
    static RPCContext ctx;
    return &ctx;
108
  }
109
110
111

  /*! \brief Reset the RPC context */
  static void Reset() {
112
    auto* t = getInstance();
113
114
115
116
117
118
    t->rank = -1;
    t->machine_id = -1;
    t->num_machines = 0;
    t->num_clients = 0;
    t->barrier_count = 0;
    t->num_servers_per_machine = 0;
119
120
121
    t->sender.reset();
    t->receiver.reset();
    t->ctx.reset();
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
  }
};

/*! \brief RPC status flag */
enum RPCStatus {
  kRPCSuccess = 0,
  kRPCTimeOut,
};

/*!
 * \brief Send out one RPC message.
 *
 * The operation is non-blocking -- it does not guarantee the payloads have
 * reached the target or even have left the sender process. However,
 * all the payloads (i.e., data and arrays) can be safely freed after this
 * function returns.
 *
 * The data buffer in the requst will be copied to internal buffer for actual
 * transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
 * The underlying sending threads will hold references to the tensors until
 * the contents have been transmitted.
 *
 * \param msg RPC message to send
 * \return status flag
 */
RPCStatus SendRPCMessage(const RPCMessage& msg);

/*!
 * \brief Receive one RPC message.
 *
 * The operation is blocking -- it returns when it receives any message
 *
 * \param msg The received message
 * \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
 * \return status flag
 */
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);

}  // namespace rpc
}  // namespace dgl

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);
}  // namespace dmlc

#endif  // DGL_RPC_RPC_H_