rpc.cc 17.2 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file rpc/rpc.cc
 * @brief Implementation of RPC utilities used by both server and client sides.
5
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
#include <unistd.h>

#include <csignal>
#include <future>

20
#include "../c_api_common.h"
21
#include "../runtime/resource_manager.h"
22
23
24
25
26
27
28

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

29
30
31
32
33
34
35
// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
36
37
38
39
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
  static constexpr int32_t retry_timeout = 5 * 1000;  // milliseconds
  RPCStatus status;
  const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;
  do {
    status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);
    if (status == kRPCTimeOut) {
      static const std::string log_str = [real_timeout, timeout]() {
        std::ostringstream oss;
        oss << "Recv RPCMessage timeout in " << real_timeout << " ms."
            << (timeout == 0 ? " Retrying ..." : "");
        return oss.str();
      }();
      DLOG(WARNING) << log_str;
    }
  } while (timeout == 0 && status == kRPCTimeOut);
  return status;
56
57
58
}

//////////////////////////// C APIs ////////////////////////////
59
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
60
    .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
61
62

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
63
64
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int64_t msg_queue_size = args[0];
65
66
67
      int max_thread_count = args[1];
      RPCContext::getInstance()->sender.reset(
          new network::SocketSender(msg_queue_size, max_thread_count));
68
    });
69
70

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
71
72
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int64_t msg_queue_size = args[0];
73
74
75
      int max_thread_count = args[1];
      RPCContext::getInstance()->receiver.reset(
          new network::SocketReceiver(msg_queue_size, max_thread_count));
76
    });
77
78

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
79
80
81
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCContext::getInstance()->sender->Finalize();
    });
82
83

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
84
85
86
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCContext::getInstance()->receiver->Finalize();
    });
87

88
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
89
90
91
92
93
94
95
96
97
98
99
100
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::string ip = args[0];
      int port = args[1];
      int num_sender = args[2];
      bool blocking = args[3];
      std::string addr;
      addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
      if (RPCContext::getInstance()->receiver->Wait(
              addr, num_sender, blocking) == false) {
        LOG(FATAL) << "Wait sender socket failed.";
      }
    });
101

102
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
103
104
105
106
107
108
109
110
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::string ip = args[0];
      int port = args[1];
      int recv_id = args[2];
      std::string addr;
      addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
      *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
    });
111

112
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
113
114
115
116
117
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int max_try_times = args[0];
      *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(
          max_try_times);
    });
118

119
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
120
121
122
123
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t rank = args[0];
      RPCContext::getInstance()->rank = rank;
    });
124
125

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
126
127
128
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->rank;
    });
129
130

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
131
132
133
134
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_servers = args[0];
      *rv = RPCContext::getInstance()->num_servers = num_servers;
    });
135
136

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
137
138
139
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_servers;
    });
140

141
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
142
143
144
145
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_clients = args[0];
      *rv = RPCContext::getInstance()->num_clients = num_clients;
    });
146
147

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
148
149
150
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_clients;
    });
151

152
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
153
154
155
156
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_servers = args[0];
      *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
    });
157
158

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
159
160
161
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_servers_per_machine;
    });
162

163
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
164
165
166
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = (RPCContext::getInstance()->msg_seq)++;
    });
167
168

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
169
170
171
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->msg_seq;
    });
172
173

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
174
175
176
177
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int64_t msg_seq = args[0];
      RPCContext::getInstance()->msg_seq = msg_seq;
    });
178

179
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
180
181
182
183
184
185
186
187
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t group_id = args[0];
      auto&& cnt = RPCContext::getInstance()->barrier_count;
      if (cnt.find(group_id) == cnt.end()) {
        cnt.emplace(group_id, 0x0);
      }
      *rv = cnt[group_id];
    });
188
189

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
190
191
192
193
194
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t count = args[0];
      const int32_t group_id = args[1];
      RPCContext::getInstance()->barrier_count[group_id] = count;
    });
195

196
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
197
198
199
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->machine_id;
    });
200
201

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
202
203
204
205
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t machine_id = args[0];
      RPCContext::getInstance()->machine_id = machine_id;
    });
206
207

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
208
209
210
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_machines;
    });
211
212

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
213
214
215
216
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_machines = args[0];
      RPCContext::getInstance()->num_machines = num_machines;
    });
217
218

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
219
220
221
222
223
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCMessageRef msg = args[0];
      const int32_t target_id = args[1];
      *rv = SendRPCMessage(*(msg.sptr()), target_id);
    });
224
225

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
226
227
228
229
230
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int32_t timeout = args[0];
      RPCMessageRef msg = args[1];
      *rv = RecvRPCMessage(msg.sptr().get(), timeout);
    });
231
232
233
234

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
235
236
237
238
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::shared_ptr<RPCMessage> rst(new RPCMessage);
      *rv = rst;
    });
239

240
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
241
242
243
244
245
246
247
248
249
250
251
252
253
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::shared_ptr<RPCMessage> rst(new RPCMessage);
      rst->service_id = args[0];
      rst->msg_seq = args[1];
      rst->client_id = args[2];
      rst->server_id = args[3];
      const std::string data =
          args[4];  // directly assigning string value raises errors :(
      rst->data = data;
      rst->tensors = ListValueToVector<NDArray>(args[5]);
      rst->group_id = args[6];
      *rv = rst;
    });
254
255

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
256
257
258
259
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->service_id;
    });
260
261

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
262
263
264
265
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->msg_seq;
    });
266
267

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
268
269
270
271
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->client_id;
    });
272
273

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
274
275
276
277
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->server_id;
    });
278
279

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
280
281
282
283
284
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      DGLByteArray barr{msg->data.c_str(), msg->data.size()};
      *rv = barr;
    });
285
286

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
287
288
289
290
291
292
293
294
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      List<Value> ret;
      for (size_t i = 0; i < msg->tensors.size(); ++i) {
        ret.push_back(Value(MakeValue(msg->tensors[i])));
      }
      *rv = ret;
    });
295

296
#if defined(__linux__)
297
/**
298
299
 * @brief The signal handler.
 * @param s signal
300
 */
301
void SigHandler(int s) {
302
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
303
  CleanupResources();
304
305
306
  exit(1);
}

307
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
308
309
310
311
312
313
314
315
316
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      // Ctrl+C handler
      struct sigaction sigHandler;
      sigHandler.sa_handler = SigHandler;
      sigemptyset(&sigHandler.sa_mask);
      sigHandler.sa_flags = 0;
      sigaction(SIGINT, &sigHandler, nullptr);
      sigaction(SIGTERM, &sigHandler, nullptr);
    });
317
318
#endif

319
320
321
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
322
323
324
325
326
327
328
329
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      auto st = RPCContext::getInstance()->server_state;
      if (st.get() == nullptr) {
        RPCContext::getInstance()->server_state =
            std::make_shared<ServerState>();
      }
      *rv = st;
    });
330

331
332
333
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray ID = args[0];
      NDArray part_id = args[1];
      int local_machine_id = args[2];
      int64_t* ID_data = static_cast<int64_t*>(ID->data);
      int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
      int64_t ID_size = ID.GetSize() / sizeof(int64_t);
      std::vector<int64_t> global_id;
      for (int64_t i = 0; i < ID_size; ++i) {
        if (part_id_data[i] == local_machine_id) {
          global_id.push_back(ID_data[i]);
        }
      }
      NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
      *rv = res_tensor;
    });
350
351

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      // Input
      std::string name = args[0];
      int local_machine_id = args[1];
      int machine_count = args[2];
      int group_count = args[3];
      int client_id = args[4];
      int service_id = args[5];
      int64_t msg_seq = args[6];
      std::string pickle_data = args[7];
      NDArray ID = args[8];
      NDArray part_id = args[9];
      NDArray local_id = args[10];
      NDArray local_data = args[11];
      // Data
      dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
      dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
      dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
      dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
      char* local_data_char = static_cast<char*>(local_data->data);
      std::vector<dgl_id_t> local_ids;
      std::vector<dgl_id_t> local_ids_orginal;
      std::vector<int64_t> local_data_shape;
      std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
      std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
      // Get row size (in bytes)
      int row_size = 1;
      for (int i = 0; i < local_data->ndim; ++i) {
        local_data_shape.push_back(local_data->shape[i]);
        if (i != 0) {
          row_size *= local_data->shape[i];
        }
      }
      row_size *= (local_data->dtype.bits / 8);
      size_t data_size = local_data.GetSize();
      CHECK_GT(local_data_shape.size(), 0);
      CHECK_EQ(row_size * local_data_shape[0], data_size);
      // Get local id (used in local machine) and
      // remote id (send to remote machine)
      dgl_id_t idx = 0;
      for (dgl_id_t i = 0; i < ID_size; ++i) {
        dgl_id_t p_id = part_id_data[i];
        if (static_cast<int>(p_id) == local_machine_id) {
          dgl_id_t l_id = local_id_data[idx++];
          CHECK_LT(l_id, local_data_shape[0]);
          CHECK_GE(l_id, 0);
          local_ids.push_back(l_id);
          local_ids_orginal.push_back(i);
        } else {
          CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
          dgl_id_t id = ID_data[i];
          remote_ids[p_id].push_back(id);
          remote_ids_original[p_id].push_back(i);
        }
      }
      // Send remote id
      int msg_count = 0;
      for (size_t i = 0; i < remote_ids.size(); ++i) {
        if (remote_ids[i].size() != 0) {
          RPCMessage msg;
          msg.service_id = service_id;
          msg.msg_seq = msg_seq;
          msg.client_id = client_id;
          int lower = i * group_count;
          int upper = (i + 1) * group_count;
          msg.server_id =
              dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
          msg.data = pickle_data;
          NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
          msg.tensors.push_back(tensor);
          msg.group_id = RPCContext::getInstance()->group_id;
          SendRPCMessage(msg, msg.server_id);
          msg_count++;
        }
      }
      local_data_shape[0] = ID_size;
      NDArray res_tensor = NDArray::Empty(
          local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0});
      char* return_data = static_cast<char*>(res_tensor->data);
      // Copy local data
      parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
        for (auto i = b; i < e; ++i) {
          CHECK_GE(
              ID_size * row_size, local_ids_orginal[i] * row_size + row_size);
          CHECK_GE(data_size, local_ids[i] * row_size + row_size);
          CHECK_GE(local_ids[i], 0);
          memcpy(
              return_data + local_ids_orginal[i] * row_size,
              local_data_char + local_ids[i] * row_size, row_size);
        }
      });
      // Recv remote message
      int recv_cnt = 0;
      while (recv_cnt < msg_count) {
        RPCMessage msg;
        auto status = RecvRPCMessage(&msg, 0);
        CHECK_EQ(status, kRPCSuccess);
        ++recv_cnt;
        int part_id = msg.server_id / group_count;
        char* data_char = static_cast<char*>(msg.tensors[0]->data);
        dgl_id_t id_size = remote_ids[part_id].size();
        for (size_t n = 0; n < id_size; ++n) {
          memcpy(
              return_data + remote_ids_original[part_id][n] * row_size,
              data_char + n * row_size, row_size);
        }
      }
      *rv = res_tensor;
    });
461

462
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
463
464
465
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->group_id;
    });
466
467

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
468
469
470
471
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t group_id = args[0];
      RPCContext::getInstance()->group_id = group_id;
    });
472
473

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
474
475
476
477
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->group_id;
    });
478
479

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
480
481
482
483
484
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t client_id = args[0];
      const int32_t group_id = args[1];
      *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
    });
485
486

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
487
488
489
490
491
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t client_id = args[0];
      const int32_t group_id = args[1];
      *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
    });
492

493
494
}  // namespace rpc
}  // namespace dgl
495
496

#endif