linkers_socket.cpp 7.49 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#ifdef USE_SOCKET

7
8
9
10
#include <LightGBM/config.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/text_reader.h>

11
#include <algorithm>
12
#include <chrono>
Guolin Ke's avatar
Guolin Ke committed
13
#include <cstring>
14
#include <string>
15
#include <thread>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
#include <unordered_map>
#include <unordered_set>
#include <vector>
19
20

#include "linkers.h"
Guolin Ke's avatar
Guolin Ke committed
21
22
23

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
24
Linkers::Linkers(Config config) {
25
  is_init_ = false;
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
  // start up socket
  TcpSocket::Startup();
  network_time_ = std::chrono::duration<double, std::milli>(0);
  num_machines_ = config.num_machines;
  local_listen_port_ = config.local_listen_port;
  socket_timeout_ = config.time_out;
  rank_ = -1;
zhangyafeikimi's avatar
zhangyafeikimi committed
33
  // parse clients from file
34
  ParseMachineList(config.machines, config.machine_list_filename);
Guolin Ke's avatar
Guolin Ke committed
35
36
37
38
39
40
41
42
43
44
45
46
47

  if (rank_ == -1) {
    // get ip list of local machine
    std::unordered_set<std::string> local_ip_list = TcpSocket::GetLocalIpList();
    // get local rank
    for (size_t i = 0; i < client_ips_.size(); ++i) {
      if (local_ip_list.count(client_ips_[i]) > 0 && client_ports_[i] == local_listen_port_) {
        rank_ = static_cast<int>(i);
        break;
      }
    }
  }
  if (rank_ == -1) {
48
    Log::Fatal("Machine list file doesn't contain the local machine");
Guolin Ke's avatar
Guolin Ke committed
49
50
  }
  // construct listener
Guolin Ke's avatar
Guolin Ke committed
51
  listener_ = std::unique_ptr<TcpSocket>(new TcpSocket());
Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
  TryBind(local_listen_port_);

  for (int i = 0; i < num_machines_; ++i) {
    linkers_.push_back(nullptr);
  }
57

Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
  // construct communication topo
  bruck_map_ = BruckMap::Construct(rank_, num_machines_);
  recursive_halving_map_ = RecursiveHalvingMap::Construct(rank_, num_machines_);

  // construct linkers
  Construct();
  // free listener
  listener_->Close();
66
  is_init_ = true;
Guolin Ke's avatar
Guolin Ke committed
67
68
69
}

Linkers::~Linkers() {
70
71
72
73
74
  if (is_init_) {
    for (size_t i = 0; i < linkers_.size(); ++i) {
      if (linkers_[i] != nullptr) {
        linkers_[i]->Close();
      }
Guolin Ke's avatar
Guolin Ke committed
75
    }
76
77
    TcpSocket::Finalize();
    Log::Info("Finished linking network in %f seconds", network_time_ * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
78
79
80
  }
}

81
82
83
84
85
86
87
88
89
90
91
void Linkers::ParseMachineList(const std::string& machines, const std::string& filename) {
  std::vector<std::string> lines;
  if (machines.empty()) {
    TextReader<size_t> machine_list_reader(filename.c_str(), false);
    machine_list_reader.ReadAllLines();
    if (machine_list_reader.Lines().empty()) {
      Log::Fatal("Machine list file %s doesn't exist", filename.c_str());
    }
    lines = machine_list_reader.Lines();
  } else {
    lines = Common::Split(machines.c_str(), ',');
Guolin Ke's avatar
Guolin Ke committed
92
  }
93
  for (auto& line : lines) {
Guolin Ke's avatar
Guolin Ke committed
94
    line = Common::Trim(line);
95
    if (line.find("rank=") != std::string::npos) {
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
      std::vector<std::string> str_after_split = Common::Split(line.c_str(), '=');
      Common::Atoi(str_after_split[1].c_str(), &rank_);
      continue;
    }
100
    std::vector<std::string> str_after_split = Common::Split(line.c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
101
    if (str_after_split.size() != 2) {
102
103
104
105
      str_after_split = Common::Split(line.c_str(), ':');
      if (str_after_split.size() != 2) {
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
106
107
    }
    if (client_ips_.size() >= static_cast<size_t>(num_machines_)) {
108
      Log::Warning("machine_list size is larger than the parameter num_machines, ignoring redundant entries");
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
      break;
    }
    str_after_split[0] = Common::Trim(str_after_split[0]);
    str_after_split[1] = Common::Trim(str_after_split[1]);
    client_ips_.push_back(str_after_split[0]);
    client_ports_.push_back(atoi(str_after_split[1].c_str()));
  }
Guolin Ke's avatar
Guolin Ke committed
116
  if (client_ips_.empty()) {
117
118
    Log::Fatal("Cannot find any ip and port.\n"
               "Please check machine_list_filename or machines parameter");
119
  }
Guolin Ke's avatar
Guolin Ke committed
120
  if (client_ips_.size() != static_cast<size_t>(num_machines_)) {
121
    Log::Warning("World size is larger than the machine_list size, change world size to %zu", client_ips_.size());
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
    num_machines_ = static_cast<int>(client_ips_.size());
  }
}

void Linkers::TryBind(int port) {
127
  Log::Info("Trying to bind port %d...", port);
Guolin Ke's avatar
Guolin Ke committed
128
  if (listener_->Bind(port)) {
129
    Log::Info("Binding port %d succeeded", port);
Guolin Ke's avatar
Guolin Ke committed
130
  } else {
131
    Log::Fatal("Binding port %d failed", port);
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
  }
}

void Linkers::SetLinker(int rank, const TcpSocket& socket) {
Guolin Ke's avatar
Guolin Ke committed
136
  linkers_[rank].reset(new TcpSocket(socket));
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
  // set timeout
  linkers_[rank]->SetTimeout(socket_timeout_ * 1000 * 60);
}

void Linkers::ListenThread(int incoming_cnt) {
142
  Log::Info("Listening...");
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
  char buffer[100];
  int connected_cnt = 0;
  while (connected_cnt < incoming_cnt) {
    // accept incoming socket
    TcpSocket handler = listener_->Accept();
    if (handler.IsClosed()) {
      continue;
    }
    // receive rank
    int read_cnt = 0;
    int size_of_int = static_cast<int>(sizeof(int));
    while (read_cnt < size_of_int) {
      int cur_read_cnt = handler.Recv(buffer + read_cnt, size_of_int - read_cnt);
      read_cnt += cur_read_cnt;
    }
    int* ptr_in_rank = reinterpret_cast<int*>(buffer);
    int in_rank = *ptr_in_rank;
    // add new socket
    SetLinker(in_rank, handler);
    ++connected_cnt;
  }
}

void Linkers::Construct() {
  // save ranks that need to connect with
  std::unordered_map<int, int> need_connect;
Guolin Ke's avatar
Guolin Ke committed
169
170
171
  for (int i = 0; i < num_machines_; ++i) {
    if (i != rank_) {
      need_connect[i] = 1;
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
180
181
182
183
184
    }
  }
  int need_connect_cnt = 0;
  int incoming_cnt = 0;
  for (auto it = need_connect.begin(); it != need_connect.end(); ++it) {
    int machine_rank = it->first;
    if (machine_rank >= 0 && machine_rank != rank_) {
      ++need_connect_cnt;
    }
    if (machine_rank < rank_) {
      ++incoming_cnt;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
185

Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
  // start listener
  listener_->SetTimeout(socket_timeout_);
  listener_->Listen(incoming_cnt);
  std::thread listen_thread(&Linkers::ListenThread, this, incoming_cnt);
190
191
192
  const int connect_fail_constant_factor = 20;
  const int connect_fail_retries_scale_factor = static_cast<int>(num_machines_ / connect_fail_constant_factor);
  const int connect_fail_retry_cnt = std::max(connect_fail_constant_factor, connect_fail_retries_scale_factor);
193
  const int connect_fail_retry_first_delay_interval = 200;  // 0.2 s
194
  const float connect_fail_retry_delay_factor = 1.3f;
Guolin Ke's avatar
Guolin Ke committed
195
196
197
198
199
  // start connect
  for (auto it = need_connect.begin(); it != need_connect.end(); ++it) {
    int out_rank = it->first;
    // let smaller rank connect to larger rank
    if (out_rank > rank_) {
200
      int connect_fail_delay_time = connect_fail_retry_first_delay_interval;
Guolin Ke's avatar
Guolin Ke committed
201
      for (int i = 0; i < connect_fail_retry_cnt; ++i) {
202
        TcpSocket cur_socket;
Guolin Ke's avatar
Guolin Ke committed
203
        if (cur_socket.Connect(client_ips_[out_rank].c_str(), client_ports_[out_rank])) {
204
205
206
          // send local rank
          cur_socket.Send(reinterpret_cast<const char*>(&rank_), sizeof(rank_));
          SetLinker(out_rank, cur_socket);
Guolin Ke's avatar
Guolin Ke committed
207
208
          break;
        } else {
209
          Log::Warning("Connecting to rank %d failed, waiting for %d milliseconds", out_rank, connect_fail_delay_time);
210
          cur_socket.Close();
Guolin Ke's avatar
Guolin Ke committed
211
          std::this_thread::sleep_for(std::chrono::milliseconds(connect_fail_delay_time));
212
          connect_fail_delay_time = static_cast<int>(connect_fail_delay_time * connect_fail_retry_delay_factor);
Guolin Ke's avatar
Guolin Ke committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        }
      }
    }
  }
  // wait for listener
  listen_thread.join();
  // print connected linkers
  PrintLinkers();
}

bool Linkers::CheckLinker(int rank) {
  if (linkers_[rank] == nullptr || linkers_[rank]->IsClosed()) {
    return false;
  }
  return true;
}

void Linkers::PrintLinkers() {
  for (int i = 0; i < num_machines_; ++i) {
    if (CheckLinker(i)) {
233
      Log::Info("Connected to rank %d", i);
Guolin Ke's avatar
Guolin Ke committed
234
235
236
237
238
239
240
    }
  }
}

}  // namespace LightGBM

#endif  // USE_SOCKET