linkers_socket.cpp 7.55 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#ifdef USE_SOCKET

7
8
9
10
#include <LightGBM/config.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/text_reader.h>

11
#include <algorithm>
12
#include <chrono>
Guolin Ke's avatar
Guolin Ke committed
13
#include <cstring>
14
#include <string>
15
#include <thread>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
#include <unordered_map>
#include <unordered_set>
#include <vector>
19
20

#include "linkers.h"
Guolin Ke's avatar
Guolin Ke committed
21
22
23

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
24
Linkers::Linkers(Config config) {
25
  is_init_ = false;
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
  // start up socket
  TcpSocket::Startup();
  network_time_ = std::chrono::duration<double, std::milli>(0);
  num_machines_ = config.num_machines;
  local_listen_port_ = config.local_listen_port;
  socket_timeout_ = config.time_out;
  rank_ = -1;
zhangyafeikimi's avatar
zhangyafeikimi committed
33
  // parse clients from file
34
  ParseMachineList(config.machines, config.machine_list_filename);
Guolin Ke's avatar
Guolin Ke committed
35
36
37
38
39
40
41
42
43
44
45
46
47

  if (rank_ == -1) {
    // get ip list of local machine
    std::unordered_set<std::string> local_ip_list = TcpSocket::GetLocalIpList();
    // get local rank
    for (size_t i = 0; i < client_ips_.size(); ++i) {
      if (local_ip_list.count(client_ips_[i]) > 0 && client_ports_[i] == local_listen_port_) {
        rank_ = static_cast<int>(i);
        break;
      }
    }
  }
  if (rank_ == -1) {
48
    Log::Fatal("Machine list file doesn't contain the local machine");
Guolin Ke's avatar
Guolin Ke committed
49
50
  }
  // construct listener
Guolin Ke's avatar
Guolin Ke committed
51
  listener_ = std::unique_ptr<TcpSocket>(new TcpSocket());
Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
  TryBind(local_listen_port_);

  for (int i = 0; i < num_machines_; ++i) {
    linkers_.push_back(nullptr);
  }
57

Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
  // construct communication topo
  bruck_map_ = BruckMap::Construct(rank_, num_machines_);
  recursive_halving_map_ = RecursiveHalvingMap::Construct(rank_, num_machines_);

  // construct linkers
  Construct();
  // free listener
  listener_->Close();
66
  is_init_ = true;
Guolin Ke's avatar
Guolin Ke committed
67
68
69
}

Linkers::~Linkers() {
70
71
72
73
74
  if (is_init_) {
    for (size_t i = 0; i < linkers_.size(); ++i) {
      if (linkers_[i] != nullptr) {
        linkers_[i]->Close();
      }
Guolin Ke's avatar
Guolin Ke committed
75
    }
76
77
    TcpSocket::Finalize();
    Log::Info("Finished linking network in %f seconds", network_time_ * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
78
79
80
  }
}

81
82
83
84
85
86
87
88
89
90
91
void Linkers::ParseMachineList(const std::string& machines, const std::string& filename) {
  std::vector<std::string> lines;
  if (machines.empty()) {
    TextReader<size_t> machine_list_reader(filename.c_str(), false);
    machine_list_reader.ReadAllLines();
    if (machine_list_reader.Lines().empty()) {
      Log::Fatal("Machine list file %s doesn't exist", filename.c_str());
    }
    lines = machine_list_reader.Lines();
  } else {
    lines = Common::Split(machines.c_str(), ',');
Guolin Ke's avatar
Guolin Ke committed
92
  }
93
  for (auto& line : lines) {
Guolin Ke's avatar
Guolin Ke committed
94
    line = Common::Trim(line);
95
    if (line.find("rank=") != std::string::npos) {
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
      std::vector<std::string> str_after_split = Common::Split(line.c_str(), '=');
      Common::Atoi(str_after_split[1].c_str(), &rank_);
      continue;
    }
100
    std::vector<std::string> str_after_split = Common::Split(line.c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
101
    if (str_after_split.size() != 2) {
102
103
104
105
      str_after_split = Common::Split(line.c_str(), ':');
      if (str_after_split.size() != 2) {
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
106
107
    }
    if (client_ips_.size() >= static_cast<size_t>(num_machines_)) {
108
      Log::Warning("machine_list size is larger than the parameter num_machines, ignoring redundant entries");
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
      break;
    }
    str_after_split[0] = Common::Trim(str_after_split[0]);
    str_after_split[1] = Common::Trim(str_after_split[1]);
    client_ips_.push_back(str_after_split[0]);
    client_ports_.push_back(atoi(str_after_split[1].c_str()));
  }
Guolin Ke's avatar
Guolin Ke committed
116
  if (client_ips_.empty()) {
117
118
    Log::Fatal("Cannot find any ip and port.\n"
               "Please check machine_list_filename or machines parameter");
119
  }
Guolin Ke's avatar
Guolin Ke committed
120
  if (client_ips_.size() != static_cast<size_t>(num_machines_)) {
121
    Log::Warning("World size is larger than the machine_list size, change world size to %zu", client_ips_.size());
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
    num_machines_ = static_cast<int>(client_ips_.size());
  }
}

void Linkers::TryBind(int port) {
127
  Log::Info("Trying to bind port %d...", port);
Guolin Ke's avatar
Guolin Ke committed
128
  if (listener_->Bind(port)) {
129
    Log::Info("Binding port %d succeeded", port);
Guolin Ke's avatar
Guolin Ke committed
130
  } else {
131
    Log::Fatal("Binding port %d failed", port);
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
  }
}

void Linkers::SetLinker(int rank, const TcpSocket& socket) {
Guolin Ke's avatar
Guolin Ke committed
136
  linkers_[rank].reset(new TcpSocket(socket));
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
  // set timeout
  linkers_[rank]->SetTimeout(socket_timeout_ * 1000 * 60);
}

void Linkers::ListenThread(int incoming_cnt) {
142
  Log::Info("Listening...");
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
  char buffer[100];
  int connected_cnt = 0;
  while (connected_cnt < incoming_cnt) {
    // accept incoming socket
    TcpSocket handler = listener_->Accept();
    if (handler.IsClosed()) {
      continue;
    }
    // receive rank
    int read_cnt = 0;
    int size_of_int = static_cast<int>(sizeof(int));
    while (read_cnt < size_of_int) {
      int cur_read_cnt = handler.Recv(buffer + read_cnt, size_of_int - read_cnt);
      read_cnt += cur_read_cnt;
    }
    int* ptr_in_rank = reinterpret_cast<int*>(buffer);
    int in_rank = *ptr_in_rank;
160
161
162
    if (in_rank < 0 || in_rank >= num_machines_) {
      Log::Fatal("Invalid rank %d found during initialization of linkers. The world size is %d.", in_rank, num_machines_);
    }
Guolin Ke's avatar
Guolin Ke committed
163
164
165
166
167
168
169
170
171
    // add new socket
    SetLinker(in_rank, handler);
    ++connected_cnt;
  }
}

void Linkers::Construct() {
  // save ranks that need to connect with
  std::unordered_map<int, int> need_connect;
Guolin Ke's avatar
Guolin Ke committed
172
173
174
  for (int i = 0; i < num_machines_; ++i) {
    if (i != rank_) {
      need_connect[i] = 1;
Guolin Ke's avatar
Guolin Ke committed
175
176
177
178
179
180
181
182
183
    }
  }
  int incoming_cnt = 0;
  for (auto it = need_connect.begin(); it != need_connect.end(); ++it) {
    int machine_rank = it->first;
    if (machine_rank < rank_) {
      ++incoming_cnt;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
184

Guolin Ke's avatar
Guolin Ke committed
185
186
187
188
  // start listener
  listener_->SetTimeout(socket_timeout_);
  listener_->Listen(incoming_cnt);
  std::thread listen_thread(&Linkers::ListenThread, this, incoming_cnt);
189
190
191
  const int connect_fail_constant_factor = 20;
  const int connect_fail_retries_scale_factor = static_cast<int>(num_machines_ / connect_fail_constant_factor);
  const int connect_fail_retry_cnt = std::max(connect_fail_constant_factor, connect_fail_retries_scale_factor);
192
  const int connect_fail_retry_first_delay_interval = 200;  // 0.2 s
193
  const float connect_fail_retry_delay_factor = 1.3f;
Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
198
  // start connect
  for (auto it = need_connect.begin(); it != need_connect.end(); ++it) {
    int out_rank = it->first;
    // let smaller rank connect to larger rank
    if (out_rank > rank_) {
199
      int connect_fail_delay_time = connect_fail_retry_first_delay_interval;
Guolin Ke's avatar
Guolin Ke committed
200
      for (int i = 0; i < connect_fail_retry_cnt; ++i) {
201
        TcpSocket cur_socket;
Guolin Ke's avatar
Guolin Ke committed
202
        if (cur_socket.Connect(client_ips_[out_rank].c_str(), client_ports_[out_rank])) {
203
204
205
          // send local rank
          cur_socket.Send(reinterpret_cast<const char*>(&rank_), sizeof(rank_));
          SetLinker(out_rank, cur_socket);
Guolin Ke's avatar
Guolin Ke committed
206
207
          break;
        } else {
208
          Log::Warning("Connecting to rank %d failed, waiting for %d milliseconds", out_rank, connect_fail_delay_time);
209
          cur_socket.Close();
Guolin Ke's avatar
Guolin Ke committed
210
          std::this_thread::sleep_for(std::chrono::milliseconds(connect_fail_delay_time));
211
          connect_fail_delay_time = static_cast<int>(connect_fail_delay_time * connect_fail_retry_delay_factor);
Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        }
      }
    }
  }
  // wait for listener
  listen_thread.join();
  // print connected linkers
  PrintLinkers();
}

bool Linkers::CheckLinker(int rank) {
  if (linkers_[rank] == nullptr || linkers_[rank]->IsClosed()) {
    return false;
  }
  return true;
}

void Linkers::PrintLinkers() {
  for (int i = 0; i < num_machines_; ++i) {
    if (CheckLinker(i)) {
232
      Log::Info("Connected to rank %d", i);
Guolin Ke's avatar
Guolin Ke committed
233
234
235
236
237
238
239
    }
  }
}

}  // namespace LightGBM

#endif  // USE_SOCKET