rpc.cc 16.8 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
72
73
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
  // ignore timeout now
  CHECK_EQ(timeout, 0) << "rpc cannot support timeout now.";
74
  RPCContext::getInstance()->receiver->Recv(msg);
75
76
77
  return kRPCSuccess;
}

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
      CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
      auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
                                                         std::move(listeners));
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

109
//////////////////////////// C APIs ////////////////////////////
110
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
111
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
112
113

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
114
.set_body([](DGLArgs args, DGLRetValue* rv) {
115
116
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
117
118
119
  InitGlobalTpContext();
  RPCContext::getInstance()->sender =
    std::make_shared<TPSender>(RPCContext::getInstance()->ctx);
120
121
122
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
123
.set_body([](DGLArgs args, DGLRetValue* rv) {
124
125
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
126
127
128
  InitGlobalTpContext();
  RPCContext::getInstance()->receiver =
    std::make_shared<TPReceiver>(RPCContext::getInstance()->ctx);
129
130
131
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
132
133
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->Finalize();
134
135
136
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
137
138
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->receiver->Finalize();
139
140
141
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
142
.set_body([](DGLArgs args, DGLRetValue* rv) {
143
144
145
146
  std::string ip = args[0];
  int port = args[1];
  int num_sender = args[2];
  std::string addr;
147
148
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
  if (RPCContext::getInstance()->receiver->Wait(addr, num_sender) == false) {
149
150
151
152
153
    LOG(FATAL) << "Wait sender socket failed.";
  }
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCAddReceiver")
154
.set_body([](DGLArgs args, DGLRetValue* rv) {
155
156
157
158
  std::string ip = args[0];
  int port = args[1];
  int recv_id = args[2];
  std::string addr;
159
160
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
  RPCContext::getInstance()->sender->AddReceiver(addr, recv_id);
161
162
163
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSenderConnect")
164
165
.set_body([](DGLArgs args, DGLRetValue* rv) {
  if (RPCContext::getInstance()->sender->Connect() == false) {
166
167
168
169
170
    LOG(FATAL) << "Sender connection failed.";
  }
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
171
.set_body([](DGLArgs args, DGLRetValue* rv) {
172
  const int32_t rank = args[0];
173
  RPCContext::getInstance()->rank = rank;
174
175
176
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
177
178
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->rank;
179
180
181
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
182
.set_body([](DGLArgs args, DGLRetValue* rv) {
183
  const int32_t num_servers = args[0];
184
  *rv = RPCContext::getInstance()->num_servers = num_servers;
185
186
187
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
188
189
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers;
190
191
});

192
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
193
.set_body([](DGLArgs args, DGLRetValue* rv) {
194
  const int32_t num_clients = args[0];
195
  *rv = RPCContext::getInstance()->num_clients = num_clients;
196
197
198
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
199
200
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_clients;
201
202
});

203
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
204
.set_body([](DGLArgs args, DGLRetValue* rv) {
205
  const int32_t num_servers = args[0];
206
  *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
207
208
209
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
210
211
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers_per_machine;
212
213
});

214
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
215
216
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = (RPCContext::getInstance()->msg_seq)++;
217
218
219
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
220
221
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->msg_seq;
222
223
224
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
225
.set_body([](DGLArgs args, DGLRetValue* rv) {
226
  const int64_t msg_seq = args[0];
227
  RPCContext::getInstance()->msg_seq = msg_seq;
228
229
});

230
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
231
232
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->barrier_count;
233
234
235
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
236
.set_body([](DGLArgs args, DGLRetValue* rv) {
237
  const int32_t count = args[0];
238
  RPCContext::getInstance()->barrier_count = count;
239
240
});

241
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
242
243
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->machine_id;
244
245
246
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
247
.set_body([](DGLArgs args, DGLRetValue* rv) {
248
  const int32_t machine_id = args[0];
249
  RPCContext::getInstance()->machine_id = machine_id;
250
251
252
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
253
254
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_machines;
255
256
257
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
258
.set_body([](DGLArgs args, DGLRetValue* rv) {
259
  const int32_t num_machines = args[0];
260
  RPCContext::getInstance()->num_machines = num_machines;
261
262
263
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
264
.set_body([](DGLArgs args, DGLRetValue* rv) {
265
  RPCMessageRef msg = args[0];
266
267
  const int32_t target_id = args[1];
  *rv = SendRPCMessage(*(msg.sptr()), target_id);
268
269
270
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
271
.set_body([](DGLArgs args, DGLRetValue* rv) {
272
273
274
275
276
277
278
279
  int32_t timeout = args[0];
  RPCMessageRef msg = args[1];
  *rv = RecvRPCMessage(msg.sptr().get(), timeout);
});

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
280
281
282
283
284
285
286
287
288
.set_body([](DGLArgs args, DGLRetValue* rv) {
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  int64_t message_size = args[0];

289
290
291
292
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

293

294
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
295
.set_body([](DGLArgs args, DGLRetValue* rv) {
296
297
298
299
300
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  rst->service_id = args[0];
  rst->msg_seq = args[1];
  rst->client_id = args[2];
  rst->server_id = args[3];
301
302
  const std::string data =
    args[4];  // directly assigning string value raises errors :(
303
304
305
306
307
308
  rst->data = data;
  rst->tensors = ListValueToVector<NDArray>(args[5]);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
309
.set_body([](DGLArgs args, DGLRetValue* rv) {
310
311
312
313
314
  const RPCMessageRef msg = args[0];
  *rv = msg->service_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
315
.set_body([](DGLArgs args, DGLRetValue* rv) {
316
317
318
319
320
  const RPCMessageRef msg = args[0];
  *rv = msg->msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
321
.set_body([](DGLArgs args, DGLRetValue* rv) {
322
323
324
325
326
  const RPCMessageRef msg = args[0];
  *rv = msg->client_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
327
.set_body([](DGLArgs args, DGLRetValue* rv) {
328
329
330
331
332
  const RPCMessageRef msg = args[0];
  *rv = msg->server_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
333
.set_body([](DGLArgs args, DGLRetValue* rv) {
334
335
336
337
338
339
  const RPCMessageRef msg = args[0];
  DGLByteArray barr{msg->data.c_str(), msg->data.size()};
  *rv = barr;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
340
.set_body([](DGLArgs args, DGLRetValue* rv) {
341
342
343
344
345
346
347
348
  const RPCMessageRef msg = args[0];
  List<Value> ret;
  for (size_t i = 0; i < msg->tensors.size(); ++i) {
    ret.push_back(Value(MakeValue(msg->tensors[i])));
  }
  *rv = ret;
});

349
350
#if defined(__linux__)
/*!
351
 * \brief The signal handler.
352
353
 * \param s signal
 */
354
void SigHandler(int s) {
355
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
356
  CleanupResources();
357
358
359
  exit(1);
}

360
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
361
.set_body([](DGLArgs args, DGLRetValue* rv) {
362
  // Ctrl+C handler
363
364
365
366
367
368
  struct sigaction sigHandler;
  sigHandler.sa_handler = SigHandler;
  sigemptyset(&sigHandler.sa_mask);
  sigHandler.sa_flags = 0;
  sigaction(SIGINT, &sigHandler, nullptr);
  sigaction(SIGTERM, &sigHandler, nullptr);
369
370
371
});
#endif

372
373
374
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
375
376
.set_body([](DGLArgs args, DGLRetValue* rv) {
  auto st = RPCContext::getInstance()->server_state;
377
  if (st.get() == nullptr) {
378
    RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
379
  }
380
381
382
  *rv = st;
});

383
384
385
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
386
.set_body([](DGLArgs args, DGLRetValue* rv) {
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
  NDArray ID = args[0];
  NDArray part_id = args[1];
  int local_machine_id = args[2];
  int64_t* ID_data = static_cast<int64_t*>(ID->data);
  int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
  int64_t ID_size = ID.GetSize() / sizeof(int64_t);
  std::vector<int64_t> global_id;
  for (int64_t i = 0; i < ID_size; ++i) {
    if (part_id_data[i] == local_machine_id) {
      global_id.push_back(ID_data[i]);
    }
  }
  NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
  *rv = res_tensor;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
404
.set_body([](DGLArgs args, DGLRetValue* rv) {
405
406
407
408
409
410
411
  // Input
  std::string name = args[0];
  int local_machine_id = args[1];
  int machine_count = args[2];
  int group_count = args[3];
  int client_id = args[4];
  int service_id = args[5];
412
  int64_t msg_seq = args[6];
413
414
415
416
417
418
419
420
421
422
423
424
425
426
  std::string pickle_data = args[7];
  NDArray ID = args[8];
  NDArray part_id = args[9];
  NDArray local_id = args[10];
  NDArray local_data = args[11];
  // Data
  dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
  dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
  dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
  dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
  char* local_data_char = static_cast<char*>(local_data->data);
  std::vector<dgl_id_t> local_ids;
  std::vector<dgl_id_t> local_ids_orginal;
  std::vector<int64_t> local_data_shape;
427
428
  std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
  std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
  // Get row size (in bytes)
  int row_size = 1;
  for (int i = 0; i < local_data->ndim; ++i) {
    local_data_shape.push_back(local_data->shape[i]);
    if (i != 0) {
      row_size *= local_data->shape[i];
    }
  }
  row_size *= (local_data->dtype.bits / 8);
  size_t data_size = local_data.GetSize();
  CHECK_GT(local_data_shape.size(), 0);
  CHECK_EQ(row_size * local_data_shape[0], data_size);
  // Get local id (used in local machine) and
  // remote id (send to remote machine)
  dgl_id_t idx = 0;
  for (dgl_id_t i = 0; i < ID_size; ++i) {
    dgl_id_t p_id = part_id_data[i];
    if (p_id == local_machine_id) {
      dgl_id_t l_id = local_id_data[idx++];
      CHECK_LT(l_id, local_data_shape[0]);
      CHECK_GE(l_id, 0);
      local_ids.push_back(l_id);
      local_ids_orginal.push_back(i);
    } else {
      CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
      dgl_id_t id = ID_data[i];
      remote_ids[p_id].push_back(id);
      remote_ids_original[p_id].push_back(i);
    }
  }
  // Send remote id
  int msg_count = 0;
  for (int i = 0; i < remote_ids.size(); ++i) {
    if (remote_ids[i].size() != 0) {
      RPCMessage msg;
      msg.service_id = service_id;
      msg.msg_seq = msg_seq;
      msg.client_id = client_id;
467
468
      int lower = i * group_count;
      int upper = (i + 1) * group_count;
469
470
471
472
473
474
475
476
477
478
479
480
481
482
      msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
      msg.data = pickle_data;
      NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
      msg.tensors.push_back(tensor);
      SendRPCMessage(msg, msg.server_id);
      msg_count++;
    }
  }
  local_data_shape[0] = ID_size;
  NDArray res_tensor = NDArray::Empty(local_data_shape,
                                      local_data->dtype,
                                      DLContext{kDLCPU, 0});
  char* return_data = static_cast<char*>(res_tensor->data);
  // Copy local data
483
484
  parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
485
486
      CHECK_GE(ID_size * row_size,
                local_ids_orginal[i] * row_size + row_size);
487
488
489
      CHECK_GE(data_size, local_ids[i] * row_size + row_size);
      CHECK_GE(local_ids[i], 0);
      memcpy(return_data + local_ids_orginal[i] * row_size,
490
             local_data_char + local_ids[i] * row_size, row_size);
491
492
    }
  });
493
494
495
496
497
498
499
500
501
  // Recv remote message
  for (int i = 0; i < msg_count; ++i) {
    RPCMessage msg;
    RecvRPCMessage(&msg, 0);
    int part_id = msg.server_id / group_count;
    char* data_char = static_cast<char*>(msg.tensors[0]->data);
    dgl_id_t id_size = remote_ids[part_id].size();
    for (size_t n = 0; n < id_size; ++n) {
      memcpy(return_data + remote_ids_original[part_id][n] * row_size,
502
             data_char + n * row_size, row_size);
503
504
505
506
507
    }
  }
  *rv = res_tensor;
});

508
509
}  // namespace rpc
}  // namespace dgl
510
511

#endif