socket_communicator_test.cc 4.46 KB
Newer Older
Chao Ma's avatar
Chao Ma committed
1
2
/*!
 *  Copyright (c) 2019 by Contributors
3
4
 * \file socket_communicator_test.cc
 * \brief Test SocketCommunicator
Chao Ma's avatar
Chao Ma committed
5
6
7
8
 */
#include <gtest/gtest.h>
#include <string.h>
#include <string>
9
10
#include <thread>
#include <vector>
Chao Ma's avatar
Chao Ma committed
11

12
13
#include "../src/rpc/network/msg_queue.h"
#include "../src/rpc/network/socket_communicator.h"
Chao Ma's avatar
Chao Ma committed
14
15

using std::string;
16

Chao Ma's avatar
Chao Ma committed
17
18
using dgl::network::SocketSender;
using dgl::network::SocketReceiver;
19
20
using dgl::network::Message;
using dgl::network::DefaultMessageDeleter;
Chao Ma's avatar
Chao Ma committed
21

22
const int64_t kQueueSize = 500 * 1024;
VoVAllen's avatar
VoVAllen committed
23
24
25

#ifndef WIN32

26
27
28
29
30
31
32
33
34
35
36
37
const int kNumSender = 3;
const int kNumReceiver = 3;
const int kNumMessage = 10;

const char* ip_addr[] = {
  "socket://127.0.0.1:50091",
  "socket://127.0.0.1:50092",
  "socket://127.0.0.1:50093"
};

static void start_client();
static void start_server(int id);
VoVAllen's avatar
VoVAllen committed
38

Chao Ma's avatar
Chao Ma committed
39
TEST(SocketCommunicatorTest, SendAndRecv) {
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
  // start 10 client
  std::vector<std::thread*> client_thread;
  for (int i = 0; i < kNumSender; ++i) {
    client_thread.push_back(new std::thread(start_client));
  }
  // start 10 server
  std::vector<std::thread*> server_thread;
  for (int i = 0; i < kNumReceiver; ++i) {
    server_thread.push_back(new std::thread(start_server, i));
  }
  for (int i = 0; i < kNumSender; ++i) {
    client_thread[i]->join();
  }
  for (int i = 0; i < kNumReceiver; ++i) {
    server_thread[i]->join();
  }
VoVAllen's avatar
VoVAllen committed
56
57
}

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
void start_client() {
  sleep(2); // wait server start
  SocketSender sender(kQueueSize);
  for (int i = 0; i < kNumReceiver; ++i) {
    sender.AddReceiver(ip_addr[i], i);
  }
  sender.Connect();
  for (int i = 0; i < kNumMessage; ++i) {
    for (int n = 0; n < kNumReceiver; ++n) {
      char* str_data = new char[9];
      memcpy(str_data, "123456789", 9);
      Message msg = {str_data, 9};
      msg.deallocator = DefaultMessageDeleter;
      EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);
    }
  }
  for (int i = 0; i < kNumMessage; ++i) {
    for (int n = 0; n < kNumReceiver; ++n) {
      char* str_data = new char[9];
      memcpy(str_data, "123456789", 9);
      Message msg = {str_data, 9};
      msg.deallocator = DefaultMessageDeleter;
      EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);
    }
  }
  sender.Finalize();
}

void start_server(int id) {
  SocketReceiver receiver(kQueueSize);
  receiver.Wait(ip_addr[id], kNumSender);
  for (int i = 0; i < kNumMessage; ++i) {
    for (int n = 0; n < kNumSender; ++n) {
      Message msg;
      EXPECT_EQ(receiver.RecvFrom(&msg, n), REMOVE_SUCCESS);
      EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
      msg.deallocator(&msg);
    }
  }
  for (int n = 0; n < kNumSender*kNumMessage; ++n) {
    Message msg;
    int recv_id;
    EXPECT_EQ(receiver.Recv(&msg, &recv_id), REMOVE_SUCCESS);
    EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
    msg.deallocator(&msg);
  }
  receiver.Finalize();
}

#else
VoVAllen's avatar
VoVAllen committed
108
109
110
111
112
113
114
115
116
117

#include <windows.h>
#include <winsock2.h>

#pragma comment(lib, "ws2_32.lib")

void sleep(int seconds) {
  Sleep(seconds * 1000);
}

118
119
120
static void start_client();
static bool start_server();

VoVAllen's avatar
VoVAllen committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
DWORD WINAPI _ClientThreadFunc(LPVOID param) {
  start_client();
  return 0;
}

DWORD WINAPI _ServerThreadFunc(LPVOID param) {
  return start_server() ? 1 : 0;
}

TEST(SocketCommunicatorTest, SendAndRecv) {
  HANDLE hThreads[2];
  WSADATA wsaData;
  DWORD retcode, exitcode;

  ASSERT_EQ(::WSAStartup(MAKEWORD(2, 2), &wsaData), 0);

  hThreads[0] = ::CreateThread(NULL, 0, _ClientThreadFunc, NULL, 0, NULL);  // client
  ASSERT_TRUE(hThreads[0] != NULL);
  hThreads[1] = ::CreateThread(NULL, 0, _ServerThreadFunc, NULL, 0, NULL);  // server
  ASSERT_TRUE(hThreads[1] != NULL);

  retcode = ::WaitForMultipleObjects(2, hThreads, TRUE, INFINITE);
  EXPECT_TRUE((retcode <= WAIT_OBJECT_0 + 1) && (retcode >= WAIT_OBJECT_0));

  EXPECT_EQ(::GetExitCodeThread(hThreads[1], &exitcode), TRUE);
  EXPECT_EQ(exitcode, 1);

  EXPECT_EQ(::CloseHandle(hThreads[0]), TRUE);
  EXPECT_EQ(::CloseHandle(hThreads[1]), TRUE);

  ::WSACleanup();
}

154
static void start_client() {
VoVAllen's avatar
VoVAllen committed
155
  sleep(1);
156
157
  SocketSender sender(kQueueSize);
  sender.AddReceiver("socket://127.0.0.1:8001", 0);
VoVAllen's avatar
VoVAllen committed
158
  sender.Connect();
159
160
161
162
163
  char* str_data = new char[9];
  memcpy(str_data, "123456789", 9);
  Message msg = {str_data, 9};
  msg.deallocator = DefaultMessageDeleter;
  sender.Send(msg, 0);
VoVAllen's avatar
VoVAllen committed
164
165
166
  sender.Finalize();
}

167
168
169
170
171
static bool start_server() {
  SocketReceiver receiver(kQueueSize);
  receiver.Wait("socket://127.0.0.1:8001", 1);
  Message msg;
  EXPECT_EQ(receiver.RecvFrom(&msg, 0), REMOVE_SUCCESS);
VoVAllen's avatar
VoVAllen committed
172
  receiver.Finalize();
173
  return string("123456789") == string(msg.data, msg.size);
VoVAllen's avatar
VoVAllen committed
174
}
175
176

#endif