"examples/pytorch/ogb/line/reading_data.py" did not exist on "8d16d4bb6b436d2f0b18b2300be1b77acd614709"
rpc.cc 17.2 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file rpc/rpc.cc
 * @brief Implementation of RPC utilities used by both server and client sides.
5
 */
6
#if defined(__linux__)
7
#include "./rpc.h"
8

9
#include <dgl/array.h>
10
#include <dgl/packed_func_ext.h>
11
#include <dgl/random.h>
12
13
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
14
#include <dgl/zerocopy_serializer.h>
15
16
17
18
19
#include <unistd.h>

#include <csignal>
#include <future>

20
#include "../c_api_common.h"
21
#include "../runtime/resource_manager.h"
22
23
24
25
26
27
28

using dgl::network::StringPrintf;
using namespace dgl::runtime;

namespace dgl {
namespace rpc {

29
30
31
32
33
34
35
// Borrow from PyTorch

const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME";
const char kDefaultUvAddress[] = "127.0.0.1";

RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
  RPCContext::getInstance()->sender->Send(msg, target_id);
36
37
38
39
  return kRPCSuccess;
}

RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
  static constexpr int32_t retry_timeout = 5 * 1000;  // milliseconds
  RPCStatus status;
  const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;
  do {
    status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);
    if (status == kRPCTimeOut) {
      static const std::string log_str = [real_timeout, timeout]() {
        std::ostringstream oss;
        oss << "Recv RPCMessage timeout in " << real_timeout << " ms."
            << (timeout == 0 ? " Retrying ..." : "");
        return oss.str();
      }();
      DLOG(WARNING) << log_str;
    }
  } while (timeout == 0 && status == kRPCTimeOut);
  return status;
56
57
58
}

//////////////////////////// C APIs ////////////////////////////
59
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
60
    .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });
61
62

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
63
64
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int64_t msg_queue_size = args[0];
65
66
67
      int max_thread_count = args[1];
      RPCContext::getInstance()->sender.reset(
          new network::SocketSender(msg_queue_size, max_thread_count));
68
    });
69
70

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
71
72
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int64_t msg_queue_size = args[0];
73
74
75
      int max_thread_count = args[1];
      RPCContext::getInstance()->receiver.reset(
          new network::SocketReceiver(msg_queue_size, max_thread_count));
76
    });
77
78

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
79
80
81
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCContext::getInstance()->sender->Finalize();
    });
82
83

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
84
85
86
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCContext::getInstance()->receiver->Finalize();
    });
87

88
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
89
90
91
92
93
94
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::string ip = args[0];
      int port = args[1];
      int num_sender = args[2];
      std::string addr;
      addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
95
96
      if (RPCContext::getInstance()->receiver->Wait(addr, num_sender) ==
          false) {
97
98
99
        LOG(FATAL) << "Wait sender socket failed.";
      }
    });
100

101
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
102
103
104
105
106
107
108
109
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::string ip = args[0];
      int port = args[1];
      int recv_id = args[2];
      std::string addr;
      addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
      *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
    });
110

111
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
112
113
114
115
116
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int max_try_times = args[0];
      *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(
          max_try_times);
    });
117

118
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
119
120
121
122
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t rank = args[0];
      RPCContext::getInstance()->rank = rank;
    });
123
124

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
125
126
127
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->rank;
    });
128
129

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
130
131
132
133
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_servers = args[0];
      *rv = RPCContext::getInstance()->num_servers = num_servers;
    });
134
135

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
136
137
138
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_servers;
    });
139

140
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient")
141
142
143
144
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_clients = args[0];
      *rv = RPCContext::getInstance()->num_clients = num_clients;
    });
145
146

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient")
147
148
149
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_clients;
    });
150

151
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
152
153
154
155
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_servers = args[0];
      *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;
    });
156
157

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
158
159
160
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_servers_per_machine;
    });
161

162
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
163
164
165
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = (RPCContext::getInstance()->msg_seq)++;
    });
166
167

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
168
169
170
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->msg_seq;
    });
171
172

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
173
174
175
176
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int64_t msg_seq = args[0];
      RPCContext::getInstance()->msg_seq = msg_seq;
    });
177

178
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount")
179
180
181
182
183
184
185
186
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t group_id = args[0];
      auto&& cnt = RPCContext::getInstance()->barrier_count;
      if (cnt.find(group_id) == cnt.end()) {
        cnt.emplace(group_id, 0x0);
      }
      *rv = cnt[group_id];
    });
187
188

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount")
189
190
191
192
193
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t count = args[0];
      const int32_t group_id = args[1];
      RPCContext::getInstance()->barrier_count[group_id] = count;
    });
194

195
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
196
197
198
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->machine_id;
    });
199
200

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
201
202
203
204
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t machine_id = args[0];
      RPCContext::getInstance()->machine_id = machine_id;
    });
205
206

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
207
208
209
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->num_machines;
    });
210
211

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
212
213
214
215
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t num_machines = args[0];
      RPCContext::getInstance()->num_machines = num_machines;
    });
216
217

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
218
219
220
221
222
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      RPCMessageRef msg = args[0];
      const int32_t target_id = args[1];
      *rv = SendRPCMessage(*(msg.sptr()), target_id);
    });
223
224

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
225
226
227
228
229
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int32_t timeout = args[0];
      RPCMessageRef msg = args[1];
      *rv = RecvRPCMessage(msg.sptr().get(), timeout);
    });
230
231
232
233

//////////////////////////// RPCMessage ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
234
235
236
237
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::shared_ptr<RPCMessage> rst(new RPCMessage);
      *rv = rst;
    });
238

239
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
240
241
242
243
244
245
246
247
248
249
250
251
252
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      std::shared_ptr<RPCMessage> rst(new RPCMessage);
      rst->service_id = args[0];
      rst->msg_seq = args[1];
      rst->client_id = args[2];
      rst->server_id = args[3];
      const std::string data =
          args[4];  // directly assigning string value raises errors :(
      rst->data = data;
      rst->tensors = ListValueToVector<NDArray>(args[5]);
      rst->group_id = args[6];
      *rv = rst;
    });
253
254

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
255
256
257
258
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->service_id;
    });
259
260

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
261
262
263
264
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->msg_seq;
    });
265
266

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
267
268
269
270
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->client_id;
    });
271
272

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
273
274
275
276
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->server_id;
    });
277
278

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
279
280
281
282
283
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      DGLByteArray barr{msg->data.c_str(), msg->data.size()};
      *rv = barr;
    });
284
285

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
286
287
288
289
290
291
292
293
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      List<Value> ret;
      for (size_t i = 0; i < msg->tensors.size(); ++i) {
        ret.push_back(Value(MakeValue(msg->tensors[i])));
      }
      *rv = ret;
    });
294

295
#if defined(__linux__)
296
/**
297
298
 * @brief The signal handler.
 * @param s signal
299
 */
300
void SigHandler(int s) {
301
  LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
302
  CleanupResources();
303
304
305
  exit(1);
}

306
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal")
307
308
309
310
311
312
313
314
315
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      // Ctrl+C handler
      struct sigaction sigHandler;
      sigHandler.sa_handler = SigHandler;
      sigemptyset(&sigHandler.sa_mask);
      sigHandler.sa_flags = 0;
      sigaction(SIGINT, &sigHandler, nullptr);
      sigaction(SIGTERM, &sigHandler, nullptr);
    });
316
317
#endif

318
319
320
//////////////////////////// ServerState ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
321
322
323
324
325
326
327
328
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      auto st = RPCContext::getInstance()->server_state;
      if (st.get() == nullptr) {
        RPCContext::getInstance()->server_state =
            std::make_shared<ServerState>();
      }
      *rv = st;
    });
329

330
331
332
//////////////////////////// KVStore ////////////////////////////

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray ID = args[0];
      NDArray part_id = args[1];
      int local_machine_id = args[2];
      int64_t* ID_data = static_cast<int64_t*>(ID->data);
      int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
      int64_t ID_size = ID.GetSize() / sizeof(int64_t);
      std::vector<int64_t> global_id;
      for (int64_t i = 0; i < ID_size; ++i) {
        if (part_id_data[i] == local_machine_id) {
          global_id.push_back(ID_data[i]);
        }
      }
      NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
      *rv = res_tensor;
    });
349
350

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      // Input
      std::string name = args[0];
      int local_machine_id = args[1];
      int machine_count = args[2];
      int group_count = args[3];
      int client_id = args[4];
      int service_id = args[5];
      int64_t msg_seq = args[6];
      std::string pickle_data = args[7];
      NDArray ID = args[8];
      NDArray part_id = args[9];
      NDArray local_id = args[10];
      NDArray local_data = args[11];
      // Data
      dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
      dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
      dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
      dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
      char* local_data_char = static_cast<char*>(local_data->data);
      std::vector<dgl_id_t> local_ids;
      std::vector<dgl_id_t> local_ids_orginal;
      std::vector<int64_t> local_data_shape;
      std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);
      std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);
      // Get row size (in bytes)
      int row_size = 1;
      for (int i = 0; i < local_data->ndim; ++i) {
        local_data_shape.push_back(local_data->shape[i]);
        if (i != 0) {
          row_size *= local_data->shape[i];
        }
      }
      row_size *= (local_data->dtype.bits / 8);
      size_t data_size = local_data.GetSize();
      CHECK_GT(local_data_shape.size(), 0);
      CHECK_EQ(row_size * local_data_shape[0], data_size);
      // Get local id (used in local machine) and
      // remote id (send to remote machine)
      dgl_id_t idx = 0;
      for (dgl_id_t i = 0; i < ID_size; ++i) {
        dgl_id_t p_id = part_id_data[i];
        if (static_cast<int>(p_id) == local_machine_id) {
          dgl_id_t l_id = local_id_data[idx++];
          CHECK_LT(l_id, local_data_shape[0]);
          CHECK_GE(l_id, 0);
          local_ids.push_back(l_id);
          local_ids_orginal.push_back(i);
        } else {
          CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
          dgl_id_t id = ID_data[i];
          remote_ids[p_id].push_back(id);
          remote_ids_original[p_id].push_back(i);
        }
      }
      // Send remote id
      int msg_count = 0;
      for (size_t i = 0; i < remote_ids.size(); ++i) {
        if (remote_ids[i].size() != 0) {
          RPCMessage msg;
          msg.service_id = service_id;
          msg.msg_seq = msg_seq;
          msg.client_id = client_id;
          int lower = i * group_count;
          int upper = (i + 1) * group_count;
          msg.server_id =
              dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
          msg.data = pickle_data;
          NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
          msg.tensors.push_back(tensor);
          msg.group_id = RPCContext::getInstance()->group_id;
          SendRPCMessage(msg, msg.server_id);
          msg_count++;
        }
      }
      local_data_shape[0] = ID_size;
      NDArray res_tensor = NDArray::Empty(
          local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0});
      char* return_data = static_cast<char*>(res_tensor->data);
      // Copy local data
      parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {
        for (auto i = b; i < e; ++i) {
          CHECK_GE(
              ID_size * row_size, local_ids_orginal[i] * row_size + row_size);
          CHECK_GE(data_size, local_ids[i] * row_size + row_size);
          CHECK_GE(local_ids[i], 0);
          memcpy(
              return_data + local_ids_orginal[i] * row_size,
              local_data_char + local_ids[i] * row_size, row_size);
        }
      });
      // Recv remote message
      int recv_cnt = 0;
      while (recv_cnt < msg_count) {
        RPCMessage msg;
        auto status = RecvRPCMessage(&msg, 0);
        CHECK_EQ(status, kRPCSuccess);
        ++recv_cnt;
        int part_id = msg.server_id / group_count;
        char* data_char = static_cast<char*>(msg.tensors[0]->data);
        dgl_id_t id_size = remote_ids[part_id].size();
        for (size_t n = 0; n < id_size; ++n) {
          memcpy(
              return_data + remote_ids_original[part_id][n] * row_size,
              data_char + n * row_size, row_size);
        }
      }
      *rv = res_tensor;
    });
460

461
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGroupID")
462
463
464
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = RPCContext::getInstance()->group_id;
    });
465
466

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetGroupID")
467
468
469
470
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t group_id = args[0];
      RPCContext::getInstance()->group_id = group_id;
    });
471
472

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetGroupId")
473
474
475
476
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const RPCMessageRef msg = args[0];
      *rv = msg->group_id;
    });
477
478

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRegisterClient")
479
480
481
482
483
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t client_id = args[0];
      const int32_t group_id = args[1];
      *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);
    });
484
485

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetClient")
486
487
488
489
490
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t client_id = args[0];
      const int32_t group_id = args[1];
      *rv = RPCContext::getInstance()->GetClient(client_id, group_id);
    });
491

492
493
}  // namespace rpc
}  // namespace dgl
494
495

#endif