application.cpp 10.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
#include <LightGBM/application.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/text_reader.h>

#include <LightGBM/network.h>
#include <LightGBM/dataset.h>
Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
9
10
11
12
13
14
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>

#include "predictor.hpp"

15
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
19
20
21
22
23
24
25
26
27
28

#include <cstdio>
#include <ctime>

#include <chrono>
#include <fstream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
29
Application::Application(int argc, char** argv) {
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
  LoadParameters(argc, argv);
  // set number of threads for openmp
  if (config_.num_threads > 0) {
    omp_set_num_threads(config_.num_threads);
  }
35
  if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
36
    Log::Fatal("No training/prediction data, application quit");
37
  }
Guolin Ke's avatar
Guolin Ke committed
38
  omp_set_nested(0);
Guolin Ke's avatar
Guolin Ke committed
39
40
41
42
43
44
45
46
47
48
}

Application::~Application() {
  if (config_.is_parallel) {
    Network::Dispose();
  }
}

void Application::LoadParameters(int argc, char** argv) {
  std::unordered_map<std::string, std::string> params;
Guolin Ke's avatar
Guolin Ke committed
49
  for (int i = 1; i < argc; ++i) {
Guolin Ke's avatar
Guolin Ke committed
50
51
    std::vector<std::string> tmp_strs = Common::Split(argv[i], '=');
    if (tmp_strs.size() == 2) {
52
53
      std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
      std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
      if (key.size() <= 0) {
        continue;
      }
      params[key] = value;
Guolin Ke's avatar
Guolin Ke committed
58
    } else {
Qiwei Ye's avatar
Qiwei Ye committed
59
      Log::Warning("Unknown parameter in command line: %s", argv[i]);
Guolin Ke's avatar
Guolin Ke committed
60
    }
Guolin Ke's avatar
Guolin Ke committed
61
62
63
64
65
  }
  // check for alias
  ParameterAlias::KeyAliasTransform(&params);
  // read parameters from config file
  if (params.count("config_file") > 0) {
Guolin Ke's avatar
Guolin Ke committed
66
    TextReader<size_t> config_reader(params["config_file"].c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
67
    config_reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
68
    if (!config_reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
69
      for (auto& line : config_reader.Lines()) {
70
        // remove str after "#"
Guolin Ke's avatar
Guolin Ke committed
71
72
73
        if (line.size() > 0 && std::string::npos != line.find_first_of("#")) {
          line.erase(line.find_first_of("#"));
        }
Guolin Ke's avatar
Guolin Ke committed
74
        line = Common::Trim(line);
Guolin Ke's avatar
Guolin Ke committed
75
        if (line.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
          continue;
        }
        std::vector<std::string> tmp_strs = Common::Split(line.c_str(), '=');
        if (tmp_strs.size() == 2) {
80
81
          std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
          std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
Guolin Ke's avatar
Guolin Ke committed
82
83
84
          if (key.size() <= 0) {
            continue;
          }
85
          // Command-line has higher priority
Guolin Ke's avatar
Guolin Ke committed
86
87
88
          if (params.count(key) == 0) {
            params[key] = value;
          }
Guolin Ke's avatar
Guolin Ke committed
89
        } else {
Qiwei Ye's avatar
Qiwei Ye committed
90
          Log::Warning("Unknown parameter in config file: %s", line.c_str());
Guolin Ke's avatar
Guolin Ke committed
91
        }
Guolin Ke's avatar
Guolin Ke committed
92
93
      }
    } else {
94
      Log::Warning("Config file %s doesn't exist, will ignore",
Guolin Ke's avatar
Guolin Ke committed
95
                   params["config_file"].c_str());
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
100
101
    }
  }
  // check for alias again
  ParameterAlias::KeyAliasTransform(&params);
  // load configs
  config_.Set(params);
102
  Log::Info("Finished loading parameters");
Guolin Ke's avatar
Guolin Ke committed
103
104
105
106
}

void Application::LoadData() {
  auto start_time = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
107
  std::unique_ptr<Predictor> predictor;
108
  // prediction is needed if using input initial model(continued train)
Guolin Ke's avatar
Guolin Ke committed
109
  PredictFunction predict_fun = nullptr;
110
  // need to continue training
Guolin Ke's avatar
Guolin Ke committed
111
  if (boosting_->NumberOfTotalModel() > 0) {
Guolin Ke's avatar
Guolin Ke committed
112
    predictor.reset(new Predictor(boosting_.get(), -1, true, false));
Guolin Ke's avatar
Guolin Ke committed
113
    predict_fun = predictor->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
114
  }
115

Guolin Ke's avatar
Guolin Ke committed
116
117
118
  // sync up random seed for data partition
  if (config_.is_parallel_find_bin) {
    config_.io_config.data_random_seed =
Guolin Ke's avatar
Guolin Ke committed
119
      GlobalSyncUpByMin<int>(config_.io_config.data_random_seed);
Guolin Ke's avatar
Guolin Ke committed
120
  }
Guolin Ke's avatar
Guolin Ke committed
121

122
  DatasetLoader dataset_loader(config_.io_config, predict_fun,
Guolin Ke's avatar
Guolin Ke committed
123
                               config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
124
125
126
  // load Training data
  if (config_.is_parallel_find_bin) {
    // load data for parallel training
Guolin Ke's avatar
Guolin Ke committed
127
    train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(),
Guolin Ke's avatar
Guolin Ke committed
128
                                                  Network::rank(), Network::num_machines()));
Guolin Ke's avatar
Guolin Ke committed
129
130
  } else {
    // load data for single machine
Guolin Ke's avatar
Guolin Ke committed
131
    train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), 0, 1));
Guolin Ke's avatar
Guolin Ke committed
132
133
134
  }
  // need save binary file
  if (config_.io_config.is_save_binary_file) {
Guolin Ke's avatar
Guolin Ke committed
135
    train_data_->SaveBinaryFile(nullptr);
Guolin Ke's avatar
Guolin Ke committed
136
137
  }
  // create training metric
Guolin Ke's avatar
Guolin Ke committed
138
  if (config_.boosting_config.is_provide_training_metric) {
Guolin Ke's avatar
Guolin Ke committed
139
    for (auto metric_type : config_.metric_types) {
Guolin Ke's avatar
Guolin Ke committed
140
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
Guolin Ke's avatar
Guolin Ke committed
141
      if (metric == nullptr) { continue; }
Guolin Ke's avatar
Guolin Ke committed
142
      metric->Init(train_data_->metadata(), train_data_->num_data());
Guolin Ke's avatar
Guolin Ke committed
143
      train_metric_.push_back(std::move(metric));
Guolin Ke's avatar
Guolin Ke committed
144
145
    }
  }
Guolin Ke's avatar
Guolin Ke committed
146
  train_metric_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
147

148

Guolin Ke's avatar
Guolin Ke committed
149
  if (!config_.metric_types.empty()) {
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    // only when have metrics then need to construct validation data

    // Add validation data, if it exists
    for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
      // add
      auto new_dataset = std::unique_ptr<Dataset>(
        dataset_loader.LoadFromFileAlignWithOtherDataset(
          config_.io_config.valid_data_filenames[i].c_str(),
          train_data_.get())
        );
      valid_datas_.push_back(std::move(new_dataset));
      // need save binary file
      if (config_.io_config.is_save_binary_file) {
        valid_datas_.back()->SaveBinaryFile(nullptr);
      }

      // add metric for validation data
      valid_metrics_.emplace_back();
      for (auto metric_type : config_.metric_types) {
        auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
        if (metric == nullptr) { continue; }
        metric->Init(valid_datas_.back()->metadata(),
Guolin Ke's avatar
Guolin Ke committed
172
                     valid_datas_.back()->num_data());
173
174
175
        valid_metrics_.back().push_back(std::move(metric));
      }
      valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
176
    }
177
178
    valid_datas_.shrink_to_fit();
    valid_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
179
180
181
  }
  auto end_time = std::chrono::high_resolution_clock::now();
  // output used time on each iteration
182
  Log::Info("Finished loading data in %f seconds",
Guolin Ke's avatar
Guolin Ke committed
183
            std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
188
189
}

void Application::InitTrain() {
  if (config_.is_parallel) {
    // need init network
    Network::Init(config_.network_config);
190
    Log::Info("Finished initializing network");
Guolin Ke's avatar
Guolin Ke committed
191
    // sync global random seed for feature patition
Guolin Ke's avatar
Guolin Ke committed
192
193
194
195
196
197
    config_.boosting_config.tree_config.feature_fraction_seed =
      GlobalSyncUpByMin<int>(config_.boosting_config.tree_config.feature_fraction_seed);
    config_.boosting_config.tree_config.feature_fraction =
      GlobalSyncUpByMin<double>(config_.boosting_config.tree_config.feature_fraction);
    config_.boosting_config.drop_seed =
      GlobalSyncUpByMin<int>(config_.boosting_config.drop_seed);
Guolin Ke's avatar
Guolin Ke committed
198
  }
Guolin Ke's avatar
Guolin Ke committed
199

Guolin Ke's avatar
Guolin Ke committed
200
  // create boosting
Guolin Ke's avatar
Guolin Ke committed
201
  boosting_.reset(
202
    Boosting::CreateBoosting(config_.boosting_type,
Guolin Ke's avatar
Guolin Ke committed
203
                             config_.io_config.input_model.c_str()));
Guolin Ke's avatar
Guolin Ke committed
204
  // create objective function
Guolin Ke's avatar
Guolin Ke committed
205
  objective_fun_.reset(
Guolin Ke's avatar
Guolin Ke committed
206
    ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
Guolin Ke's avatar
Guolin Ke committed
207
                                               config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
208
209
210
211
212
  // load training data
  LoadData();
  // initialize the objective function
  objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
  // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
213
  boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
214
                  Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
215
216
  // add validation data into boosting
  for (size_t i = 0; i < valid_datas_.size(); ++i) {
217
    boosting_->AddValidDataset(valid_datas_[i].get(),
Guolin Ke's avatar
Guolin Ke committed
218
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
Guolin Ke's avatar
Guolin Ke committed
219
  }
220
  Log::Info("Finished initializing training");
Guolin Ke's avatar
Guolin Ke committed
221
222
223
}

void Application::Train() {
224
  Log::Info("Started training...");
Guolin Ke's avatar
Guolin Ke committed
225
  int total_iter = config_.boosting_config.num_iterations;
226
227
  bool is_finished = false;
  bool need_eval = true;
Guolin Ke's avatar
Guolin Ke committed
228
  auto start_time = std::chrono::steady_clock::now();
229
230
  for (int iter = 0; iter < total_iter && !is_finished; ++iter) {
    is_finished = boosting_->TrainOneIter(nullptr, nullptr, need_eval);
Guolin Ke's avatar
Guolin Ke committed
231
    auto end_time = std::chrono::steady_clock::now();
232
    // output used time per iteration
233
    Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
Guolin Ke's avatar
Guolin Ke committed
234
              std::milli>(end_time - start_time) * 1e-3, iter + 1);
Guolin Ke's avatar
Guolin Ke committed
235
236
237
238
239
    if (config_.io_config.snapshot_freq > 0 
        && (iter+1) % config_.io_config.snapshot_freq == 0) {
      std::string snapshot_out = config_.io_config.output_model + ".snapshot_iter_" + std::to_string(iter + 1);
      boosting_->SaveModelToFile(-1, snapshot_out.c_str());
    }
240
241
  }
  // save model to file
242
  boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
243
244
245
246
  // convert model to if-else statement code
  if (config_.convert_model_language == std::string("cpp")) {
    boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
  }
247
  Log::Info("Finished training");
Guolin Ke's avatar
Guolin Ke committed
248
249
250
251
}

void Application::Predict() {
  // create predictor
Guolin Ke's avatar
Guolin Ke committed
252
253
  Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
                      config_.io_config.is_predict_leaf_index);
254
  predictor.Predict(config_.io_config.data_filename.c_str(),
Guolin Ke's avatar
Guolin Ke committed
255
                    config_.io_config.output_result.c_str(), config_.io_config.has_header);
256
  Log::Info("Finished prediction");
Guolin Ke's avatar
Guolin Ke committed
257
258
259
}

void Application::InitPredict() {
Guolin Ke's avatar
Guolin Ke committed
260
261
  boosting_.reset(
    Boosting::CreateBoosting(config_.io_config.input_model.c_str()));
262
  Log::Info("Finished initializing prediction");
Guolin Ke's avatar
Guolin Ke committed
263
264
}

265
266
267
268
269
270
271
void Application::ConvertModel() {
  boosting_.reset(
    Boosting::CreateBoosting(config_.boosting_type,
                             config_.io_config.input_model.c_str()));
  boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}

Guolin Ke's avatar
Guolin Ke committed
272
273
274
275
template<typename T>
T Application::GlobalSyncUpByMin(T& local) {
  T global = local;
  if (!config_.is_parallel) {
276
    // no need to sync if not parallel learning
Guolin Ke's avatar
Guolin Ke committed
277
278
279
    return global;
  }
  Network::Allreduce(reinterpret_cast<char*>(&local),
Guolin Ke's avatar
Guolin Ke committed
280
                     sizeof(local), sizeof(local),
Guolin Ke's avatar
Guolin Ke committed
281
                     reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
282
                     [](const char* src, char* dst, int len) {
Guolin Ke's avatar
Guolin Ke committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    int used_size = 0;
    const int type_size = sizeof(T);
    const T *p1;
    T *p2;
    while (used_size < len) {
      p1 = reinterpret_cast<const T *>(src);
      p2 = reinterpret_cast<T *>(dst);
      if (*p1 < *p2) {
        std::memcpy(dst, src, type_size);
      }
      src += type_size;
      dst += type_size;
      used_size += type_size;
    }
  });
  return global;
}

}  // namespace LightGBM