rpc.cc 19 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
72
73
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
  // ignore timeout now
  CHECK_EQ(timeout, 0) << "rpc cannot support timeout now.";
74
  RPCContext::getInstance()->receiver->Recv(msg);
75
76
77
  return kRPCSuccess;
}

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
      CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
      auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
                                                         std::move(listeners));
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

109
//////////////////////////// C APIs ////////////////////////////
110
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
111
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
112
113

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
114
.set_body([](DGLArgs args, DGLRetValue* rv) {
115
116
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
117
118
119
120
121
122
123
124
125
126
127
128
129
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->sender.reset(
        new TPSender(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->sender.reset(
        new network::SocketSender(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
  }
  LOG(INFO) << "Sender with NetType~"
            << RPCContext::getInstance()->sender->NetType() << " is created.";
130
131
132
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
133
.set_body([](DGLArgs args, DGLRetValue* rv) {
134
135
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
136
137
138
139
140
141
142
143
144
145
146
147
148
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->receiver.reset(
        new TPReceiver(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->receiver.reset(
        new network::SocketReceiver(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
  }
  LOG(INFO) << "Receiver with NetType~"
            << RPCContext::getInstance()->receiver->NetType() << " is created.";
149
150
151
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
152
153
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->Finalize();
154
155
156
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
157
158
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->receiver->Finalize();
159
160
});

161
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
162
.set_body([](DGLArgs args, DGLRetValue* rv) {
163
164
165
  std::string ip = args[0];
  int port = args[1];
  int num_sender = args[2];
166
  bool blocking = args[3];
167
  std::string addr;
168
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
169
  if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
170
171
172
173
    LOG(FATAL) << "Wait sender socket failed.";
  }
});

174
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
175
.set_body([](DGLArgs args, DGLRetValue* rv) {
176
177
178
179
  std::string ip = args[0];
  int port = args[1];
  int recv_id = args[2];
  std::string addr;
180
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
181
  *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
182
183
});

184
185
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
186
187
  const int max_try_times = args[0];
  *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times);
188
189
});

190
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
191
.set_body([](DGLArgs args, DGLRetValue* rv) {
192
  const int32_t rank = args[0];
193
  RPCContext::getInstance()->rank = rank;
194
195
196
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
197
198
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->rank;
199
200
201
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
202
.set_body([](DGLArgs args, DGLRetValue* rv) {
203
  const int32_t num_servers = args[0];
204
  *rv = RPCContext::getInstance()->num_servers = num_servers;
205
206
207
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
208
209
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers;
210
211
});

212
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
213
.set_body([](DGLArgs args, DGLRetValue* rv) {
214
  const int32_t num_clients = args[0];
215
  *rv = RPCContext::getInstance()->num_clients = num_clients;
216
217
218
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
219
220
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_clients;
221
222
});

223
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
224
.set_body([](DGLArgs args, DGLRetValue* rv) {
225
  const int32_t num_servers = args[0];
226
  *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
227
228
229
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
230
231
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers_per_machine;
232
233
});

234
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
235
236
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = (RPCContext::getInstance()->msg_seq)++;
237
238
239
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
240
241
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->msg_seq;
242
243
244
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
245
.set_body([](DGLArgs args, DGLRetValue* rv) {
246
  const int64_t msg_seq = args[0];
247
  RPCContext::getInstance()->msg_seq = msg_seq;
248
249
});

250
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
251
.set_body([](DGLArgs args, DGLRetValue* rv) {
252
253
254
255
256
257
  const int32_t group_id = args[0];
  auto&& cnt = RPCContext::getInstance()->barrier_count;
  if (cnt.find(group_id) == cnt.end()) {
    cnt.emplace(group_id, 0x0);
  }
  *rv = cnt[group_id];
258
259
260
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
261
.set_body([](DGLArgs args, DGLRetValue* rv) {
262
  const int32_t count = args[0];
263
264
  const int32_t group_id = args[1];
  RPCContext::getInstance()->barrier_count[group_id] = count;
265
266
});

267
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
268
269
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->machine_id;
270
271
272
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
273
.set_body([](DGLArgs args, DGLRetValue* rv) {
274
  const int32_t machine_id = args[0];
275
  RPCContext::getInstance()->machine_id = machine_id;
276
277
278
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
279
280
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_machines;
281
282
283
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
284
.set_body([](DGLArgs args, DGLRetValue* rv) {
285
  const int32_t num_machines = args[0];
286
  RPCContext::getInstance()->num_machines = num_machines;
287
288
289
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
290
.set_body([](DGLArgs args, DGLRetValue* rv) {
291
  RPCMessageRef msg = args[0];
292
293
  const int32_t target_id = args[1];
  *rv = SendRPCMessage(*(msg.sptr()), target_id);
294
295
296
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
297
.set_body([](DGLArgs args, DGLRetValue* rv) {
298
299
300
301
302
303
304
305
  int32_t timeout = args[0];
  RPCMessageRef msg = args[1];
  *rv = RecvRPCMessage(msg.sptr().get(), timeout);
});

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
306
307
308
309
310
311
312
313
314
.set_body([](DGLArgs args, DGLRetValue* rv) {
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  int64_t message_size = args[0];

315
316
317
318
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

319

320
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
321
.set_body([](DGLArgs args, DGLRetValue* rv) {
322
323
324
325
326
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  rst->service_id = args[0];
  rst->msg_seq = args[1];
  rst->client_id = args[2];
  rst->server_id = args[3];
327
328
  const std::string data =
    args[4];  // directly assigning string value raises errors :(
329
330
  rst->data = data;
  rst->tensors = ListValueToVector<NDArray>(args[5]);
331
  rst->group_id = args[6];
332
333
334
335
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
336
.set_body([](DGLArgs args, DGLRetValue* rv) {
337
338
339
340
341
  const RPCMessageRef msg = args[0];
  *rv = msg->service_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
342
.set_body([](DGLArgs args, DGLRetValue* rv) {
343
344
345
346
347
  const RPCMessageRef msg = args[0];
  *rv = msg->msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
348
.set_body([](DGLArgs args, DGLRetValue* rv) {
349
350
351
352
353
  const RPCMessageRef msg = args[0];
  *rv = msg->client_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
354
.set_body([](DGLArgs args, DGLRetValue* rv) {
355
356
357
358
359
  const RPCMessageRef msg = args[0];
  *rv = msg->server_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
360
.set_body([](DGLArgs args, DGLRetValue* rv) {
361
362
363
364
365
366
  const RPCMessageRef msg = args[0];
  DGLByteArray barr{msg->data.c_str(), msg->data.size()};
  *rv = barr;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
367
.set_body([](DGLArgs args, DGLRetValue* rv) {
368
369
370
371
372
373
374
375
  const RPCMessageRef msg = args[0];
  List<Value> ret;
  for (size_t i = 0; i < msg->tensors.size(); ++i) {
    ret.push_back(Value(MakeValue(msg->tensors[i])));
  }
  *rv = ret;
});

376
377
#if defined(__linux__)
/*!
378
 * \brief The signal handler.
379
380
 * \param s signal
 */
381
void SigHandler(int s) {
382
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
383
  CleanupResources();
384
385
386
  exit(1);
}

387
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
388
.set_body([](DGLArgs args, DGLRetValue* rv) {
389
  // Ctrl+C handler
390
391
392
393
394
395
  struct sigaction sigHandler;
  sigHandler.sa_handler = SigHandler;
  sigemptyset(&sigHandler.sa_mask);
  sigHandler.sa_flags = 0;
  sigaction(SIGINT, &sigHandler, nullptr);
  sigaction(SIGTERM, &sigHandler, nullptr);
396
397
398
});
#endif

399
400
401
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
402
403
.set_body([](DGLArgs args, DGLRetValue* rv) {
  auto st = RPCContext::getInstance()->server_state;
404
  if (st.get() == nullptr) {
405
    RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
406
  }
407
408
409
  *rv = st;
});

410
411
412
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
413
.set_body([](DGLArgs args, DGLRetValue* rv) {
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
  NDArray ID = args[0];
  NDArray part_id = args[1];
  int local_machine_id = args[2];
  int64_t* ID_data = static_cast<int64_t*>(ID->data);
  int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
  int64_t ID_size = ID.GetSize() / sizeof(int64_t);
  std::vector<int64_t> global_id;
  for (int64_t i = 0; i < ID_size; ++i) {
    if (part_id_data[i] == local_machine_id) {
      global_id.push_back(ID_data[i]);
    }
  }
  NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
  *rv = res_tensor;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
431
.set_body([](DGLArgs args, DGLRetValue* rv) {
432
433
434
435
436
437
438
  // Input
  std::string name = args[0];
  int local_machine_id = args[1];
  int machine_count = args[2];
  int group_count = args[3];
  int client_id = args[4];
  int service_id = args[5];
439
  int64_t msg_seq = args[6];
440
441
442
443
444
445
446
447
448
449
450
451
452
453
  std::string pickle_data = args[7];
  NDArray ID = args[8];
  NDArray part_id = args[9];
  NDArray local_id = args[10];
  NDArray local_data = args[11];
  // Data
  dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
  dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
  dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
  dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
  char* local_data_char = static_cast<char*>(local_data->data);
  std::vector<dgl_id_t> local_ids;
  std::vector<dgl_id_t> local_ids_orginal;
  std::vector<int64_t> local_data_shape;
454
455
  std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
  std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
  // Get row size (in bytes)
  int row_size = 1;
  for (int i = 0; i < local_data->ndim; ++i) {
    local_data_shape.push_back(local_data->shape[i]);
    if (i != 0) {
      row_size *= local_data->shape[i];
    }
  }
  row_size *= (local_data->dtype.bits / 8);
  size_t data_size = local_data.GetSize();
  CHECK_GT(local_data_shape.size(), 0);
  CHECK_EQ(row_size * local_data_shape[0], data_size);
  // Get local id (used in local machine) and
  // remote id (send to remote machine)
  dgl_id_t idx = 0;
  for (dgl_id_t i = 0; i < ID_size; ++i) {
    dgl_id_t p_id = part_id_data[i];
    if (p_id == local_machine_id) {
      dgl_id_t l_id = local_id_data[idx++];
      CHECK_LT(l_id, local_data_shape[0]);
      CHECK_GE(l_id, 0);
      local_ids.push_back(l_id);
      local_ids_orginal.push_back(i);
    } else {
      CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
      dgl_id_t id = ID_data[i];
      remote_ids[p_id].push_back(id);
      remote_ids_original[p_id].push_back(i);
    }
  }
  // Send remote id
  int msg_count = 0;
  for (int i = 0; i < remote_ids.size(); ++i) {
    if (remote_ids[i].size() != 0) {
      RPCMessage msg;
      msg.service_id = service_id;
      msg.msg_seq = msg_seq;
      msg.client_id = client_id;
494
495
      int lower = i * group_count;
      int upper = (i + 1) * group_count;
496
497
498
499
      msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
      msg.data = pickle_data;
      NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
      msg.tensors.push_back(tensor);
500
      msg.group_id = RPCContext::getInstance()->group_id;
501
502
503
504
505
506
507
508
509
510
      SendRPCMessage(msg, msg.server_id);
      msg_count++;
    }
  }
  local_data_shape[0] = ID_size;
  NDArray res_tensor = NDArray::Empty(local_data_shape,
                                      local_data->dtype,
                                      DLContext{kDLCPU, 0});
  char* return_data = static_cast<char*>(res_tensor->data);
  // Copy local data
511
512
  parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
513
514
      CHECK_GE(ID_size * row_size,
                local_ids_orginal[i] * row_size + row_size);
515
516
517
      CHECK_GE(data_size, local_ids[i] * row_size + row_size);
      CHECK_GE(local_ids[i], 0);
      memcpy(return_data + local_ids_orginal[i] * row_size,
518
             local_data_char + local_ids[i] * row_size, row_size);
519
520
    }
  });
521
522
523
524
525
526
527
528
529
  // Recv remote message
  for (int i = 0; i < msg_count; ++i) {
    RPCMessage msg;
    RecvRPCMessage(&msg, 0);
    int part_id = msg.server_id / group_count;
    char* data_char = static_cast<char*>(msg.tensors[0]->data);
    dgl_id_t id_size = remote_ids[part_id].size();
    for (size_t n = 0; n < id_size; ++n) {
      memcpy(return_data + remote_ids_original[part_id][n] * row_size,
530
             data_char + n * row_size, row_size);
531
532
533
534
535
    }
  }
  *rv = res_tensor;
});

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t group_id = args[0];
  RPCContext::getInstance()->group_id = group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const RPCMessageRef msg = args[0];
  *rv = msg->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});

567
568
}  // namespace rpc
}  // namespace dgl
569
570

#endif