rpc.cc 19.5 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  static constexpr int32_t retry_timeout = 5 * 1000;  // milliseconds
  RPCStatus status;
  const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;
  do {
    status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);
    if (status == kRPCTimeOut) {
      static const std::string log_str = [real_timeout, timeout]() {
        std::ostringstream oss;
        oss << "Recv RPCMessage timeout in " << real_timeout << " ms."
            << (timeout == 0 ? " Retrying ..." : "");
        return oss.str();
      }();
      DLOG(WARNING) << log_str;
    }
  } while (timeout == 0 && status == kRPCTimeOut);
  return status;
88
89
}

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
      CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
      auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
                                                         std::move(listeners));
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

121
//////////////////////////// C APIs ////////////////////////////
122
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
123
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
124
125

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
126
.set_body([](DGLArgs args, DGLRetValue* rv) {
127
128
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
129
130
131
132
133
134
135
136
137
138
139
140
141
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->sender.reset(
        new TPSender(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->sender.reset(
        new network::SocketSender(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
  }
  LOG(INFO) << "Sender with NetType~"
            << RPCContext::getInstance()->sender->NetType() << " is created.";
142
143
144
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
145
.set_body([](DGLArgs args, DGLRetValue* rv) {
146
147
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
148
149
150
151
152
153
154
155
156
157
158
159
160
  int max_thread_count = args[2];
  if (type == "tensorpipe") {
    InitGlobalTpContext();
    RPCContext::getInstance()->receiver.reset(
        new TPReceiver(RPCContext::getInstance()->ctx));
  } else if (type == "socket") {
    RPCContext::getInstance()->receiver.reset(
        new network::SocketReceiver(msg_queue_size, max_thread_count));
  } else {
    LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
  }
  LOG(INFO) << "Receiver with NetType~"
            << RPCContext::getInstance()->receiver->NetType() << " is created.";
161
162
163
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
164
165
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->Finalize();
166
167
168
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
169
170
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->receiver->Finalize();
171
172
});

173
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
174
.set_body([](DGLArgs args, DGLRetValue* rv) {
175
176
177
  std::string ip = args[0];
  int port = args[1];
  int num_sender = args[2];
178
  bool blocking = args[3];
179
  std::string addr;
180
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
181
  if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
182
183
184
185
    LOG(FATAL) << "Wait sender socket failed.";
  }
});

186
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
187
.set_body([](DGLArgs args, DGLRetValue* rv) {
188
189
190
191
  std::string ip = args[0];
  int port = args[1];
  int recv_id = args[2];
  std::string addr;
192
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
193
  *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
194
195
});

196
197
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
198
199
  const int max_try_times = args[0];
  *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times);
200
201
});

202
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
203
.set_body([](DGLArgs args, DGLRetValue* rv) {
204
  const int32_t rank = args[0];
205
  RPCContext::getInstance()->rank = rank;
206
207
208
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
209
210
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->rank;
211
212
213
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
214
.set_body([](DGLArgs args, DGLRetValue* rv) {
215
  const int32_t num_servers = args[0];
216
  *rv = RPCContext::getInstance()->num_servers = num_servers;
217
218
219
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
220
221
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers;
222
223
});

224
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
225
.set_body([](DGLArgs args, DGLRetValue* rv) {
226
  const int32_t num_clients = args[0];
227
  *rv = RPCContext::getInstance()->num_clients = num_clients;
228
229
230
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
231
232
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_clients;
233
234
});

235
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
236
.set_body([](DGLArgs args, DGLRetValue* rv) {
237
  const int32_t num_servers = args[0];
238
  *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
239
240
241
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
242
243
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers_per_machine;
244
245
});

246
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
247
248
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = (RPCContext::getInstance()->msg_seq)++;
249
250
251
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
252
253
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->msg_seq;
254
255
256
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
257
.set_body([](DGLArgs args, DGLRetValue* rv) {
258
  const int64_t msg_seq = args[0];
259
  RPCContext::getInstance()->msg_seq = msg_seq;
260
261
});

262
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
263
.set_body([](DGLArgs args, DGLRetValue* rv) {
264
265
266
267
268
269
  const int32_t group_id = args[0];
  auto&& cnt = RPCContext::getInstance()->barrier_count;
  if (cnt.find(group_id) == cnt.end()) {
    cnt.emplace(group_id, 0x0);
  }
  *rv = cnt[group_id];
270
271
272
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
273
.set_body([](DGLArgs args, DGLRetValue* rv) {
274
  const int32_t count = args[0];
275
276
  const int32_t group_id = args[1];
  RPCContext::getInstance()->barrier_count[group_id] = count;
277
278
});

279
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
280
281
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->machine_id;
282
283
284
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
285
.set_body([](DGLArgs args, DGLRetValue* rv) {
286
  const int32_t machine_id = args[0];
287
  RPCContext::getInstance()->machine_id = machine_id;
288
289
290
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
291
292
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_machines;
293
294
295
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
296
.set_body([](DGLArgs args, DGLRetValue* rv) {
297
  const int32_t num_machines = args[0];
298
  RPCContext::getInstance()->num_machines = num_machines;
299
300
301
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
302
.set_body([](DGLArgs args, DGLRetValue* rv) {
303
  RPCMessageRef msg = args[0];
304
305
  const int32_t target_id = args[1];
  *rv = SendRPCMessage(*(msg.sptr()), target_id);
306
307
308
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
309
.set_body([](DGLArgs args, DGLRetValue* rv) {
310
311
312
313
314
315
316
317
  int32_t timeout = args[0];
  RPCMessageRef msg = args[1];
  *rv = RecvRPCMessage(msg.sptr().get(), timeout);
});

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
318
319
320
321
322
323
324
325
326
.set_body([](DGLArgs args, DGLRetValue* rv) {
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  int64_t message_size = args[0];

327
328
329
330
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

331

332
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
333
.set_body([](DGLArgs args, DGLRetValue* rv) {
334
335
336
337
338
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  rst->service_id = args[0];
  rst->msg_seq = args[1];
  rst->client_id = args[2];
  rst->server_id = args[3];
339
340
  const std::string data =
    args[4];  // directly assigning string value raises errors :(
341
342
  rst->data = data;
  rst->tensors = ListValueToVector<NDArray>(args[5]);
343
  rst->group_id = args[6];
344
345
346
347
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
348
.set_body([](DGLArgs args, DGLRetValue* rv) {
349
350
351
352
353
  const RPCMessageRef msg = args[0];
  *rv = msg->service_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
354
.set_body([](DGLArgs args, DGLRetValue* rv) {
355
356
357
358
359
  const RPCMessageRef msg = args[0];
  *rv = msg->msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
360
.set_body([](DGLArgs args, DGLRetValue* rv) {
361
362
363
364
365
  const RPCMessageRef msg = args[0];
  *rv = msg->client_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
366
.set_body([](DGLArgs args, DGLRetValue* rv) {
367
368
369
370
371
  const RPCMessageRef msg = args[0];
  *rv = msg->server_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
372
.set_body([](DGLArgs args, DGLRetValue* rv) {
373
374
375
376
377
378
  const RPCMessageRef msg = args[0];
  DGLByteArray barr{msg->data.c_str(), msg->data.size()};
  *rv = barr;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
379
.set_body([](DGLArgs args, DGLRetValue* rv) {
380
381
382
383
384
385
386
387
  const RPCMessageRef msg = args[0];
  List<Value> ret;
  for (size_t i = 0; i < msg->tensors.size(); ++i) {
    ret.push_back(Value(MakeValue(msg->tensors[i])));
  }
  *rv = ret;
});

388
389
#if defined(__linux__)
/*!
390
 * \brief The signal handler.
391
392
 * \param s signal
 */
393
void SigHandler(int s) {
394
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
395
  CleanupResources();
396
397
398
  exit(1);
}

399
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
400
.set_body([](DGLArgs args, DGLRetValue* rv) {
401
  // Ctrl+C handler
402
403
404
405
406
407
  struct sigaction sigHandler;
  sigHandler.sa_handler = SigHandler;
  sigemptyset(&sigHandler.sa_mask);
  sigHandler.sa_flags = 0;
  sigaction(SIGINT, &sigHandler, nullptr);
  sigaction(SIGTERM, &sigHandler, nullptr);
408
409
410
});
#endif

411
412
413
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
414
415
.set_body([](DGLArgs args, DGLRetValue* rv) {
  auto st = RPCContext::getInstance()->server_state;
416
  if (st.get() == nullptr) {
417
    RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
418
  }
419
420
421
  *rv = st;
});

422
423
424
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
425
.set_body([](DGLArgs args, DGLRetValue* rv) {
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
  NDArray ID = args[0];
  NDArray part_id = args[1];
  int local_machine_id = args[2];
  int64_t* ID_data = static_cast<int64_t*>(ID->data);
  int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
  int64_t ID_size = ID.GetSize() / sizeof(int64_t);
  std::vector<int64_t> global_id;
  for (int64_t i = 0; i < ID_size; ++i) {
    if (part_id_data[i] == local_machine_id) {
      global_id.push_back(ID_data[i]);
    }
  }
  NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
  *rv = res_tensor;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
443
.set_body([](DGLArgs args, DGLRetValue* rv) {
444
445
446
447
448
449
450
  // Input
  std::string name = args[0];
  int local_machine_id = args[1];
  int machine_count = args[2];
  int group_count = args[3];
  int client_id = args[4];
  int service_id = args[5];
451
  int64_t msg_seq = args[6];
452
453
454
455
456
457
458
459
460
461
462
463
464
465
  std::string pickle_data = args[7];
  NDArray ID = args[8];
  NDArray part_id = args[9];
  NDArray local_id = args[10];
  NDArray local_data = args[11];
  // Data
  dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
  dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
  dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
  dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
  char* local_data_char = static_cast<char*>(local_data->data);
  std::vector<dgl_id_t> local_ids;
  std::vector<dgl_id_t> local_ids_orginal;
  std::vector<int64_t> local_data_shape;
466
467
  std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
  std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
  // Get row size (in bytes)
  int row_size = 1;
  for (int i = 0; i < local_data->ndim; ++i) {
    local_data_shape.push_back(local_data->shape[i]);
    if (i != 0) {
      row_size *= local_data->shape[i];
    }
  }
  row_size *= (local_data->dtype.bits / 8);
  size_t data_size = local_data.GetSize();
  CHECK_GT(local_data_shape.size(), 0);
  CHECK_EQ(row_size * local_data_shape[0], data_size);
  // Get local id (used in local machine) and
  // remote id (send to remote machine)
  dgl_id_t idx = 0;
  for (dgl_id_t i = 0; i < ID_size; ++i) {
    dgl_id_t p_id = part_id_data[i];
    if (p_id == local_machine_id) {
      dgl_id_t l_id = local_id_data[idx++];
      CHECK_LT(l_id, local_data_shape[0]);
      CHECK_GE(l_id, 0);
      local_ids.push_back(l_id);
      local_ids_orginal.push_back(i);
    } else {
      CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
      dgl_id_t id = ID_data[i];
      remote_ids[p_id].push_back(id);
      remote_ids_original[p_id].push_back(i);
    }
  }
  // Send remote id
  int msg_count = 0;
  for (int i = 0; i < remote_ids.size(); ++i) {
    if (remote_ids[i].size() != 0) {
      RPCMessage msg;
      msg.service_id = service_id;
      msg.msg_seq = msg_seq;
      msg.client_id = client_id;
506
507
      int lower = i * group_count;
      int upper = (i + 1) * group_count;
508
509
510
511
      msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
      msg.data = pickle_data;
      NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
      msg.tensors.push_back(tensor);
512
      msg.group_id = RPCContext::getInstance()->group_id;
513
514
515
516
517
518
519
520
521
522
      SendRPCMessage(msg, msg.server_id);
      msg_count++;
    }
  }
  local_data_shape[0] = ID_size;
  NDArray res_tensor = NDArray::Empty(local_data_shape,
                                      local_data->dtype,
                                      DLContext{kDLCPU, 0});
  char* return_data = static_cast<char*>(res_tensor->data);
  // Copy local data
523
524
  parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
525
526
      CHECK_GE(ID_size * row_size,
                local_ids_orginal[i] * row_size + row_size);
527
528
529
      CHECK_GE(data_size, local_ids[i] * row_size + row_size);
      CHECK_GE(local_ids[i], 0);
      memcpy(return_data + local_ids_orginal[i] * row_size,
530
             local_data_char + local_ids[i] * row_size, row_size);
531
532
    }
  });
533
  // Recv remote message
534
535
  int recv_cnt = 0;
  while (recv_cnt < msg_count) {
536
    RPCMessage msg;
537
538
539
    auto status = RecvRPCMessage(&msg, 0);
    CHECK_EQ(status, kRPCSuccess);
    ++recv_cnt;
540
541
542
543
544
    int part_id = msg.server_id / group_count;
    char* data_char = static_cast<char*>(msg.tensors[0]->data);
    dgl_id_t id_size = remote_ids[part_id].size();
    for (size_t n = 0; n < id_size; ++n) {
      memcpy(return_data + remote_ids_original[part_id][n] * row_size,
545
             data_char + n * row_size, row_size);
546
547
548
549
550
    }
  }
  *rv = res_tensor;
});

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t group_id = args[0];
  RPCContext::getInstance()->group_id = group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const RPCMessageRef msg = args[0];
  *rv = msg->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});

582
583
}  // namespace rpc
}  // namespace dgl
584
585

#endif