/*! * Copyright (c) 2020 by Contributors * \file rpc/rpc.cc * \brief Implementation of RPC utilities used by both server and client sides. */ #if defined(__linux__) #include "./rpc.h" #include #include #include #include #include #include #include #include #include #include #include "../c_api_common.h" #include "../runtime/resource_manager.h" using dgl::network::StringPrintf; using namespace dgl::runtime; namespace dgl { namespace rpc { using namespace tensorpipe; // Borrow from PyTorch const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME"; const char kDefaultUvAddress[] = "127.0.0.1"; const std::string& guessAddress() { static const std::string uvAddress = []() { tensorpipe::Error error; std::string result; char* ifnameEnv = std::getenv(kSocketIfnameEnvVar); if (ifnameEnv != nullptr) { std::tie(error, result) = tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv); if (error) { LOG(WARNING) << "Failed to look up the IP address for interface " << ifnameEnv << " (" << error.what() << "), defaulting to " << kDefaultUvAddress; return std::string(kDefaultUvAddress); } } else { std::tie(error, result) = tensorpipe::transport::uv::lookupAddrForHostname(); if (error) { LOG(WARNING) << "Failed to look up the IP address for the hostname (" << error.what() << "), defaulting to " << kDefaultUvAddress; return std::string(kDefaultUvAddress); } } return result; }(); return uvAddress; } RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) { RPCContext::getInstance()->sender->Send(msg, target_id); return kRPCSuccess; } RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) { // ignore timeout now CHECK_EQ(timeout, 0) << "rpc cannot support timeout now."; RPCContext::getInstance()->receiver->Recv(msg); return kRPCSuccess; } void InitGlobalTpContext() { if (!RPCContext::getInstance()->ctx) { RPCContext::getInstance()->ctx = std::make_shared(); auto context = RPCContext::getInstance()->ctx; auto transportContext = tensorpipe::transport::uv::create(); auto shmtransport = tensorpipe::transport::shm::create(); context->registerTransport(0 /* priority */, "tcp", transportContext); // Register basic uv channel auto basicChannel = tensorpipe::channel::basic::create(); context->registerChannel(0 /* low priority */, "basic", basicChannel); char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS"); if (numUvThreads_str) { int numUvThreads = std::atoi(numUvThreads_str); CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set"; // Register multiplex uv channel std::vector> contexts; std::vector> listeners; for (int i = 0; i < numUvThreads; i++) { auto context = tensorpipe::transport::uv::create(); std::string address = guessAddress(); contexts.push_back(std::move(context)); listeners.push_back(contexts.back()->listen(address)); } auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts), std::move(listeners)); context->registerChannel(20 /* high priority */, "mpt", mptChannel); } } } //////////////////////////// C APIs //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset") .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") .set_body([](DGLArgs args, DGLRetValue* rv) { int64_t msg_queue_size = args[0]; std::string type = args[1]; InitGlobalTpContext(); RPCContext::getInstance()->sender = std::make_shared(RPCContext::getInstance()->ctx); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver") .set_body([](DGLArgs args, DGLRetValue* rv) { int64_t msg_queue_size = args[0]; std::string type = args[1]; InitGlobalTpContext(); RPCContext::getInstance()->receiver = std::make_shared(RPCContext::getInstance()->ctx); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender") .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::getInstance()->sender->Finalize(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver") .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::getInstance()->receiver->Finalize(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait") .set_body([](DGLArgs args, DGLRetValue* rv) { std::string ip = args[0]; int port = args[1]; int num_sender = args[2]; bool blocking = args[3]; std::string addr; addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) { LOG(FATAL) << "Wait sender socket failed."; } }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver") .set_body([](DGLArgs args, DGLRetValue* rv) { std::string ip = args[0]; int port = args[1]; int recv_id = args[2]; std::string addr; addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank") .set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t rank = args[0]; RPCContext::getInstance()->rank = rank; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->rank; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer") .set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_servers = args[0]; *rv = RPCContext::getInstance()->num_servers = num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient") .set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_clients = args[0]; *rv = RPCContext::getInstance()->num_clients = num_clients; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->num_clients; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine") .set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_servers = args[0]; *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->num_servers_per_machine; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = (RPCContext::getInstance()->msg_seq)++; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq") .set_body([](DGLArgs args, DGLRetValue* rv) { const int64_t msg_seq = args[0]; RPCContext::getInstance()->msg_seq = msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->barrier_count; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount") .set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t count = args[0]; RPCContext::getInstance()->barrier_count = count; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->machine_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID") .set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t machine_id = args[0]; RPCContext::getInstance()->machine_id = machine_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines") .set_body([](DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::getInstance()->num_machines; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines") .set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_machines = args[0]; RPCContext::getInstance()->num_machines = num_machines; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage") .set_body([](DGLArgs args, DGLRetValue* rv) { RPCMessageRef msg = args[0]; const int32_t target_id = args[1]; *rv = SendRPCMessage(*(msg.sptr()), target_id); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage") .set_body([](DGLArgs args, DGLRetValue* rv) { int32_t timeout = args[0]; RPCMessageRef msg = args[1]; *rv = RecvRPCMessage(msg.sptr().get(), timeout); }); //////////////////////////// RPCMessage //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage") .set_body([](DGLArgs args, DGLRetValue* rv) { std::shared_ptr rst(new RPCMessage); *rv = rst; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize") .set_body([](DGLArgs args, DGLRetValue* rv) { int64_t message_size = args[0]; std::shared_ptr rst(new RPCMessage); *rv = rst; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage") .set_body([](DGLArgs args, DGLRetValue* rv) { std::shared_ptr rst(new RPCMessage); rst->service_id = args[0]; rst->msg_seq = args[1]; rst->client_id = args[2]; rst->server_id = args[3]; const std::string data = args[4]; // directly assigning string value raises errors :( rst->data = data; rst->tensors = ListValueToVector(args[5]); *rv = rst; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId") .set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->service_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq") .set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId") .set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->client_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId") .set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->server_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData") .set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; DGLByteArray barr{msg->data.c_str(), msg->data.size()}; *rv = barr; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors") .set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; List ret; for (size_t i = 0; i < msg->tensors.size(); ++i) { ret.push_back(Value(MakeValue(msg->tensors[i]))); } *rv = ret; }); #if defined(__linux__) /*! * \brief The signal handler. * \param s signal */ void SigHandler(int s) { LOG(INFO) << "\nUser pressed Ctrl+C, Exiting"; CleanupResources(); exit(1); } DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal") .set_body([](DGLArgs args, DGLRetValue* rv) { // Ctrl+C handler struct sigaction sigHandler; sigHandler.sa_handler = SigHandler; sigemptyset(&sigHandler.sa_mask); sigHandler.sa_flags = 0; sigaction(SIGINT, &sigHandler, nullptr); sigaction(SIGTERM, &sigHandler, nullptr); }); #endif //////////////////////////// ServerState //////////////////////////// DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState") .set_body([](DGLArgs args, DGLRetValue* rv) { auto st = RPCContext::getInstance()->server_state; if (st.get() == nullptr) { RPCContext::getInstance()->server_state = std::make_shared(); } *rv = st; }); //////////////////////////// KVStore //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition") .set_body([](DGLArgs args, DGLRetValue* rv) { NDArray ID = args[0]; NDArray part_id = args[1]; int local_machine_id = args[2]; int64_t* ID_data = static_cast(ID->data); int64_t* part_id_data = static_cast(part_id->data); int64_t ID_size = ID.GetSize() / sizeof(int64_t); std::vector global_id; for (int64_t i = 0; i < ID_size; ++i) { if (part_id_data[i] == local_machine_id) { global_id.push_back(ID_data[i]); } } NDArray res_tensor = dgl::aten::VecToIdArray(global_id); *rv = res_tensor; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") .set_body([](DGLArgs args, DGLRetValue* rv) { // Input std::string name = args[0]; int local_machine_id = args[1]; int machine_count = args[2]; int group_count = args[3]; int client_id = args[4]; int service_id = args[5]; int64_t msg_seq = args[6]; std::string pickle_data = args[7]; NDArray ID = args[8]; NDArray part_id = args[9]; NDArray local_id = args[10]; NDArray local_data = args[11]; // Data dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t); dgl_id_t* ID_data = static_cast(ID->data); dgl_id_t* part_id_data = static_cast(part_id->data); dgl_id_t* local_id_data = static_cast(local_id->data); char* local_data_char = static_cast(local_data->data); std::vector local_ids; std::vector local_ids_orginal; std::vector local_data_shape; std::vector> remote_ids(machine_count); std::vector> remote_ids_original(machine_count); // Get row size (in bytes) int row_size = 1; for (int i = 0; i < local_data->ndim; ++i) { local_data_shape.push_back(local_data->shape[i]); if (i != 0) { row_size *= local_data->shape[i]; } } row_size *= (local_data->dtype.bits / 8); size_t data_size = local_data.GetSize(); CHECK_GT(local_data_shape.size(), 0); CHECK_EQ(row_size * local_data_shape[0], data_size); // Get local id (used in local machine) and // remote id (send to remote machine) dgl_id_t idx = 0; for (dgl_id_t i = 0; i < ID_size; ++i) { dgl_id_t p_id = part_id_data[i]; if (p_id == local_machine_id) { dgl_id_t l_id = local_id_data[idx++]; CHECK_LT(l_id, local_data_shape[0]); CHECK_GE(l_id, 0); local_ids.push_back(l_id); local_ids_orginal.push_back(i); } else { CHECK_LT(p_id, machine_count) << "Invalid partition ID."; dgl_id_t id = ID_data[i]; remote_ids[p_id].push_back(id); remote_ids_original[p_id].push_back(i); } } // Send remote id int msg_count = 0; for (int i = 0; i < remote_ids.size(); ++i) { if (remote_ids[i].size() != 0) { RPCMessage msg; msg.service_id = service_id; msg.msg_seq = msg_seq; msg.client_id = client_id; int lower = i * group_count; int upper = (i + 1) * group_count; msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper); msg.data = pickle_data; NDArray tensor = dgl::aten::VecToIdArray(remote_ids[i]); msg.tensors.push_back(tensor); SendRPCMessage(msg, msg.server_id); msg_count++; } } local_data_shape[0] = ID_size; NDArray res_tensor = NDArray::Empty(local_data_shape, local_data->dtype, DLContext{kDLCPU, 0}); char* return_data = static_cast(res_tensor->data); // Copy local data parallel_for(0, local_ids.size(), [&](size_t b, size_t e) { for (auto i = b; i < e; ++i) { CHECK_GE(ID_size * row_size, local_ids_orginal[i] * row_size + row_size); CHECK_GE(data_size, local_ids[i] * row_size + row_size); CHECK_GE(local_ids[i], 0); memcpy(return_data + local_ids_orginal[i] * row_size, local_data_char + local_ids[i] * row_size, row_size); } }); // Recv remote message for (int i = 0; i < msg_count; ++i) { RPCMessage msg; RecvRPCMessage(&msg, 0); int part_id = msg.server_id / group_count; char* data_char = static_cast(msg.tensors[0]->data); dgl_id_t id_size = remote_ids[part_id].size(); for (size_t n = 0; n < id_size; ++n) { memcpy(return_data + remote_ids_original[part_id][n] * row_size, data_char + n * row_size, row_size); } } *rv = res_tensor; }); } // namespace rpc } // namespace dgl #endif