rpc.cc 18.9 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
72
73
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
  // ignore timeout now
  CHECK_EQ(timeout, 0) << "rpc cannot support timeout now.";
74
  RPCContext::getInstance()->receiver->Recv(msg);
75
76
77
  return kRPCSuccess;
}

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
      CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
      auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
                                                         std::move(listeners));
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

109
//////////////////////////// C APIs ////////////////////////////
110
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
111
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
112
113

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
114
.set_body([](DGLArgs args, DGLRetValue* rv) {
115
116
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
117
118
119
120
121
122
123
124
125
126
127
128
129
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->sender.reset(
        new TPSender(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->sender.reset(
        new network::SocketSender(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
  }
  LOG(INFO) << "Sender with NetType~"
            << RPCContext::getInstance()->sender->NetType() << " is created.";
130
131
132
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
133
.set_body([](DGLArgs args, DGLRetValue* rv) {
134
135
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
136
137
138
139
140
141
142
143
144
145
146
147
148
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->receiver.reset(
        new TPReceiver(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->receiver.reset(
        new network::SocketReceiver(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
  }
  LOG(INFO) << "Receiver with NetType~"
            << RPCContext::getInstance()->receiver->NetType() << " is created.";
149
150
151
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
152
153
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->Finalize();
154
155
156
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
157
158
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->receiver->Finalize();
159
160
});

161
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
162
.set_body([](DGLArgs args, DGLRetValue* rv) {
163
164
165
  std::string ip = args[0];
  int port = args[1];
  int num_sender = args[2];
166
  bool blocking = args[3];
167
  std::string addr;
168
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
169
  if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
170
171
172
173
    LOG(FATAL) << "Wait sender socket failed.";
  }
});

174
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
175
.set_body([](DGLArgs args, DGLRetValue* rv) {
176
177
178
179
  std::string ip = args[0];
  int port = args[1];
  int recv_id = args[2];
  std::string addr;
180
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
181
  *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
182
183
});

184
185
186
187
188
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->ConnectReceiverFinalize();
});

189
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
190
.set_body([](DGLArgs args, DGLRetValue* rv) {
191
  const int32_t rank = args[0];
192
  RPCContext::getInstance()->rank = rank;
193
194
195
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
196
197
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->rank;
198
199
200
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
201
.set_body([](DGLArgs args, DGLRetValue* rv) {
202
  const int32_t num_servers = args[0];
203
  *rv = RPCContext::getInstance()->num_servers = num_servers;
204
205
206
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
207
208
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers;
209
210
});

211
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
212
.set_body([](DGLArgs args, DGLRetValue* rv) {
213
  const int32_t num_clients = args[0];
214
  *rv = RPCContext::getInstance()->num_clients = num_clients;
215
216
217
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
218
219
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_clients;
220
221
});

222
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
223
.set_body([](DGLArgs args, DGLRetValue* rv) {
224
  const int32_t num_servers = args[0];
225
  *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
226
227
228
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
229
230
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers_per_machine;
231
232
});

233
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
234
235
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = (RPCContext::getInstance()->msg_seq)++;
236
237
238
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
239
240
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->msg_seq;
241
242
243
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
244
.set_body([](DGLArgs args, DGLRetValue* rv) {
245
  const int64_t msg_seq = args[0];
246
  RPCContext::getInstance()->msg_seq = msg_seq;
247
248
});

249
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
250
.set_body([](DGLArgs args, DGLRetValue* rv) {
251
252
253
254
255
256
  const int32_t group_id = args[0];
  auto&& cnt = RPCContext::getInstance()->barrier_count;
  if (cnt.find(group_id) == cnt.end()) {
    cnt.emplace(group_id, 0x0);
  }
  *rv = cnt[group_id];
257
258
259
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
260
.set_body([](DGLArgs args, DGLRetValue* rv) {
261
  const int32_t count = args[0];
262
263
  const int32_t group_id = args[1];
  RPCContext::getInstance()->barrier_count[group_id] = count;
264
265
});

266
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
267
268
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->machine_id;
269
270
271
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
272
.set_body([](DGLArgs args, DGLRetValue* rv) {
273
  const int32_t machine_id = args[0];
274
  RPCContext::getInstance()->machine_id = machine_id;
275
276
277
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
278
279
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_machines;
280
281
282
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
283
.set_body([](DGLArgs args, DGLRetValue* rv) {
284
  const int32_t num_machines = args[0];
285
  RPCContext::getInstance()->num_machines = num_machines;
286
287
288
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
289
.set_body([](DGLArgs args, DGLRetValue* rv) {
290
  RPCMessageRef msg = args[0];
291
292
  const int32_t target_id = args[1];
  *rv = SendRPCMessage(*(msg.sptr()), target_id);
293
294
295
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
296
.set_body([](DGLArgs args, DGLRetValue* rv) {
297
298
299
300
301
302
303
304
  int32_t timeout = args[0];
  RPCMessageRef msg = args[1];
  *rv = RecvRPCMessage(msg.sptr().get(), timeout);
});

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
305
306
307
308
309
310
311
312
313
.set_body([](DGLArgs args, DGLRetValue* rv) {
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  int64_t message_size = args[0];

314
315
316
317
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

318

319
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
320
.set_body([](DGLArgs args, DGLRetValue* rv) {
321
322
323
324
325
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  rst->service_id = args[0];
  rst->msg_seq = args[1];
  rst->client_id = args[2];
  rst->server_id = args[3];
326
327
  const std::string data =
    args[4];  // directly assigning string value raises errors :(
328
329
  rst->data = data;
  rst->tensors = ListValueToVector<NDArray>(args[5]);
330
  rst->group_id = args[6];
331
332
333
334
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
335
.set_body([](DGLArgs args, DGLRetValue* rv) {
336
337
338
339
340
  const RPCMessageRef msg = args[0];
  *rv = msg->service_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
341
.set_body([](DGLArgs args, DGLRetValue* rv) {
342
343
344
345
346
  const RPCMessageRef msg = args[0];
  *rv = msg->msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
347
.set_body([](DGLArgs args, DGLRetValue* rv) {
348
349
350
351
352
  const RPCMessageRef msg = args[0];
  *rv = msg->client_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
353
.set_body([](DGLArgs args, DGLRetValue* rv) {
354
355
356
357
358
  const RPCMessageRef msg = args[0];
  *rv = msg->server_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
359
.set_body([](DGLArgs args, DGLRetValue* rv) {
360
361
362
363
364
365
  const RPCMessageRef msg = args[0];
  DGLByteArray barr{msg->data.c_str(), msg->data.size()};
  *rv = barr;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
366
.set_body([](DGLArgs args, DGLRetValue* rv) {
367
368
369
370
371
372
373
374
  const RPCMessageRef msg = args[0];
  List<Value> ret;
  for (size_t i = 0; i < msg->tensors.size(); ++i) {
    ret.push_back(Value(MakeValue(msg->tensors[i])));
  }
  *rv = ret;
});

375
376
#if defined(__linux__)
/*!
377
 * \brief The signal handler.
378
379
 * \param s signal
 */
380
void SigHandler(int s) {
381
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
382
  CleanupResources();
383
384
385
  exit(1);
}

386
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
387
.set_body([](DGLArgs args, DGLRetValue* rv) {
388
  // Ctrl+C handler
389
390
391
392
393
394
  struct sigaction sigHandler;
  sigHandler.sa_handler = SigHandler;
  sigemptyset(&sigHandler.sa_mask);
  sigHandler.sa_flags = 0;
  sigaction(SIGINT, &sigHandler, nullptr);
  sigaction(SIGTERM, &sigHandler, nullptr);
395
396
397
});
#endif

398
399
400
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
401
402
.set_body([](DGLArgs args, DGLRetValue* rv) {
  auto st = RPCContext::getInstance()->server_state;
403
  if (st.get() == nullptr) {
404
    RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
405
  }
406
407
408
  *rv = st;
});

409
410
411
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
412
.set_body([](DGLArgs args, DGLRetValue* rv) {
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
  NDArray ID = args[0];
  NDArray part_id = args[1];
  int local_machine_id = args[2];
  int64_t* ID_data = static_cast<int64_t*>(ID->data);
  int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
  int64_t ID_size = ID.GetSize() / sizeof(int64_t);
  std::vector<int64_t> global_id;
  for (int64_t i = 0; i < ID_size; ++i) {
    if (part_id_data[i] == local_machine_id) {
      global_id.push_back(ID_data[i]);
    }
  }
  NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
  *rv = res_tensor;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
430
.set_body([](DGLArgs args, DGLRetValue* rv) {
431
432
433
434
435
436
437
  // Input
  std::string name = args[0];
  int local_machine_id = args[1];
  int machine_count = args[2];
  int group_count = args[3];
  int client_id = args[4];
  int service_id = args[5];
438
  int64_t msg_seq = args[6];
439
440
441
442
443
444
445
446
447
448
449
450
451
452
  std::string pickle_data = args[7];
  NDArray ID = args[8];
  NDArray part_id = args[9];
  NDArray local_id = args[10];
  NDArray local_data = args[11];
  // Data
  dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
  dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
  dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
  dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
  char* local_data_char = static_cast<char*>(local_data->data);
  std::vector<dgl_id_t> local_ids;
  std::vector<dgl_id_t> local_ids_orginal;
  std::vector<int64_t> local_data_shape;
453
454
  std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
  std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
  // Get row size (in bytes)
  int row_size = 1;
  for (int i = 0; i < local_data->ndim; ++i) {
    local_data_shape.push_back(local_data->shape[i]);
    if (i != 0) {
      row_size *= local_data->shape[i];
    }
  }
  row_size *= (local_data->dtype.bits / 8);
  size_t data_size = local_data.GetSize();
  CHECK_GT(local_data_shape.size(), 0);
  CHECK_EQ(row_size * local_data_shape[0], data_size);
  // Get local id (used in local machine) and
  // remote id (send to remote machine)
  dgl_id_t idx = 0;
  for (dgl_id_t i = 0; i < ID_size; ++i) {
    dgl_id_t p_id = part_id_data[i];
    if (p_id == local_machine_id) {
      dgl_id_t l_id = local_id_data[idx++];
      CHECK_LT(l_id, local_data_shape[0]);
      CHECK_GE(l_id, 0);
      local_ids.push_back(l_id);
      local_ids_orginal.push_back(i);
    } else {
      CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
      dgl_id_t id = ID_data[i];
      remote_ids[p_id].push_back(id);
      remote_ids_original[p_id].push_back(i);
    }
  }
  // Send remote id
  int msg_count = 0;
  for (int i = 0; i < remote_ids.size(); ++i) {
    if (remote_ids[i].size() != 0) {
      RPCMessage msg;
      msg.service_id = service_id;
      msg.msg_seq = msg_seq;
      msg.client_id = client_id;
493
494
      int lower = i * group_count;
      int upper = (i + 1) * group_count;
495
496
497
498
      msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
      msg.data = pickle_data;
      NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
      msg.tensors.push_back(tensor);
499
      msg.group_id = RPCContext::getInstance()->group_id;
500
501
502
503
504
505
506
507
508
509
      SendRPCMessage(msg, msg.server_id);
      msg_count++;
    }
  }
  local_data_shape[0] = ID_size;
  NDArray res_tensor = NDArray::Empty(local_data_shape,
                                      local_data->dtype,
                                      DLContext{kDLCPU, 0});
  char* return_data = static_cast<char*>(res_tensor->data);
  // Copy local data
510
511
  parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
512
513
      CHECK_GE(ID_size * row_size,
                local_ids_orginal[i] * row_size + row_size);
514
515
516
      CHECK_GE(data_size, local_ids[i] * row_size + row_size);
      CHECK_GE(local_ids[i], 0);
      memcpy(return_data + local_ids_orginal[i] * row_size,
517
             local_data_char + local_ids[i] * row_size, row_size);
518
519
    }
  });
520
521
522
523
524
525
526
527
528
  // Recv remote message
  for (int i = 0; i < msg_count; ++i) {
    RPCMessage msg;
    RecvRPCMessage(&msg, 0);
    int part_id = msg.server_id / group_count;
    char* data_char = static_cast<char*>(msg.tensors[0]->data);
    dgl_id_t id_size = remote_ids[part_id].size();
    for (size_t n = 0; n < id_size; ++n) {
      memcpy(return_data + remote_ids_original[part_id][n] * row_size,
529
             data_char + n * row_size, row_size);
530
531
532
533
534
    }
  }
  *rv = res_tensor;
});

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t group_id = args[0];
  RPCContext::getInstance()->group_id = group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const RPCMessageRef msg = args[0];
  *rv = msg->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});

566
567
}  // namespace rpc
}  // namespace dgl
568
569

#endif