rpc.cc 20.7 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  static constexpr int32_t retry_timeout = 5 * 1000;  // milliseconds
  RPCStatus status;
  const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;
  do {
    status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);
    if (status == kRPCTimeOut) {
      static const std::string log_str = [real_timeout, timeout]() {
        std::ostringstream oss;
        oss << "Recv RPCMessage timeout in " << real_timeout << " ms."
            << (timeout == 0 ? " Retrying ..." : "");
        return oss.str();
      }();
      DLOG(WARNING) << log_str;
    }
  } while (timeout == 0 && status == kRPCTimeOut);
  return status;
88
89
}

90
91
92
93
94
95
96
97
98
99
100
101
102
103
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
104
105
      CHECK(numUvThreads > 0)
          << "DGL_SOCKET_NTHREADS should be positive integer if set";
106
107
108
109
110
111
112
113
114
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
115
116
      auto mptChannel = tensorpipe::channel::mpt::create(
          std::move(contexts), std::move(listeners));
117
118
119
120
121
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

122
//////////////////////////// C APIs ////////////////////////////
123
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
124
    .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
125
126

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int64_t msg_queue_size = args[0];
      std::string type = args[1];
      int max_thread_count = args[2];
      if (type == "tensorpipe") {
        InitGlobalTpContext();
        RPCContext::getInstance()->sender.reset(
            new TPSender(RPCContext::getInstance()->ctx));
      } else if (type == "socket") {
        RPCContext::getInstance()->sender.reset(
            new network::SocketSender(msg_queue_size, max_thread_count));
      } else {
        LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
      }
      LOG(INFO) << "Sender with NetType~"
                << RPCContext::getInstance()->sender->NetType()
                << " is created.";
    });
145
146

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int64_t msg_queue_size = args[0];
      std::string type = args[1];
      int max_thread_count = args[2];
      if (type == "tensorpipe") {
        InitGlobalTpContext();
        RPCContext::getInstance()->receiver.reset(
            new TPReceiver(RPCContext::getInstance()->ctx));
      } else if (type == "socket") {
        RPCContext::getInstance()->receiver.reset(
            new network::SocketReceiver(msg_queue_size, max_thread_count));
      } else {
        LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
      }
      LOG(INFO) << "Receiver with NetType~"
                << RPCContext::getInstance()->receiver->NetType()
                << " is created.";
    });
165
166

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
167
168
169
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCContext::getInstance()->sender->Finalize();
    });
170
171

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
172
173
174
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCContext::getInstance()->receiver->Finalize();
    });
175

176
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
177
178
179
180
181
182
183
184
185
186
187
188
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::string ip = args[0];
      int port = args[1];
      int num_sender = args[2];
      bool blocking = args[3];
      std::string addr;
      addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
      if (RPCContext::getInstance()->receiver->Wait(
              addr, num_sender, blocking) == false) {
        LOG(FATAL) << "Wait sender socket failed.";
      }
    });
189

190
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
191
192
193
194
195
196
197
198
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::string ip = args[0];
      int port = args[1];
      int recv_id = args[2];
      std::string addr;
      addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
      *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
    });
199

200
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
201
202
203
204
205
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int max_try_times = args[0];
      *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(
          max_try_times);
    });
206

207
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
208
209
210
211
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t rank = args[0];
      RPCContext::getInstance()->rank = rank;
    });
212
213

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
214
215
216
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->rank;
    });
217
218

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
219
220
221
222
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_servers = args[0];
      *rv = RPCContext::getInstance()->num_servers = num_servers;
    });
223
224

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
225
226
227
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_servers;
    });
228

229
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
230
231
232
233
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_clients = args[0];
      *rv = RPCContext::getInstance()->num_clients = num_clients;
    });
234
235

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
236
237
238
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_clients;
    });
239

240
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
241
242
243
244
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_servers = args[0];
      *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
    });
245
246

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
247
248
249
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_servers_per_machine;
    });
250

251
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
252
253
254
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = (RPCContext::getInstance()->msg_seq)++;
    });
255
256

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
257
258
259
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->msg_seq;
    });
260
261

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
262
263
264
265
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int64_t msg_seq = args[0];
      RPCContext::getInstance()->msg_seq = msg_seq;
    });
266

267
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
268
269
270
271
272
273
274
275
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t group_id = args[0];
      auto&& cnt = RPCContext::getInstance()->barrier_count;
      if (cnt.find(group_id) == cnt.end()) {
        cnt.emplace(group_id, 0x0);
      }
      *rv = cnt[group_id];
    });
276
277

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
278
279
280
281
282
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t count = args[0];
      const int32_t group_id = args[1];
      RPCContext::getInstance()->barrier_count[group_id] = count;
    });
283

284
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
285
286
287
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->machine_id;
    });
288
289

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
290
291
292
293
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t machine_id = args[0];
      RPCContext::getInstance()->machine_id = machine_id;
    });
294
295

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
296
297
298
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_machines;
    });
299
300

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
301
302
303
304
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_machines = args[0];
      RPCContext::getInstance()->num_machines = num_machines;
    });
305
306

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
307
308
309
310
311
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCMessageRef msg = args[0];
      const int32_t target_id = args[1];
      *rv = SendRPCMessage(*(msg.sptr()), target_id);
    });
312
313

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
314
315
316
317
318
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int32_t timeout = args[0];
      RPCMessageRef msg = args[1];
      *rv = RecvRPCMessage(msg.sptr().get(), timeout);
    });
319
320
321
322

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
323
324
325
326
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::shared_ptr<RPCMessage> rst(new RPCMessage);
      *rv = rst;
    });
327

328
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
329
330
331
332
333
334
335
336
337
338
339
340
341
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::shared_ptr<RPCMessage> rst(new RPCMessage);
      rst->service_id = args[0];
      rst->msg_seq = args[1];
      rst->client_id = args[2];
      rst->server_id = args[3];
      const std::string data =
          args[4];  // directly assigning string value raises errors :(
      rst->data = data;
      rst->tensors = ListValueToVector<NDArray>(args[5]);
      rst->group_id = args[6];
      *rv = rst;
    });
342
343

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
344
345
346
347
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->service_id;
    });
348
349

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
350
351
352
353
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->msg_seq;
    });
354
355

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
356
357
358
359
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->client_id;
    });
360
361

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
362
363
364
365
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->server_id;
    });
366
367

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
368
369
370
371
372
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      DGLByteArray barr{msg->data.c_str(), msg->data.size()};
      *rv = barr;
    });
373
374

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
375
376
377
378
379
380
381
382
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      List<Value> ret;
      for (size_t i = 0; i < msg->tensors.size(); ++i) {
        ret.push_back(Value(MakeValue(msg->tensors[i])));
      }
      *rv = ret;
    });
383

384
385
#if defined(__linux__)
/*!
386
 * \brief The signal handler.
387
388
 * \param s signal
 */
389
void SigHandler(int s) {
390
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
391
  CleanupResources();
392
393
394
  exit(1);
}

395
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
396
397
398
399
400
401
402
403
404
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      // Ctrl+C handler
      struct sigaction sigHandler;
      sigHandler.sa_handler = SigHandler;
      sigemptyset(&sigHandler.sa_mask);
      sigHandler.sa_flags = 0;
      sigaction(SIGINT, &sigHandler, nullptr);
      sigaction(SIGTERM, &sigHandler, nullptr);
    });
405
406
#endif

407
408
409
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
410
411
412
413
414
415
416
417
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      auto st = RPCContext::getInstance()->server_state;
      if (st.get() == nullptr) {
        RPCContext::getInstance()->server_state =
            std::make_shared<ServerState>();
      }
      *rv = st;
    });
418

419
420
421
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray ID = args[0];
      NDArray part_id = args[1];
      int local_machine_id = args[2];
      int64_t* ID_data = static_cast<int64_t*>(ID->data);
      int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
      int64_t ID_size = ID.GetSize() / sizeof(int64_t);
      std::vector<int64_t> global_id;
      for (int64_t i = 0; i < ID_size; ++i) {
        if (part_id_data[i] == local_machine_id) {
          global_id.push_back(ID_data[i]);
        }
      }
      NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
      *rv = res_tensor;
    });
438
439

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      // Input
      std::string name = args[0];
      int local_machine_id = args[1];
      int machine_count = args[2];
      int group_count = args[3];
      int client_id = args[4];
      int service_id = args[5];
      int64_t msg_seq = args[6];
      std::string pickle_data = args[7];
      NDArray ID = args[8];
      NDArray part_id = args[9];
      NDArray local_id = args[10];
      NDArray local_data = args[11];
      // Data
      dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
      dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
      dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
      dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
      char* local_data_char = static_cast<char*>(local_data->data);
      std::vector<dgl_id_t> local_ids;
      std::vector<dgl_id_t> local_ids_orginal;
      std::vector<int64_t> local_data_shape;
      std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
      std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
      // Get row size (in bytes)
      int row_size = 1;
      for (int i = 0; i < local_data->ndim; ++i) {
        local_data_shape.push_back(local_data->shape[i]);
        if (i != 0) {
          row_size *= local_data->shape[i];
        }
      }
      row_size *= (local_data->dtype.bits / 8);
      size_t data_size = local_data.GetSize();
      CHECK_GT(local_data_shape.size(), 0);
      CHECK_EQ(row_size * local_data_shape[0], data_size);
      // Get local id (used in local machine) and
      // remote id (send to remote machine)
      dgl_id_t idx = 0;
      for (dgl_id_t i = 0; i < ID_size; ++i) {
        dgl_id_t p_id = part_id_data[i];
        if (static_cast<int>(p_id) == local_machine_id) {
          dgl_id_t l_id = local_id_data[idx++];
          CHECK_LT(l_id, local_data_shape[0]);
          CHECK_GE(l_id, 0);
          local_ids.push_back(l_id);
          local_ids_orginal.push_back(i);
        } else {
          CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
          dgl_id_t id = ID_data[i];
          remote_ids[p_id].push_back(id);
          remote_ids_original[p_id].push_back(i);
        }
      }
      // Send remote id
      int msg_count = 0;
      for (size_t i = 0; i < remote_ids.size(); ++i) {
        if (remote_ids[i].size() != 0) {
          RPCMessage msg;
          msg.service_id = service_id;
          msg.msg_seq = msg_seq;
          msg.client_id = client_id;
          int lower = i * group_count;
          int upper = (i + 1) * group_count;
          msg.server_id =
              dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
          msg.data = pickle_data;
          NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
          msg.tensors.push_back(tensor);
          msg.group_id = RPCContext::getInstance()->group_id;
          SendRPCMessage(msg, msg.server_id);
          msg_count++;
        }
      }
      local_data_shape[0] = ID_size;
      NDArray res_tensor = NDArray::Empty(
          local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0});
      char* return_data = static_cast<char*>(res_tensor->data);
      // Copy local data
      parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
        for (auto i = b; i < e; ++i) {
          CHECK_GE(
              ID_size * row_size, local_ids_orginal[i] * row_size + row_size);
          CHECK_GE(data_size, local_ids[i] * row_size + row_size);
          CHECK_GE(local_ids[i], 0);
          memcpy(
              return_data + local_ids_orginal[i] * row_size,
              local_data_char + local_ids[i] * row_size, row_size);
        }
      });
      // Recv remote message
      int recv_cnt = 0;
      while (recv_cnt < msg_count) {
        RPCMessage msg;
        auto status = RecvRPCMessage(&msg, 0);
        CHECK_EQ(status, kRPCSuccess);
        ++recv_cnt;
        int part_id = msg.server_id / group_count;
        char* data_char = static_cast<char*>(msg.tensors[0]->data);
        dgl_id_t id_size = remote_ids[part_id].size();
        for (size_t n = 0; n < id_size; ++n) {
          memcpy(
              return_data + remote_ids_original[part_id][n] * row_size,
              data_char + n * row_size, row_size);
        }
      }
      *rv = res_tensor;
    });
549

550
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
551
552
553
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->group_id;
    });
554
555

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
556
557
558
559
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t group_id = args[0];
      RPCContext::getInstance()->group_id = group_id;
    });
560
561

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
562
563
564
565
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->group_id;
    });
566
567

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
568
569
570
571
572
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t client_id = args[0];
      const int32_t group_id = args[1];
      *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
    });
573
574

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
575
576
577
578
579
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t client_id = args[0];
      const int32_t group_id = args[1];
      *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
    });
580

581
582
}  // namespace rpc
}  // namespace dgl
583
584

#endif