application.cpp 10.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/application.h>

7
#include <LightGBM/boosting.h>
Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/dataset.h>
Guolin Ke's avatar
Guolin Ke committed
9
#include <LightGBM/dataset_loader.h>
10
11
#include <LightGBM/metric.h>
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
12
#include <LightGBM/objective_function.h>
13
#include <LightGBM/prediction_early_stop.h>
14
#include <LightGBM/cuda/vector_cudahost.h>
15
#include <LightGBM/utils/common.h>
16
#include <LightGBM/utils/openmp_wrapper.h>
17
#include <LightGBM/utils/text_reader.h>
Guolin Ke's avatar
Guolin Ke committed
18

19
20
21
22
#include <chrono>
#include <cstdio>
#include <ctime>
#include <fstream>
23
#include <memory>
24
#include <sstream>
25
26
#include <string>
#include <unordered_map>
27
#include <utility>
28
#include <vector>
29

30
#include "predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
31
32
33

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
34
Application::Application(int argc, char** argv) {
Guolin Ke's avatar
Guolin Ke committed
35
36
  LoadParameters(argc, argv);
  // set number of threads for openmp
37
  OMP_SET_NUM_THREADS(config_.num_threads);
Guolin Ke's avatar
Guolin Ke committed
38
  if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
39
    Log::Fatal("No training/prediction data, application quit");
40
  }
41

42
  if (config_.device_type == std::string("cuda")) {
43
44
      LGBM_config_::current_device = lgbm_device_cuda;
  }
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
49
50
51
52
53
}

Application::~Application() {
  if (config_.is_parallel) {
    Network::Dispose();
  }
}

void Application::LoadParameters(int argc, char** argv) {
54
  std::unordered_map<std::string, std::vector<std::string>> all_params;
Guolin Ke's avatar
Guolin Ke committed
55
  std::unordered_map<std::string, std::string> params;
Guolin Ke's avatar
Guolin Ke committed
56
  for (int i = 1; i < argc; ++i) {
57
    Config::KV2Map(&all_params, argv[i]);
Guolin Ke's avatar
Guolin Ke committed
58
59
  }
  // read parameters from config file
60
61
62
  bool config_file_ok = true;
  if (all_params.count("config") > 0) {
    TextReader<size_t> config_reader(all_params["config"][0].c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
63
    config_reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
64
    if (!config_reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
65
      for (auto& line : config_reader.Lines()) {
66
        // remove str after "#"
Guolin Ke's avatar
Guolin Ke committed
67
68
69
        if (line.size() > 0 && std::string::npos != line.find_first_of("#")) {
          line.erase(line.find_first_of("#"));
        }
Guolin Ke's avatar
Guolin Ke committed
70
        line = Common::Trim(line);
Guolin Ke's avatar
Guolin Ke committed
71
        if (line.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
72
73
          continue;
        }
74
        Config::KV2Map(&all_params, line.c_str());
Guolin Ke's avatar
Guolin Ke committed
75
76
      }
    } else {
77
      config_file_ok = false;
Guolin Ke's avatar
Guolin Ke committed
78
79
    }
  }
80
81
82
83
84
85
  Config::SetVerbosity(all_params);
  // de-duplicate params
  Config::KeepFirstValues(all_params, &params);
  if (!config_file_ok) {
    Log::Warning("Config file %s doesn't exist, will ignore", params["config"].c_str());
  }
Guolin Ke's avatar
Guolin Ke committed
86
87
  ParameterAlias::KeyAliasTransform(&params);
  config_.Set(params);
88
  Log::Info("Finished loading parameters");
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
}

void Application::LoadData() {
  auto start_time = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
93
  std::unique_ptr<Predictor> predictor;
94
  // prediction is needed if using input initial model(continued train)
Guolin Ke's avatar
Guolin Ke committed
95
  PredictFunction predict_fun = nullptr;
96
  // need to continue training
Guolin Ke's avatar
Guolin Ke committed
97
  if (boosting_->NumberOfTotalModel() > 0 && config_.task != TaskType::KRefitTree) {
98
    predictor.reset(new Predictor(boosting_.get(), 0, -1, true, false, false, false, -1, -1));
Guolin Ke's avatar
Guolin Ke committed
99
    predict_fun = predictor->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
100
  }
101

Guolin Ke's avatar
Guolin Ke committed
102
  // sync up random seed for data partition
103
  if (config_.is_data_based_parallel) {
Guolin Ke's avatar
Guolin Ke committed
104
    config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed);
Guolin Ke's avatar
Guolin Ke committed
105
  }
Guolin Ke's avatar
Guolin Ke committed
106

107
  Log::Debug("Loading train file...");
Guolin Ke's avatar
Guolin Ke committed
108
109
  DatasetLoader dataset_loader(config_, predict_fun,
                               config_.num_class, config_.data.c_str());
Guolin Ke's avatar
Guolin Ke committed
110
  // load Training data
111
  if (config_.is_data_based_parallel) {
112
    // load data for distributed training
Guolin Ke's avatar
Guolin Ke committed
113
    train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(),
Guolin Ke's avatar
Guolin Ke committed
114
                                                  Network::rank(), Network::num_machines()));
Guolin Ke's avatar
Guolin Ke committed
115
116
  } else {
    // load data for single machine
117
    train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), 0, 1));
Guolin Ke's avatar
Guolin Ke committed
118
119
  }
  // need save binary file
Guolin Ke's avatar
Guolin Ke committed
120
  if (config_.save_binary) {
Guolin Ke's avatar
Guolin Ke committed
121
    train_data_->SaveBinaryFile(nullptr);
Guolin Ke's avatar
Guolin Ke committed
122
123
  }
  // create training metric
Guolin Ke's avatar
Guolin Ke committed
124
125
126
  if (config_.is_provide_training_metric) {
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
127
128
129
      if (metric == nullptr) {
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
130
      metric->Init(train_data_->metadata(), train_data_->num_data());
Guolin Ke's avatar
Guolin Ke committed
131
      train_metric_.push_back(std::move(metric));
Guolin Ke's avatar
Guolin Ke committed
132
133
    }
  }
Guolin Ke's avatar
Guolin Ke committed
134
  train_metric_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
135

Guolin Ke's avatar
Guolin Ke committed
136
  if (!config_.metric.empty()) {
137
138
139
    // only when have metrics then need to construct validation data

    // Add validation data, if it exists
Guolin Ke's avatar
Guolin Ke committed
140
    for (size_t i = 0; i < config_.valid.size(); ++i) {
141
      Log::Debug("Loading validation file #%zu...", (i + 1));
142
143
144
      // add
      auto new_dataset = std::unique_ptr<Dataset>(
        dataset_loader.LoadFromFileAlignWithOtherDataset(
Guolin Ke's avatar
Guolin Ke committed
145
          config_.valid[i].c_str(),
146
          train_data_.get()));
147
148
      valid_datas_.push_back(std::move(new_dataset));
      // need save binary file
Guolin Ke's avatar
Guolin Ke committed
149
      if (config_.save_binary) {
150
151
152
153
154
        valid_datas_.back()->SaveBinaryFile(nullptr);
      }

      // add metric for validation data
      valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
155
156
      for (auto metric_type : config_.metric) {
        auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
157
158
159
        if (metric == nullptr) {
          continue;
        }
160
        metric->Init(valid_datas_.back()->metadata(),
Guolin Ke's avatar
Guolin Ke committed
161
                     valid_datas_.back()->num_data());
162
163
164
        valid_metrics_.back().push_back(std::move(metric));
      }
      valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
165
    }
166
167
    valid_datas_.shrink_to_fit();
    valid_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
168
169
170
  }
  auto end_time = std::chrono::high_resolution_clock::now();
  // output used time on each iteration
171
  Log::Info("Finished loading data in %f seconds",
Guolin Ke's avatar
Guolin Ke committed
172
            std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
177
}

void Application::InitTrain() {
  if (config_.is_parallel) {
    // need init network
Guolin Ke's avatar
Guolin Ke committed
178
    Network::Init(config_);
179
    Log::Info("Finished initializing network");
Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
184
185
    config_.feature_fraction_seed =
      Network::GlobalSyncUpByMin(config_.feature_fraction_seed);
    config_.feature_fraction =
      Network::GlobalSyncUpByMin(config_.feature_fraction);
    config_.drop_seed =
      Network::GlobalSyncUpByMin(config_.drop_seed);
Guolin Ke's avatar
Guolin Ke committed
186
  }
Guolin Ke's avatar
Guolin Ke committed
187

Guolin Ke's avatar
Guolin Ke committed
188
  // create boosting
Guolin Ke's avatar
Guolin Ke committed
189
  boosting_.reset(
Guolin Ke's avatar
Guolin Ke committed
190
191
    Boosting::CreateBoosting(config_.boosting,
                             config_.input_model.c_str()));
Guolin Ke's avatar
Guolin Ke committed
192
  // create objective function
Guolin Ke's avatar
Guolin Ke committed
193
  objective_fun_.reset(
Guolin Ke's avatar
Guolin Ke committed
194
195
    ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                               config_));
Guolin Ke's avatar
Guolin Ke committed
196
197
  // load training data
  LoadData();
198
199
200
201
  if (config_.task == TaskType::kSaveBinary) {
    Log::Info("Save data as binary finished, exit");
    exit(0);
  }
Guolin Ke's avatar
Guolin Ke committed
202
203
204
  // initialize the objective function
  objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
  // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
205
  boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
206
                  Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
207
208
  // add validation data into boosting
  for (size_t i = 0; i < valid_datas_.size(); ++i) {
209
    boosting_->AddValidDataset(valid_datas_[i].get(),
Guolin Ke's avatar
Guolin Ke committed
210
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
211
    Log::Debug("Number of data points in validation set #%zu: %d", i + 1, valid_datas_[i]->num_data());
Guolin Ke's avatar
Guolin Ke committed
212
  }
213
  Log::Info("Finished initializing training");
Guolin Ke's avatar
Guolin Ke committed
214
215
216
}

void Application::Train() {
217
  Log::Info("Started training...");
Guolin Ke's avatar
Guolin Ke committed
218
  boosting_->Train(config_.snapshot_freq, config_.output_model);
219
220
  boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
                             config_.output_model.c_str());
221
222
  // convert model to if-else statement code
  if (config_.convert_model_language == std::string("cpp")) {
Guolin Ke's avatar
Guolin Ke committed
223
    boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
224
  }
225
  Log::Info("Finished training");
Guolin Ke's avatar
Guolin Ke committed
226
227
228
}

void Application::Predict() {
Guolin Ke's avatar
Guolin Ke committed
229
  if (config_.task == TaskType::KRefitTree) {
230
    // create predictor
231
    Predictor predictor(boosting_.get(), 0, -1, false, true, false, false, 1, 1);
Chen Yufei's avatar
Chen Yufei committed
232
233
    predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check,
                      config_.precise_float_parser);
Guolin Ke's avatar
Guolin Ke committed
234
    TextReader<int> result_reader(config_.output_result.c_str(), false);
235
    result_reader.ReadAllLines();
236
237
238
239
240
241
242
243
244

    size_t nrow = result_reader.Lines().size();
    size_t ncol = 0;
    if (nrow > 0) {
      ncol = Common::StringToArray<int>(result_reader.Lines()[0], '\t').size();
    }
    std::vector<int> pred_leaf;
    pred_leaf.resize(nrow * ncol);

245
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
246
247
248
249
250
251
    for (int irow = 0; irow < static_cast<int>(nrow); ++irow) {
      auto line_vec = Common::StringToArray<int>(result_reader.Lines()[irow], '\t');
      CHECK_EQ(line_vec.size(), ncol);
      for (int i_row_item = 0; i_row_item < static_cast<int>(ncol); ++i_row_item) {
        pred_leaf[irow * ncol + i_row_item] = line_vec[i_row_item];
      }
252
      // Free memory
253
      result_reader.Lines()[irow].clear();
254
    }
Guolin Ke's avatar
Guolin Ke committed
255
256
    DatasetLoader dataset_loader(config_, nullptr,
                                 config_.num_class, config_.data.c_str());
257
    train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), 0, 1));
258
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
259
260
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
261
    objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
Guolin Ke's avatar
Guolin Ke committed
262
    boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
263
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
264
265

    boosting_->RefitTree(pred_leaf.data(), nrow, ncol);
266
267
    boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
                               config_.output_model.c_str());
268
269
270
    Log::Info("Finished RefitTree");
  } else {
    // create predictor
271
    Predictor predictor(boosting_.get(), config_.start_iteration_predict, config_.num_iteration_predict, config_.predict_raw_score,
Guolin Ke's avatar
Guolin Ke committed
272
273
274
275
                        config_.predict_leaf_index, config_.predict_contrib,
                        config_.pred_early_stop, config_.pred_early_stop_freq,
                        config_.pred_early_stop_margin);
    predictor.Predict(config_.data.c_str(),
Chen Yufei's avatar
Chen Yufei committed
276
277
                      config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check,
                      config_.precise_float_parser);
278
279
    Log::Info("Finished prediction");
  }
Guolin Ke's avatar
Guolin Ke committed
280
281
282
}

void Application::InitPredict() {
Guolin Ke's avatar
Guolin Ke committed
283
  boosting_.reset(
Guolin Ke's avatar
Guolin Ke committed
284
    Boosting::CreateBoosting("gbdt", config_.input_model.c_str()));
285
  Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration());
Guolin Ke's avatar
Guolin Ke committed
286
287
}

288
289
void Application::ConvertModel() {
  boosting_.reset(
Guolin Ke's avatar
Guolin Ke committed
290
291
    Boosting::CreateBoosting(config_.boosting, config_.input_model.c_str()));
  boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
292
293
}

Guolin Ke's avatar
Guolin Ke committed
294
295

}  // namespace LightGBM