"graphbolt/src/vscode:/vscode.git/clone" did not exist on "3795a006b91c94291b911f0daa261c0598d7ffd8"
rpc.cc 17.9 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
72
73
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
  // ignore timeout now
  CHECK_EQ(timeout, 0) << "rpc cannot support timeout now.";
74
  RPCContext::getInstance()->receiver->Recv(msg);
75
76
77
  return kRPCSuccess;
}

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
      CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
      auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
                                                         std::move(listeners));
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

109
//////////////////////////// C APIs ////////////////////////////
110
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
111
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
112
113

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
114
.set_body([](DGLArgs args, DGLRetValue* rv) {
115
116
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
117
118
119
  InitGlobalTpContext();
  RPCContext::getInstance()->sender =
    std::make_shared<TPSender>(RPCContext::getInstance()->ctx);
120
121
122
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
123
.set_body([](DGLArgs args, DGLRetValue* rv) {
124
125
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
126
127
128
  InitGlobalTpContext();
  RPCContext::getInstance()->receiver =
    std::make_shared<TPReceiver>(RPCContext::getInstance()->ctx);
129
130
131
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
132
133
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->Finalize();
134
135
136
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
137
138
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->receiver->Finalize();
139
140
141
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
142
.set_body([](DGLArgs args, DGLRetValue* rv) {
143
144
145
  std::string ip = args[0];
  int port = args[1];
  int num_sender = args[2];
146
  bool blocking = args[3];
147
  std::string addr;
148
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
149
  if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
150
151
152
153
    LOG(FATAL) << "Wait sender socket failed.";
  }
});

154
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
155
.set_body([](DGLArgs args, DGLRetValue* rv) {
156
157
158
159
  std::string ip = args[0];
  int port = args[1];
  int recv_id = args[2];
  std::string addr;
160
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
161
  *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
162
163
164
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
165
.set_body([](DGLArgs args, DGLRetValue* rv) {
166
  const int32_t rank = args[0];
167
  RPCContext::getInstance()->rank = rank;
168
169
170
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
171
172
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->rank;
173
174
175
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
176
.set_body([](DGLArgs args, DGLRetValue* rv) {
177
  const int32_t num_servers = args[0];
178
  *rv = RPCContext::getInstance()->num_servers = num_servers;
179
180
181
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
182
183
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers;
184
185
});

186
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
187
.set_body([](DGLArgs args, DGLRetValue* rv) {
188
  const int32_t num_clients = args[0];
189
  *rv = RPCContext::getInstance()->num_clients = num_clients;
190
191
192
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
193
194
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_clients;
195
196
});

197
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
198
.set_body([](DGLArgs args, DGLRetValue* rv) {
199
  const int32_t num_servers = args[0];
200
  *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
201
202
203
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
204
205
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers_per_machine;
206
207
});

208
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
209
210
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = (RPCContext::getInstance()->msg_seq)++;
211
212
213
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
214
215
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->msg_seq;
216
217
218
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
219
.set_body([](DGLArgs args, DGLRetValue* rv) {
220
  const int64_t msg_seq = args[0];
221
  RPCContext::getInstance()->msg_seq = msg_seq;
222
223
});

224
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
225
.set_body([](DGLArgs args, DGLRetValue* rv) {
226
227
228
229
230
231
  const int32_t group_id = args[0];
  auto&& cnt = RPCContext::getInstance()->barrier_count;
  if (cnt.find(group_id) == cnt.end()) {
    cnt.emplace(group_id, 0x0);
  }
  *rv = cnt[group_id];
232
233
234
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
235
.set_body([](DGLArgs args, DGLRetValue* rv) {
236
  const int32_t count = args[0];
237
238
  const int32_t group_id = args[1];
  RPCContext::getInstance()->barrier_count[group_id] = count;
239
240
});

241
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
242
243
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->machine_id;
244
245
246
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
247
.set_body([](DGLArgs args, DGLRetValue* rv) {
248
  const int32_t machine_id = args[0];
249
  RPCContext::getInstance()->machine_id = machine_id;
250
251
252
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
253
254
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_machines;
255
256
257
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
258
.set_body([](DGLArgs args, DGLRetValue* rv) {
259
  const int32_t num_machines = args[0];
260
  RPCContext::getInstance()->num_machines = num_machines;
261
262
263
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
264
.set_body([](DGLArgs args, DGLRetValue* rv) {
265
  RPCMessageRef msg = args[0];
266
267
  const int32_t target_id = args[1];
  *rv = SendRPCMessage(*(msg.sptr()), target_id);
268
269
270
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
271
.set_body([](DGLArgs args, DGLRetValue* rv) {
272
273
274
275
276
277
278
279
  int32_t timeout = args[0];
  RPCMessageRef msg = args[1];
  *rv = RecvRPCMessage(msg.sptr().get(), timeout);
});

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
280
281
282
283
284
285
286
287
288
.set_body([](DGLArgs args, DGLRetValue* rv) {
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  int64_t message_size = args[0];

289
290
291
292
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

293

294
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
295
.set_body([](DGLArgs args, DGLRetValue* rv) {
296
297
298
299
300
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  rst->service_id = args[0];
  rst->msg_seq = args[1];
  rst->client_id = args[2];
  rst->server_id = args[3];
301
302
  const std::string data =
    args[4];  // directly assigning string value raises errors :(
303
304
  rst->data = data;
  rst->tensors = ListValueToVector<NDArray>(args[5]);
305
  rst->group_id = args[6];
306
307
308
309
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
310
.set_body([](DGLArgs args, DGLRetValue* rv) {
311
312
313
314
315
  const RPCMessageRef msg = args[0];
  *rv = msg->service_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
316
.set_body([](DGLArgs args, DGLRetValue* rv) {
317
318
319
320
321
  const RPCMessageRef msg = args[0];
  *rv = msg->msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
322
.set_body([](DGLArgs args, DGLRetValue* rv) {
323
324
325
326
327
  const RPCMessageRef msg = args[0];
  *rv = msg->client_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
328
.set_body([](DGLArgs args, DGLRetValue* rv) {
329
330
331
332
333
  const RPCMessageRef msg = args[0];
  *rv = msg->server_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
334
.set_body([](DGLArgs args, DGLRetValue* rv) {
335
336
337
338
339
340
  const RPCMessageRef msg = args[0];
  DGLByteArray barr{msg->data.c_str(), msg->data.size()};
  *rv = barr;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
341
.set_body([](DGLArgs args, DGLRetValue* rv) {
342
343
344
345
346
347
348
349
  const RPCMessageRef msg = args[0];
  List<Value> ret;
  for (size_t i = 0; i < msg->tensors.size(); ++i) {
    ret.push_back(Value(MakeValue(msg->tensors[i])));
  }
  *rv = ret;
});

350
351
#if defined(__linux__)
/*!
352
 * \brief The signal handler.
353
354
 * \param s signal
 */
355
void SigHandler(int s) {
356
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
357
  CleanupResources();
358
359
360
  exit(1);
}

361
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
362
.set_body([](DGLArgs args, DGLRetValue* rv) {
363
  // Ctrl+C handler
364
365
366
367
368
369
  struct sigaction sigHandler;
  sigHandler.sa_handler = SigHandler;
  sigemptyset(&sigHandler.sa_mask);
  sigHandler.sa_flags = 0;
  sigaction(SIGINT, &sigHandler, nullptr);
  sigaction(SIGTERM, &sigHandler, nullptr);
370
371
372
});
#endif

373
374
375
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
376
377
.set_body([](DGLArgs args, DGLRetValue* rv) {
  auto st = RPCContext::getInstance()->server_state;
378
  if (st.get() == nullptr) {
379
    RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
380
  }
381
382
383
  *rv = st;
});

384
385
386
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
387
.set_body([](DGLArgs args, DGLRetValue* rv) {
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
  NDArray ID = args[0];
  NDArray part_id = args[1];
  int local_machine_id = args[2];
  int64_t* ID_data = static_cast<int64_t*>(ID->data);
  int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
  int64_t ID_size = ID.GetSize() / sizeof(int64_t);
  std::vector<int64_t> global_id;
  for (int64_t i = 0; i < ID_size; ++i) {
    if (part_id_data[i] == local_machine_id) {
      global_id.push_back(ID_data[i]);
    }
  }
  NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
  *rv = res_tensor;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
405
.set_body([](DGLArgs args, DGLRetValue* rv) {
406
407
408
409
410
411
412
  // Input
  std::string name = args[0];
  int local_machine_id = args[1];
  int machine_count = args[2];
  int group_count = args[3];
  int client_id = args[4];
  int service_id = args[5];
413
  int64_t msg_seq = args[6];
414
415
416
417
418
419
420
421
422
423
424
425
426
427
  std::string pickle_data = args[7];
  NDArray ID = args[8];
  NDArray part_id = args[9];
  NDArray local_id = args[10];
  NDArray local_data = args[11];
  // Data
  dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
  dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
  dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
  dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
  char* local_data_char = static_cast<char*>(local_data->data);
  std::vector<dgl_id_t> local_ids;
  std::vector<dgl_id_t> local_ids_orginal;
  std::vector<int64_t> local_data_shape;
428
429
  std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
  std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
  // Get row size (in bytes)
  int row_size = 1;
  for (int i = 0; i < local_data->ndim; ++i) {
    local_data_shape.push_back(local_data->shape[i]);
    if (i != 0) {
      row_size *= local_data->shape[i];
    }
  }
  row_size *= (local_data->dtype.bits / 8);
  size_t data_size = local_data.GetSize();
  CHECK_GT(local_data_shape.size(), 0);
  CHECK_EQ(row_size * local_data_shape[0], data_size);
  // Get local id (used in local machine) and
  // remote id (send to remote machine)
  dgl_id_t idx = 0;
  for (dgl_id_t i = 0; i < ID_size; ++i) {
    dgl_id_t p_id = part_id_data[i];
    if (p_id == local_machine_id) {
      dgl_id_t l_id = local_id_data[idx++];
      CHECK_LT(l_id, local_data_shape[0]);
      CHECK_GE(l_id, 0);
      local_ids.push_back(l_id);
      local_ids_orginal.push_back(i);
    } else {
      CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
      dgl_id_t id = ID_data[i];
      remote_ids[p_id].push_back(id);
      remote_ids_original[p_id].push_back(i);
    }
  }
  // Send remote id
  int msg_count = 0;
  for (int i = 0; i < remote_ids.size(); ++i) {
    if (remote_ids[i].size() != 0) {
      RPCMessage msg;
      msg.service_id = service_id;
      msg.msg_seq = msg_seq;
      msg.client_id = client_id;
468
469
      int lower = i * group_count;
      int upper = (i + 1) * group_count;
470
471
472
473
      msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
      msg.data = pickle_data;
      NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
      msg.tensors.push_back(tensor);
474
      msg.group_id = RPCContext::getInstance()->group_id;
475
476
477
478
479
480
481
482
483
484
      SendRPCMessage(msg, msg.server_id);
      msg_count++;
    }
  }
  local_data_shape[0] = ID_size;
  NDArray res_tensor = NDArray::Empty(local_data_shape,
                                      local_data->dtype,
                                      DLContext{kDLCPU, 0});
  char* return_data = static_cast<char*>(res_tensor->data);
  // Copy local data
485
486
  parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
487
488
      CHECK_GE(ID_size * row_size,
                local_ids_orginal[i] * row_size + row_size);
489
490
491
      CHECK_GE(data_size, local_ids[i] * row_size + row_size);
      CHECK_GE(local_ids[i], 0);
      memcpy(return_data + local_ids_orginal[i] * row_size,
492
             local_data_char + local_ids[i] * row_size, row_size);
493
494
    }
  });
495
496
497
498
499
500
501
502
503
  // Recv remote message
  for (int i = 0; i < msg_count; ++i) {
    RPCMessage msg;
    RecvRPCMessage(&msg, 0);
    int part_id = msg.server_id / group_count;
    char* data_char = static_cast<char*>(msg.tensors[0]->data);
    dgl_id_t id_size = remote_ids[part_id].size();
    for (size_t n = 0; n < id_size; ++n) {
      memcpy(return_data + remote_ids_original[part_id][n] * row_size,
504
             data_char + n * row_size, row_size);
505
506
507
508
509
    }
  }
  *rv = res_tensor;
});

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t group_id = args[0];
  RPCContext::getInstance()->group_id = group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const RPCMessageRef msg = args[0];
  *rv = msg->group_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  const int32_t client_id = args[0];
  const int32_t group_id = args[1];
  *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
});

541
542
}  // namespace rpc
}  // namespace dgl
543
544

#endif