rpc_server.cc 1.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <chrono>
#include <stdexcept>
#include <thread>
#include <vector>

#include "rpc_base.h"

class RPCServer {
public:
  explicit RPCServer(const std::string &ip, int num_machines)
      : ip_(ip), num_machines_(num_machines) {}
  void run() {
    std::vector<std::thread> threads;
    for (int i = 0; i < kNumReceiver; ++i) {
      threads.push_back(std::thread(&RPCServer::StartServer, this, i));
    }
    for (auto &&t : threads) {
      t.join();
    }
  }

private:
  void StartServer(int id) {
    dgl::rpc::TPReceiver receiver(InitTPContext());
    std::string ip_addr =
        std::string{"tcp://"} + ip_ + ":" + std::to_string(kPort + id);
    if (!receiver.Wait(ip_addr, kNumSender * num_machines_, false)) {
      LOG(FATAL) << "Failed to wait on addr: " << ip_addr;
    }
    for (int n = 0; n < kNumSender * kNumMessage * num_machines_; ++n) {
      dgl::rpc::RPCMessage msg;
      receiver.Recv(&msg);
      bool eq = msg.data == std::string("123456789");
      eq = eq && (msg.tensors.size() == kNumTensor);
      for (int j = 0; j < kNumTensor; ++j) {
        eq = eq && (msg.tensors[j].ToVector<int>().size() == kSizeTensor);
      }
      if (!eq) {
        LOG(FATAL) << "Invalid received message";
      }
      if ((n + 1) % 1000 == 0) {
        LOG(INFO) << n + 1 << " RPCMessages have been received/verified on "
                  << ip_addr;
      }
    }
    receiver.Finalize();
  }
  const std::string ip_;
  const int num_machines_;
};

int main(int argc, char **argv) {
  if (argc != 3) {
    LOG(FATAL)
        << "Invalid call. Please call like this: ./rpc_server 4 127.0.0.1";
    return -1;
  }
  const int num_machines = std::atoi(argv[1]);
  const std::string ip{argv[2]};
  RPCServer server(ip, num_machines);
  server.run();

  return 0;
}