rpc.cc 19.3 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  static constexpr int32_t retry_timeout = 5 * 1000;  // milliseconds
  RPCStatus status;
  const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;
  do {
    status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);
    if (status == kRPCTimeOut) {
      static const std::string log_str = [real_timeout, timeout]() {
        std::ostringstream oss;
        oss << "Recv RPCMessage timeout in " << real_timeout << " ms."
            << (timeout == 0 ? " Retrying ..." : "");
        return oss.str();
      }();
      DLOG(WARNING) << log_str;
    }
  } while (timeout == 0 && status == kRPCTimeOut);
  return status;
88
89
}

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
      CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
      auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
                                                         std::move(listeners));
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

121
//////////////////////////// C APIs ////////////////////////////
122
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
123
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
124
125

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
126
.set_body([](DGLArgs args, DGLRetValue* rv) {
127
128
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
129
130
131
132
133
134
135
136
137
138
139
140
141
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->sender.reset(
        new TPSender(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->sender.reset(
        new network::SocketSender(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
  }
  LOG(INFO) << "Sender with NetType~"
            << RPCContext::getInstance()->sender->NetType() << " is created.";
142
143
144
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
145
.set_body([](DGLArgs args, DGLRetValue* rv) {
146
147
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
148
149
150
151
152
153
154
155
156
157
158
159
160
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->receiver.reset(
        new TPReceiver(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->receiver.reset(
        new network::SocketReceiver(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
  }
  LOG(INFO) << "Receiver with NetType~"
            << RPCContext::getInstance()->receiver->NetType() << " is created.";
161
162
163
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
164
165
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->Finalize();
166
167
168
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
169
170
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->receiver->Finalize();
171
172
});

173
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
174
.set_body([](DGLArgs args, DGLRetValue* rv) {
175
176
177
  std::string ip = args[0];
  int port = args[1];
  int num_sender = args[2];
178
  bool blocking = args[3];
179
  std::string addr;
180
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
181
  if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
182
183
184
185
    LOG(FATAL) << "Wait sender socket failed.";
  }
});

186
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
187
.set_body([](DGLArgs args, DGLRetValue* rv) {
188
189
190
191
  std::string ip = args[0];
  int port = args[1];
  int recv_id = args[2];
  std::string addr;
192
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
193
  *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
194
195
});

196
197
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
198
199
  const int max_try_times = args[0];
  *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times);
200
201
});

202
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
203
.set_body([](DGLArgs args, DGLRetValue* rv) {
204
  const int32_t rank = args[0];
205
  RPCContext::getInstance()->rank = rank;
206
207
208
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
209
210
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->rank;
211
212
213
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
214
.set_body([](DGLArgs args, DGLRetValue* rv) {
215
  const int32_t num_servers = args[0];
216
  *rv = RPCContext::getInstance()->num_servers = num_servers;
217
218
219
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
220
221
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers;
222
223
});

224
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
225
.set_body([](DGLArgs args, DGLRetValue* rv) {
226
  const int32_t num_clients = args[0];
227
  *rv = RPCContext::getInstance()->num_clients = num_clients;
228
229
230
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
231
232
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_clients;
233
234
});

235
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
236
.set_body([](DGLArgs args, DGLRetValue* rv) {
237
  const int32_t num_servers = args[0];
238
  *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
239
240
241
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
242
243
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers_per_machine;
244
245
});

246
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
247
248
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = (RPCContext::getInstance()->msg_seq)++;
249
250
251
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
252
253
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->msg_seq;
254
255
256
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
257
.set_body([](DGLArgs args, DGLRetValue* rv) {
258
  const int64_t msg_seq = args[0];
259
  RPCContext::getInstance()->msg_seq = msg_seq;
260
261
});

262
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
263
.set_body([](DGLArgs args, DGLRetValue* rv) {
264
265
266
267
268
269
  const int32_t group_id = args[0];
  auto&& cnt = RPCContext::getInstance()->barrier_count;
  if (cnt.find(group_id) == cnt.end()) {
    cnt.emplace(group_id, 0x0);
  }
  *rv = cnt[group_id];
270
271
272
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
273
.set_body([](DGLArgs args, DGLRetValue* rv) {
274
  const int32_t count = args[0];
275
276
  const int32_t group_id = args[1];
  RPCContext::getInstance()->barrier_count[group_id] = count;
277
278
});

279
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
280
281
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->machine_id;
282
283
284
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
285
.set_body([](DGLArgs args, DGLRetValue* rv) {
286
  const int32_t machine_id = args[0];
287
  RPCContext::getInstance()->machine_id = machine_id;
288
289
290
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
291
292
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_machines;
293
294
295
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
296
.set_body([](DGLArgs args, DGLRetValue* rv) {
297
  const int32_t num_machines = args[0];
298
  RPCContext::getInstance()->num_machines = num_machines;
299
300
301
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
302
.set_body([](DGLArgs args, DGLRetValue* rv) {
303
  RPCMessageRef msg = args[0];
304
305
  const int32_t target_id = args[1];
  *rv = SendRPCMessage(*(msg.sptr()), target_id);
306
307
308
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
309
.set_body([](DGLArgs args, DGLRetValue* rv) {
310
311
312
313
314
315
316
317
  int32_t timeout = args[0];
  RPCMessageRef msg = args[1];
  *rv = RecvRPCMessage(msg.sptr().get(), timeout);
});

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
318
319
320
321
322
.set_body([](DGLArgs args, DGLRetValue* rv) {
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

323
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
324
.set_body([](DGLArgs args, DGLRetValue* rv) {
325
326
327
328
329
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  rst->service_id = args[0];
  rst->msg_seq = args[1];
  rst->client_id = args[2];
  rst->server_id = args[3];
330
331
  const std::string data =
    args[4];  // directly assigning string value raises errors :(
332
333
  rst->data = data;
  rst->tensors = ListValueToVector<NDArray>(args[5]);
334
  rst->group_id = args[6];
335
336
337
338
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
339
.set_body([](DGLArgs args, DGLRetValue* rv) {
340
341
342
343
344
  const RPCMessageRef msg = args[0];
  *rv = msg->service_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
345
.set_body([](DGLArgs args, DGLRetValue* rv) {
346
347
348
349
350
  const RPCMessageRef msg = args[0];
  *rv = msg->msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
351
.set_body([](DGLArgs args, DGLRetValue* rv) {
352
353
354
355
356
  const RPCMessageRef msg = args[0];
  *rv = msg->client_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
357
.set_body([](DGLArgs args, DGLRetValue* rv) {
358
359
360
361
362
  const RPCMessageRef msg = args[0];
  *rv = msg->server_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
363
.set_body([](DGLArgs args, DGLRetValue* rv) {
364
365
366
367
368
369
  const RPCMessageRef msg = args[0];
  DGLByteArray barr{msg->data.c_str(), msg->data.size()};
  *rv = barr;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
370
.set_body([](DGLArgs args, DGLRetValue* rv) {
371
372
373
374
375
376
377
378
  const RPCMessageRef msg = args[0];
  List<Value> ret;
  for (size_t i = 0; i < msg->tensors.size(); ++i) {
    ret.push_back(Value(MakeValue(msg->tensors[i])));
  }
  *rv = ret;
});

379
380
#if defined(__linux__)
/*!
381
 * \brief The signal handler.
382
383
 * \param s signal
 */
384
void SigHandler(int s) {
385
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
386
  CleanupResources();
387
388
389
  exit(1);
}

390
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
391
.set_body([](DGLArgs args, DGLRetValue* rv) {
392
  // Ctrl+C handler
393
394
395
396
397
398
  struct sigaction sigHandler;
  sigHandler.sa_handler = SigHandler;
  sigemptyset(&sigHandler.sa_mask);
  sigHandler.sa_flags = 0;
  sigaction(SIGINT, &sigHandler, nullptr);
  sigaction(SIGTERM, &sigHandler, nullptr);
399
400
401
});
#endif

402
403
404
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
405
406
.set_body([](DGLArgs args, DGLRetValue* rv) {
  auto st = RPCContext::getInstance()->server_state;
407
  if (st.get() == nullptr) {
408
    RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
409
  }
410
411
412
  *rv = st;
});

413
414
415
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
416
.set_body([](DGLArgs args, DGLRetValue* rv) {
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
  NDArray ID = args[0];
  NDArray part_id = args[1];
  int local_machine_id = args[2];
  int64_t* ID_data = static_cast<int64_t*>(ID->data);
  int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
  int64_t ID_size = ID.GetSize() / sizeof(int64_t);
  std::vector<int64_t> global_id;
  for (int64_t i = 0; i < ID_size; ++i) {
    if (part_id_data[i] == local_machine_id) {
      global_id.push_back(ID_data[i]);
    }
  }
  NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
  *rv = res_tensor;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
434
.set_body([](DGLArgs args, DGLRetValue* rv) {
435
436
437
438
439
440
441
  // Input
  std::string name = args[0];
  int local_machine_id = args[1];
  int machine_count = args[2];
  int group_count = args[3];
  int client_id = args[4];
  int service_id = args[5];
442
  int64_t msg_seq = args[6];
443
444
445
446
447
448
449
450
451
452
453
454
455
456
  std::string pickle_data = args[7];
  NDArray ID = args[8];
  NDArray part_id = args[9];
  NDArray local_id = args[10];
  NDArray local_data = args[11];
  // Data
  dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
  dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
  dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
  dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
  char* local_data_char = static_cast<char*>(local_data->data);
  std::vector<dgl_id_t> local_ids;
  std::vector<dgl_id_t> local_ids_orginal;
  std::vector<int64_t> local_data_shape;
457
458
  std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
  std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
  // Get row size (in bytes)
  int row_size = 1;
  for (int i = 0; i < local_data->ndim; ++i) {
    local_data_shape.push_back(local_data->shape[i]);
    if (i != 0) {
      row_size *= local_data->shape[i];
    }
  }
  row_size *= (local_data->dtype.bits / 8);
  size_t data_size = local_data.GetSize();
  CHECK_GT(local_data_shape.size(), 0);
  CHECK_EQ(row_size * local_data_shape[0], data_size);
  // Get local id (used in local machine) and
  // remote id (send to remote machine)
  dgl_id_t idx = 0;
  for (dgl_id_t i = 0; i < ID_size; ++i) {
    dgl_id_t p_id = part_id_data[i];
476
    if (static_cast<int>(p_id) == local_machine_id) {
477
478
479
480
481
482
483
484
485
486
487
488
489
490
      dgl_id_t l_id = local_id_data[idx++];
      CHECK_LT(l_id, local_data_shape[0]);
      CHECK_GE(l_id, 0);
      local_ids.push_back(l_id);
      local_ids_orginal.push_back(i);
    } else {
      CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
      dgl_id_t id = ID_data[i];
      remote_ids[p_id].push_back(id);
      remote_ids_original[p_id].push_back(i);
    }
  }
  // Send remote id
  int msg_count = 0;
491
  for (size_t i = 0; i < remote_ids.size(); ++i) {
492
493
494
495
496
    if (remote_ids[i].size() != 0) {
      RPCMessage msg;
      msg.service_id = service_id;
      msg.msg_seq = msg_seq;
      msg.client_id = client_id;
497
498
      int lower = i * group_count;
      int upper = (i + 1) * group_count;
499
500
501
502
      msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
      msg.data = pickle_data;
      NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
      msg.tensors.push_back(tensor);
503
      msg.group_id = RPCContext::getInstance()->group_id;
504
505
506
507
508
509
510
      SendRPCMessage(msg, msg.server_id);
      msg_count++;
    }
  }
  local_data_shape[0] = ID_size;
  NDArray res_tensor = NDArray::Empty(local_data_shape,
                                      local_data->dtype,
511
                                      DGLContext{kDGLCPU, 0});
512
513
  char* return_data = static_cast<char*>(res_tensor->data);
  // Copy local data
514
515
  parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
516
517
      CHECK_GE(ID_size * row_size,
                local_ids_orginal[i] * row_size + row_size);
518
519
520
      CHECK_GE(data_size, local_ids[i] * row_size + row_size);
      CHECK_GE(local_ids[i], 0);
      memcpy(return_data + local_ids_orginal[i] * row_size,
521
             local_data_char + local_ids[i] * row_size, row_size);
522
523
    }
  });
524
  // Recv remote message
525
526
  int recv_cnt = 0;
  while (recv_cnt < msg_count) {
527
    RPCMessage msg;
528
529
530
    auto status = RecvRPCMessage(&msg, 0);
    CHECK_EQ(status, kRPCSuccess);
    ++recv_cnt;
531
532
533
534
535
    int part_id = msg.server_id / group_count;
    char* data_char = static_cast<char*>(msg.tensors[0]->data);
    dgl_id_t id_size = remote_ids[part_id].size();
    for (size_t n = 0; n < id_size; ++n) {
      memcpy(return_data + remote_ids_original[part_id][n] * row_size,
536
             data_char + n * row_size, row_size);
537
538
539
540
541
    }
  }
  *rv = res_tensor;
});

542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t group_id = args[0];
  RPCContext::getInstance()->group_id = group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const RPCMessageRef msg = args[0];
  *rv = msg->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});

573
574
}  // namespace rpc
}  // namespace dgl
575
576

#endif