Unverified Commit df97f2e8 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistDGL] remove tensorpipe-based dist cpp test (#5849)

parent 0f322331
......@@ -320,16 +320,6 @@ if(BUILD_CPP_TEST)
target_link_libraries(runUnitTests gtest gtest_main)
target_link_libraries(runUnitTests dgl)
add_test(UnitTests runUnitTests)
if(NOT MSVC)
message(STATUS "Building dist/rpc tests")
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/dist/cpp/rpc_client.cc)
add_executable(rpc_client ${TEST_SRC_FILES})
target_link_libraries(rpc_client dgl)
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/dist/cpp/rpc_server.cc)
add_executable(rpc_server ${TEST_SRC_FILES})
target_link_libraries(rpc_server dgl)
endif(NOT MSVC)
endif(BUILD_CPP_TEST)
if(BUILD_SPARSE)
......
#ifndef DIST_TEST_RPC_BASE_H_
#define DIST_TEST_RPC_BASE_H_
#include <iostream>
#include <string>
#include "../../../src/rpc/rpc_msg.h"
#include "../../../src/rpc/tensorpipe/tp_communicator.h"
#include "../../cpp/common.h"
namespace {
const int kNumSender = 30;
const int kNumReceiver = 10;
const int kNumMessage = 1024;
const int kPort = 50090;
const int kSizeTensor = 10 * 1024;
const int kNumTensor = 30;
std::shared_ptr<tensorpipe::Context> InitTPContext() {
auto context = std::make_shared<tensorpipe::Context>();
auto transportContext = tensorpipe::transport::uv::create();
context->registerTransport(0 /* priority */, "tcp", transportContext);
auto basicChannel = tensorpipe::channel::basic::create();
context->registerChannel(0 /* low priority */, "basic", basicChannel);
return context;
}
} // namespace
#endif
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <stdexcept>
#include <thread>
#include <vector>
#include "rpc_base.h"
class RPCClient {
public:
explicit RPCClient(const std::string &ip_config) : ip_config_(ip_config) {
ParseIPs();
}
void Run() {
std::vector<std::thread> threads;
for (int i = 0; i < kNumSender; ++i) {
threads.push_back(std::thread(&RPCClient::StartClient, this));
}
for (auto &&t : threads) {
t.join();
}
}
private:
void ParseIPs() {
std::ifstream ifs(ip_config_);
if (!ifs) {
LOG(FATAL) << "Failed to open ip_config: " + ip_config_;
}
for (std::string line; std::getline(ifs, line);) {
std::cout << line << std::endl;
ips_.push_back(line);
}
}
void StartClient() {
dgl::rpc::TPSender sender(InitTPContext());
int recv_id = 0;
for (const auto &ip : ips_) {
for (int i = 0; i < kNumReceiver; ++i) {
const std::string ip_addr =
std::string{"tcp://"} + ip + ":" + std::to_string(kPort + i);
while (!sender.ConnectReceiver(ip_addr, recv_id)) {
std::this_thread::sleep_for(std::chrono::seconds(1));
}
++recv_id;
}
}
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < recv_id; ++n) {
dgl::rpc::RPCMessage msg;
msg.data = "123456789";
const auto tensor =
dgl::runtime::NDArray::FromVector(std::vector<int>(kSizeTensor, 1));
for (int j = 0; j < kNumTensor; ++j) {
msg.tensors.push_back(tensor);
}
sender.Send(msg, n);
}
}
sender.Finalize();
}
const std::string ip_config_;
std::vector<std::string> ips_;
};
int main(int argc, char **argv) {
if (argc != 2) {
LOG(FATAL)
<< "Invalid call. Please call like this: ./rpc_client ip_config.txt";
return -1;
}
RPCClient client(argv[1]);
client.Run();
return 0;
}
#include <chrono>
#include <stdexcept>
#include <thread>
#include <vector>
#include "rpc_base.h"
class RPCServer {
public:
explicit RPCServer(const std::string &ip, int num_machines)
: ip_(ip), num_machines_(num_machines) {}
void run() {
std::vector<std::thread> threads;
for (int i = 0; i < kNumReceiver; ++i) {
threads.push_back(std::thread(&RPCServer::StartServer, this, i));
}
for (auto &&t : threads) {
t.join();
}
}
private:
void StartServer(int id) {
dgl::rpc::TPReceiver receiver(InitTPContext());
std::string ip_addr =
std::string{"tcp://"} + ip_ + ":" + std::to_string(kPort + id);
if (!receiver.Wait(ip_addr, kNumSender * num_machines_, false)) {
LOG(FATAL) << "Failed to wait on addr: " << ip_addr;
}
for (int n = 0; n < kNumSender * kNumMessage * num_machines_; ++n) {
dgl::rpc::RPCMessage msg;
if (receiver.Recv(&msg, 0) != dgl::rpc::kRPCSuccess) {
LOG(FATAL) << "Failed to receive message on Server~" << id;
}
bool eq = msg.data == std::string("123456789");
eq = eq && (msg.tensors.size() == kNumTensor);
for (int j = 0; eq && j < kNumTensor; ++j) {
eq = eq && (msg.tensors[j].ToVector<int>().size() == kSizeTensor);
}
if (!eq) {
LOG(FATAL) << "Invalid received message";
}
if ((n + 1) % 1000 == 0) {
LOG(INFO) << n + 1 << " RPCMessages have been received/verified on "
<< ip_addr;
}
}
receiver.Finalize();
}
const std::string ip_;
const int num_machines_;
};
int main(int argc, char **argv) {
if (argc != 3) {
LOG(FATAL)
<< "Invalid call. Please call like this: ./rpc_server 4 127.0.0.1";
return -1;
}
const int num_machines = std::atoi(argv[1]);
const std::string ip{argv[2]};
RPCServer server(ip, num_machines);
server.run();
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment