rpc.cc 16.6 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2020 by Contributors
 * \file rpc/rpc.cc
 * \brief Implementation of RPC utilities used by both server and client sides.
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
20
#include <tensorpipe/tensorpipe.h>
#include <unistd.h>

#include <csignal>
#include <future>

21
#include "../c_api_common.h"
22
#include "../runtime/resource_manager.h"
23
24
25
26
27
28
29

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
using namespace tensorpipe;

// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

const std::string& guessAddress() {
  static const std::string uvAddress = []() {
    tensorpipe::Error error;
    std::string result;
    char* ifnameEnv = std::getenv(kSocketIfnameEnvVar);
    if (ifnameEnv != nullptr) {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for interface "
                     << ifnameEnv << " (" << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
    } else {
      std::tie(error, result) =
          tensorpipe::transport::uv::lookupAddrForHostname();
      if (error) {
        LOG(WARNING) << "Failed to look up the IP address for the hostname ("
                     << error.what() << "), defaulting to "
                     << kDefaultUvAddress;
        return std::string(kDefaultUvAddress);
      }
60
    }
61
62
63
64
65
66
67
    return result;
  }();
  return uvAddress;
}

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
68
69
70
71
72
73
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
  // ignore timeout now
  CHECK_EQ(timeout, 0) << "rpc cannot support timeout now.";
74
  RPCContext::getInstance()->receiver->Recv(msg);
75
76
77
  return kRPCSuccess;
}

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
void InitGlobalTpContext() {
  if (!RPCContext::getInstance()->ctx) {
    RPCContext::getInstance()->ctx = std::make_shared<tensorpipe::Context>();
    auto context = RPCContext::getInstance()->ctx;
    auto transportContext = tensorpipe::transport::uv::create();
    auto shmtransport = tensorpipe::transport::shm::create();
    context->registerTransport(0 /* priority */, "tcp", transportContext);
    // Register basic uv channel
    auto basicChannel = tensorpipe::channel::basic::create();
    context->registerChannel(0 /* low priority */, "basic", basicChannel);

    char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS");
    if (numUvThreads_str) {
      int numUvThreads = std::atoi(numUvThreads_str);
      CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set";
      // Register multiplex uv channel
      std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
      std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
      for (int i = 0; i < numUvThreads; i++) {
        auto context = tensorpipe::transport::uv::create();
        std::string address = guessAddress();
        contexts.push_back(std::move(context));
        listeners.push_back(contexts.back()->listen(address));
      }
      auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts),
                                                         std::move(listeners));
      context->registerChannel(20 /* high priority */, "mpt", mptChannel);
    }
  }
}

109
//////////////////////////// C APIs ////////////////////////////
110
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
111
.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
112
113

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
114
.set_body([](DGLArgs args, DGLRetValue* rv) {
115
116
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
117
118
119
  InitGlobalTpContext();
  RPCContext::getInstance()->sender =
    std::make_shared<TPSender>(RPCContext::getInstance()->ctx);
120
121
122
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
123
.set_body([](DGLArgs args, DGLRetValue* rv) {
124
125
  int64_t msg_queue_size = args[0];
  std::string type = args[1];
126
127
128
  InitGlobalTpContext();
  RPCContext::getInstance()->receiver =
    std::make_shared<TPReceiver>(RPCContext::getInstance()->ctx);
129
130
131
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
132
133
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->sender->Finalize();
134
135
136
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
137
138
.set_body([](DGLArgs args, DGLRetValue* rv) {
  RPCContext::getInstance()->receiver->Finalize();
139
140
141
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
142
.set_body([](DGLArgs args, DGLRetValue* rv) {
143
144
145
  std::string ip = args[0];
  int port = args[1];
  int num_sender = args[2];
146
  bool blocking = args[3];
147
  std::string addr;
148
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
149
  if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
150
151
152
153
    LOG(FATAL) << "Wait sender socket failed.";
  }
});

154
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
155
.set_body([](DGLArgs args, DGLRetValue* rv) {
156
157
158
159
  std::string ip = args[0];
  int port = args[1];
  int recv_id = args[2];
  std::string addr;
160
  addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
161
  *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
162
163
164
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
165
.set_body([](DGLArgs args, DGLRetValue* rv) {
166
  const int32_t rank = args[0];
167
  RPCContext::getInstance()->rank = rank;
168
169
170
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
171
172
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->rank;
173
174
175
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
176
.set_body([](DGLArgs args, DGLRetValue* rv) {
177
  const int32_t num_servers = args[0];
178
  *rv = RPCContext::getInstance()->num_servers = num_servers;
179
180
181
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
182
183
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers;
184
185
});

186
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
187
.set_body([](DGLArgs args, DGLRetValue* rv) {
188
  const int32_t num_clients = args[0];
189
  *rv = RPCContext::getInstance()->num_clients = num_clients;
190
191
192
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
193
194
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_clients;
195
196
});

197
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
198
.set_body([](DGLArgs args, DGLRetValue* rv) {
199
  const int32_t num_servers = args[0];
200
  *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
201
202
203
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
204
205
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_servers_per_machine;
206
207
});

208
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
209
210
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = (RPCContext::getInstance()->msg_seq)++;
211
212
213
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
214
215
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->msg_seq;
216
217
218
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
219
.set_body([](DGLArgs args, DGLRetValue* rv) {
220
  const int64_t msg_seq = args[0];
221
  RPCContext::getInstance()->msg_seq = msg_seq;
222
223
});

224
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
225
226
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->barrier_count;
227
228
229
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
230
.set_body([](DGLArgs args, DGLRetValue* rv) {
231
  const int32_t count = args[0];
232
  RPCContext::getInstance()->barrier_count = count;
233
234
});

235
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
236
237
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->machine_id;
238
239
240
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
241
.set_body([](DGLArgs args, DGLRetValue* rv) {
242
  const int32_t machine_id = args[0];
243
  RPCContext::getInstance()->machine_id = machine_id;
244
245
246
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
247
248
.set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = RPCContext::getInstance()->num_machines;
249
250
251
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
252
.set_body([](DGLArgs args, DGLRetValue* rv) {
253
  const int32_t num_machines = args[0];
254
  RPCContext::getInstance()->num_machines = num_machines;
255
256
257
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
258
.set_body([](DGLArgs args, DGLRetValue* rv) {
259
  RPCMessageRef msg = args[0];
260
261
  const int32_t target_id = args[1];
  *rv = SendRPCMessage(*(msg.sptr()), target_id);
262
263
264
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
265
.set_body([](DGLArgs args, DGLRetValue* rv) {
266
267
268
269
270
271
272
273
  int32_t timeout = args[0];
  RPCMessageRef msg = args[1];
  *rv = RecvRPCMessage(msg.sptr().get(), timeout);
});

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
274
275
276
277
278
279
280
281
282
.set_body([](DGLArgs args, DGLRetValue* rv) {
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
  int64_t message_size = args[0];

283
284
285
286
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  *rv = rst;
});

287

288
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
289
.set_body([](DGLArgs args, DGLRetValue* rv) {
290
291
292
293
294
  std::shared_ptr<RPCMessage> rst(new RPCMessage);
  rst->service_id = args[0];
  rst->msg_seq = args[1];
  rst->client_id = args[2];
  rst->server_id = args[3];
295
296
  const std::string data =
    args[4];  // directly assigning string value raises errors :(
297
298
299
300
301
302
  rst->data = data;
  rst->tensors = ListValueToVector<NDArray>(args[5]);
  *rv = rst;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
303
.set_body([](DGLArgs args, DGLRetValue* rv) {
304
305
306
307
308
  const RPCMessageRef msg = args[0];
  *rv = msg->service_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
309
.set_body([](DGLArgs args, DGLRetValue* rv) {
310
311
312
313
314
  const RPCMessageRef msg = args[0];
  *rv = msg->msg_seq;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
315
.set_body([](DGLArgs args, DGLRetValue* rv) {
316
317
318
319
320
  const RPCMessageRef msg = args[0];
  *rv = msg->client_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
321
.set_body([](DGLArgs args, DGLRetValue* rv) {
322
323
324
325
326
  const RPCMessageRef msg = args[0];
  *rv = msg->server_id;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
327
.set_body([](DGLArgs args, DGLRetValue* rv) {
328
329
330
331
332
333
  const RPCMessageRef msg = args[0];
  DGLByteArray barr{msg->data.c_str(), msg->data.size()};
  *rv = barr;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
334
.set_body([](DGLArgs args, DGLRetValue* rv) {
335
336
337
338
339
340
341
342
  const RPCMessageRef msg = args[0];
  List<Value> ret;
  for (size_t i = 0; i < msg->tensors.size(); ++i) {
    ret.push_back(Value(MakeValue(msg->tensors[i])));
  }
  *rv = ret;
});

343
344
#if defined(__linux__)
/*!
345
 * \brief The signal handler.
346
347
 * \param s signal
 */
348
void SigHandler(int s) {
349
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
350
  CleanupResources();
351
352
353
  exit(1);
}

354
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
355
.set_body([](DGLArgs args, DGLRetValue* rv) {
356
  // Ctrl+C handler
357
358
359
360
361
362
  struct sigaction sigHandler;
  sigHandler.sa_handler = SigHandler;
  sigemptyset(&sigHandler.sa_mask);
  sigHandler.sa_flags = 0;
  sigaction(SIGINT, &sigHandler, nullptr);
  sigaction(SIGTERM, &sigHandler, nullptr);
363
364
365
});
#endif

366
367
368
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
369
370
.set_body([](DGLArgs args, DGLRetValue* rv) {
  auto st = RPCContext::getInstance()->server_state;
371
  if (st.get() == nullptr) {
372
    RPCContext::getInstance()->server_state = std::make_shared<ServerState>();
373
  }
374
375
376
  *rv = st;
});

377
378
379
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
380
.set_body([](DGLArgs args, DGLRetValue* rv) {
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
  NDArray ID = args[0];
  NDArray part_id = args[1];
  int local_machine_id = args[2];
  int64_t* ID_data = static_cast<int64_t*>(ID->data);
  int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
  int64_t ID_size = ID.GetSize() / sizeof(int64_t);
  std::vector<int64_t> global_id;
  for (int64_t i = 0; i < ID_size; ++i) {
    if (part_id_data[i] == local_machine_id) {
      global_id.push_back(ID_data[i]);
    }
  }
  NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
  *rv = res_tensor;
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
398
.set_body([](DGLArgs args, DGLRetValue* rv) {
399
400
401
402
403
404
405
  // Input
  std::string name = args[0];
  int local_machine_id = args[1];
  int machine_count = args[2];
  int group_count = args[3];
  int client_id = args[4];
  int service_id = args[5];
406
  int64_t msg_seq = args[6];
407
408
409
410
411
412
413
414
415
416
417
418
419
420
  std::string pickle_data = args[7];
  NDArray ID = args[8];
  NDArray part_id = args[9];
  NDArray local_id = args[10];
  NDArray local_data = args[11];
  // Data
  dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
  dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
  dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
  dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
  char* local_data_char = static_cast<char*>(local_data->data);
  std::vector<dgl_id_t> local_ids;
  std::vector<dgl_id_t> local_ids_orginal;
  std::vector<int64_t> local_data_shape;
421
422
  std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
  std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
  // Get row size (in bytes)
  int row_size = 1;
  for (int i = 0; i < local_data->ndim; ++i) {
    local_data_shape.push_back(local_data->shape[i]);
    if (i != 0) {
      row_size *= local_data->shape[i];
    }
  }
  row_size *= (local_data->dtype.bits / 8);
  size_t data_size = local_data.GetSize();
  CHECK_GT(local_data_shape.size(), 0);
  CHECK_EQ(row_size * local_data_shape[0], data_size);
  // Get local id (used in local machine) and
  // remote id (send to remote machine)
  dgl_id_t idx = 0;
  for (dgl_id_t i = 0; i < ID_size; ++i) {
    dgl_id_t p_id = part_id_data[i];
    if (p_id == local_machine_id) {
      dgl_id_t l_id = local_id_data[idx++];
      CHECK_LT(l_id, local_data_shape[0]);
      CHECK_GE(l_id, 0);
      local_ids.push_back(l_id);
      local_ids_orginal.push_back(i);
    } else {
      CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
      dgl_id_t id = ID_data[i];
      remote_ids[p_id].push_back(id);
      remote_ids_original[p_id].push_back(i);
    }
  }
  // Send remote id
  int msg_count = 0;
  for (int i = 0; i < remote_ids.size(); ++i) {
    if (remote_ids[i].size() != 0) {
      RPCMessage msg;
      msg.service_id = service_id;
      msg.msg_seq = msg_seq;
      msg.client_id = client_id;
461
462
      int lower = i * group_count;
      int upper = (i + 1) * group_count;
463
464
465
466
467
468
469
470
471
472
473
474
475
476
      msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
      msg.data = pickle_data;
      NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
      msg.tensors.push_back(tensor);
      SendRPCMessage(msg, msg.server_id);
      msg_count++;
    }
  }
  local_data_shape[0] = ID_size;
  NDArray res_tensor = NDArray::Empty(local_data_shape,
                                      local_data->dtype,
                                      DLContext{kDLCPU, 0});
  char* return_data = static_cast<char*>(res_tensor->data);
  // Copy local data
477
478
  parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
    for (auto i = b; i < e; ++i) {
479
480
      CHECK_GE(ID_size * row_size,
                local_ids_orginal[i] * row_size + row_size);
481
482
483
      CHECK_GE(data_size, local_ids[i] * row_size + row_size);
      CHECK_GE(local_ids[i], 0);
      memcpy(return_data + local_ids_orginal[i] * row_size,
484
             local_data_char + local_ids[i] * row_size, row_size);
485
486
    }
  });
487
488
489
490
491
492
493
494
495
  // Recv remote message
  for (int i = 0; i < msg_count; ++i) {
    RPCMessage msg;
    RecvRPCMessage(&msg, 0);
    int part_id = msg.server_id / group_count;
    char* data_char = static_cast<char*>(msg.tensors[0]->data);
    dgl_id_t id_size = remote_ids[part_id].size();
    for (size_t n = 0; n < id_size; ++n) {
      memcpy(return_data + remote_ids_original[part_id][n] * row_size,
496
             data_char + n * row_size, row_size);
497
498
499
500
501
    }
  }
  *rv = res_tensor;
});

502
503
}  // namespace rpc
}  // namespace dgl
504
505

#endif