/*! * Copyright (c) 2020 by Contributors * \file rpc/rpc.cc * \brief Implementation of RPC utilities used by both server and client sides. */ #include "./rpc.h" #include #include #include #include "../c_api_common.h" using dgl::network::StringPrintf; using namespace dgl::runtime; namespace dgl { namespace rpc { RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) { std::shared_ptr zerocopy_blob(new std::string()); StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true); zc_write_strm.Write(msg); int32_t ndarray_count = msg.tensors.size(); zerocopy_blob->append( reinterpret_cast(&ndarray_count), sizeof(int32_t)); network::Message rpc_meta_msg; rpc_meta_msg.data = const_cast(zerocopy_blob->data()); rpc_meta_msg.size = zerocopy_blob->size(); rpc_meta_msg.deallocator = [zerocopy_blob](network::Message*) {}; CHECK_EQ(RPCContext::ThreadLocal()->sender->Send( rpc_meta_msg, target_id), ADD_SUCCESS); // send real ndarray data for (auto ptr : zc_write_strm.buffer_list()) { network::Message ndarray_data_msg; ndarray_data_msg.data = reinterpret_cast(ptr.data); ndarray_data_msg.size = ptr.size; NDArray tensor = ptr.tensor; ndarray_data_msg.deallocator = [tensor](network::Message*) {}; CHECK_EQ(RPCContext::ThreadLocal()->sender->Send( ndarray_data_msg, target_id), ADD_SUCCESS); } return kRPCSuccess; } RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) { // ignore timeout now CHECK_EQ(timeout, 0) << "rpc cannot support timeout now."; network::Message rpc_meta_msg; int send_id; CHECK_EQ(RPCContext::ThreadLocal()->receiver->Recv( &rpc_meta_msg, &send_id), REMOVE_SUCCESS); char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t); int32_t ndarray_count = *(reinterpret_cast(count_ptr)); // Recv real ndarray data std::vector buffer_list(ndarray_count); for (int i = 0; i < ndarray_count; ++i) { network::Message ndarray_data_msg; CHECK_EQ(RPCContext::ThreadLocal()->receiver->RecvFrom( &ndarray_data_msg, send_id), REMOVE_SUCCESS); buffer_list[i] = ndarray_data_msg.data; } StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list); zc_read_strm.Read(msg); rpc_meta_msg.deallocator(&rpc_meta_msg); return kRPCSuccess; } //////////////////////////// C APIs //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") .set_body([] (DGLArgs args, DGLRetValue* rv) { int64_t msg_queue_size = args[0]; std::string type = args[1]; if (type.compare("socket") == 0) { RPCContext::ThreadLocal()->sender = std::make_shared(msg_queue_size); } else { LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type; } }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver") .set_body([] (DGLArgs args, DGLRetValue* rv) { int64_t msg_queue_size = args[0]; std::string type = args[1]; if (type.compare("socket") == 0) { RPCContext::ThreadLocal()->receiver = std::make_shared(msg_queue_size); } else { LOG(FATAL) << "Unknown communicator type for rpc sender: " << type; } }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender") .set_body([] (DGLArgs args, DGLRetValue* rv) { RPCContext::ThreadLocal()->sender->Finalize(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver") .set_body([] (DGLArgs args, DGLRetValue* rv) { RPCContext::ThreadLocal()->receiver->Finalize(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait") .set_body([] (DGLArgs args, DGLRetValue* rv) { std::string ip = args[0]; int port = args[1]; int num_sender = args[2]; std::string addr; if (RPCContext::ThreadLocal()->receiver->Type() == "socket") { addr = StringPrintf("socket://%s:%d", ip.c_str(), port); } else { LOG(FATAL) << "Unknown communicator type: " << RPCContext::ThreadLocal()->receiver->Type(); } if (RPCContext::ThreadLocal()->receiver->Wait(addr.c_str(), num_sender) == false) { LOG(FATAL) << "Wait sender socket failed."; } }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCAddReceiver") .set_body([] (DGLArgs args, DGLRetValue* rv) { std::string ip = args[0]; int port = args[1]; int recv_id = args[2]; std::string addr; if (RPCContext::ThreadLocal()->sender->Type() == "socket") { addr = StringPrintf("socket://%s:%d", ip.c_str(), port); } else { LOG(FATAL) << "Unknown communicator type: " << RPCContext::ThreadLocal()->sender->Type(); } RPCContext::ThreadLocal()->sender->AddReceiver(addr.c_str(), recv_id); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSenderConnect") .set_body([] (DGLArgs args, DGLRetValue* rv) { if (RPCContext::ThreadLocal()->sender->Connect() == false) { LOG(FATAL) << "Sender connection failed."; } }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank") .set_body([] (DGLArgs args, DGLRetValue* rv) { const int32_t rank = args[0]; RPCContext::ThreadLocal()->rank = rank; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::ThreadLocal()->rank; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer") .set_body([] (DGLArgs args, DGLRetValue* rv) { const int32_t num_servers = args[0]; *rv = RPCContext::ThreadLocal()->num_servers = num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::ThreadLocal()->num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine") .set_body([] (DGLArgs args, DGLRetValue* rv) { const int32_t num_servers = args[0]; *rv = RPCContext::ThreadLocal()->num_servers_per_machine = num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::ThreadLocal()->num_servers_per_machine; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = (RPCContext::ThreadLocal()->msg_seq)++; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::ThreadLocal()->msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq") .set_body([] (DGLArgs args, DGLRetValue* rv) { const int64_t msg_seq = args[0]; RPCContext::ThreadLocal()->msg_seq = msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::ThreadLocal()->machine_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID") .set_body([] (DGLArgs args, DGLRetValue* rv) { const int32_t machine_id = args[0]; RPCContext::ThreadLocal()->machine_id = machine_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines") .set_body([] (DGLArgs args, DGLRetValue* rv) { *rv = RPCContext::ThreadLocal()->num_machines; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines") .set_body([] (DGLArgs args, DGLRetValue* rv) { const int32_t num_machines = args[0]; RPCContext::ThreadLocal()->num_machines = num_machines; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage") .set_body([] (DGLArgs args, DGLRetValue* rv) { RPCMessageRef msg = args[0]; const int32_t target_id = args[1]; *rv = SendRPCMessage(*(msg.sptr()), target_id); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage") .set_body([] (DGLArgs args, DGLRetValue* rv) { int32_t timeout = args[0]; RPCMessageRef msg = args[1]; *rv = RecvRPCMessage(msg.sptr().get(), timeout); }); //////////////////////////// RPCMessage //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage") .set_body([] (DGLArgs args, DGLRetValue* rv) { std::shared_ptr rst(new RPCMessage); *rv = rst; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage") .set_body([] (DGLArgs args, DGLRetValue* rv) { std::shared_ptr rst(new RPCMessage); rst->service_id = args[0]; rst->msg_seq = args[1]; rst->client_id = args[2]; rst->server_id = args[3]; const std::string data = args[4]; // directly assigning string value raises errors :( rst->data = data; rst->tensors = ListValueToVector(args[5]); *rv = rst; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId") .set_body([] (DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->service_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq") .set_body([] (DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId") .set_body([] (DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->client_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId") .set_body([] (DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->server_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData") .set_body([] (DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; DGLByteArray barr{msg->data.c_str(), msg->data.size()}; *rv = barr; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors") .set_body([] (DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; List ret; for (size_t i = 0; i < msg->tensors.size(); ++i) { ret.push_back(Value(MakeValue(msg->tensors[i]))); } *rv = ret; }); //////////////////////////// ServerState //////////////////////////// DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState") .set_body([] (DGLArgs args, DGLRetValue* rv) { auto st = RPCContext::ThreadLocal()->server_state; if (st.get() == nullptr) { RPCContext::ThreadLocal()->server_state = std::make_shared(); } *rv = st; }); } // namespace rpc } // namespace dgl