application.cpp 9.73 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include <LightGBM/application.h>

3
#include <LightGBM/boosting.h>
Guolin Ke's avatar
Guolin Ke committed
4
#include <LightGBM/dataset.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/dataset_loader.h>
6
7
#include <LightGBM/metric.h>
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/objective_function.h>
9
#include <LightGBM/prediction_early_stop.h>
10
#include <LightGBM/utils/common.h>
11
#include <LightGBM/utils/openmp_wrapper.h>
12
#include <LightGBM/utils/text_reader.h>
Guolin Ke's avatar
Guolin Ke committed
13

14
15
#include <string>
#include <chrono>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
19
20
#include <cstdio>
#include <ctime>
#include <fstream>
#include <sstream>
#include <utility>
21
22

#include "predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
23
24
25

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
26
Application::Application(int argc, char** argv) {
Guolin Ke's avatar
Guolin Ke committed
27
28
29
30
31
  LoadParameters(argc, argv);
  // set number of threads for openmp
  if (config_.num_threads > 0) {
    omp_set_num_threads(config_.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
32
  if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
33
    Log::Fatal("No training/prediction data, application quit");
34
  }
Guolin Ke's avatar
Guolin Ke committed
35
  omp_set_nested(0);
Guolin Ke's avatar
Guolin Ke committed
36
37
38
39
40
41
42
43
44
45
}

Application::~Application() {
  if (config_.is_parallel) {
    Network::Dispose();
  }
}

void Application::LoadParameters(int argc, char** argv) {
  std::unordered_map<std::string, std::string> params;
Guolin Ke's avatar
Guolin Ke committed
46
  for (int i = 1; i < argc; ++i) {
Guolin Ke's avatar
Guolin Ke committed
47
    Config::KV2Map(params, argv[i]);
Guolin Ke's avatar
Guolin Ke committed
48
49
50
51
  }
  // check for alias
  ParameterAlias::KeyAliasTransform(&params);
  // read parameters from config file
Guolin Ke's avatar
Guolin Ke committed
52
53
  if (params.count("config") > 0) {
    TextReader<size_t> config_reader(params["config"].c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
54
    config_reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
55
    if (!config_reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
56
      for (auto& line : config_reader.Lines()) {
57
        // remove str after "#"
Guolin Ke's avatar
Guolin Ke committed
58
59
60
        if (line.size() > 0 && std::string::npos != line.find_first_of("#")) {
          line.erase(line.find_first_of("#"));
        }
Guolin Ke's avatar
Guolin Ke committed
61
        line = Common::Trim(line);
Guolin Ke's avatar
Guolin Ke committed
62
        if (line.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
63
64
          continue;
        }
Guolin Ke's avatar
Guolin Ke committed
65
        Config::KV2Map(params, line.c_str());
Guolin Ke's avatar
Guolin Ke committed
66
67
      }
    } else {
68
      Log::Warning("Config file %s doesn't exist, will ignore",
Guolin Ke's avatar
Guolin Ke committed
69
                   params["config"].c_str());
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
    }
  }
  // check for alias again
  ParameterAlias::KeyAliasTransform(&params);
  // load configs
  config_.Set(params);
76
  Log::Info("Finished loading parameters");
Guolin Ke's avatar
Guolin Ke committed
77
78
79
80
}

void Application::LoadData() {
  auto start_time = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
81
  std::unique_ptr<Predictor> predictor;
82
  // prediction is needed if using input initial model(continued train)
Guolin Ke's avatar
Guolin Ke committed
83
  PredictFunction predict_fun = nullptr;
84
  PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
85
  // need to continue training
Guolin Ke's avatar
Guolin Ke committed
86
  if (boosting_->NumberOfTotalModel() > 0 && config_.task != TaskType::KRefitTree) {
87
    predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
Guolin Ke's avatar
Guolin Ke committed
88
    predict_fun = predictor->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
89
  }
90

Guolin Ke's avatar
Guolin Ke committed
91
92
  // sync up random seed for data partition
  if (config_.is_parallel_find_bin) {
Guolin Ke's avatar
Guolin Ke committed
93
    config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed);
Guolin Ke's avatar
Guolin Ke committed
94
  }
Guolin Ke's avatar
Guolin Ke committed
95

Guolin Ke's avatar
Guolin Ke committed
96
97
  DatasetLoader dataset_loader(config_, predict_fun,
                               config_.num_class, config_.data.c_str());
Guolin Ke's avatar
Guolin Ke committed
98
99
100
  // load Training data
  if (config_.is_parallel_find_bin) {
    // load data for parallel training
Guolin Ke's avatar
Guolin Ke committed
101
102
    train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(),
                                                  config_.initscore_filename.c_str(),
Guolin Ke's avatar
Guolin Ke committed
103
                                                  Network::rank(), Network::num_machines()));
Guolin Ke's avatar
Guolin Ke committed
104
105
  } else {
    // load data for single machine
Guolin Ke's avatar
Guolin Ke committed
106
    train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
107
                                                  0, 1));
Guolin Ke's avatar
Guolin Ke committed
108
109
  }
  // need save binary file
Guolin Ke's avatar
Guolin Ke committed
110
  if (config_.save_binary) {
Guolin Ke's avatar
Guolin Ke committed
111
    train_data_->SaveBinaryFile(nullptr);
Guolin Ke's avatar
Guolin Ke committed
112
113
  }
  // create training metric
Guolin Ke's avatar
Guolin Ke committed
114
115
116
  if (config_.is_provide_training_metric) {
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
117
      if (metric == nullptr) { continue; }
Guolin Ke's avatar
Guolin Ke committed
118
      metric->Init(train_data_->metadata(), train_data_->num_data());
Guolin Ke's avatar
Guolin Ke committed
119
      train_metric_.push_back(std::move(metric));
Guolin Ke's avatar
Guolin Ke committed
120
121
    }
  }
Guolin Ke's avatar
Guolin Ke committed
122
  train_metric_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
123

124

Guolin Ke's avatar
Guolin Ke committed
125
  if (!config_.metric.empty()) {
126
127
128
    // only when have metrics then need to construct validation data

    // Add validation data, if it exists
Guolin Ke's avatar
Guolin Ke committed
129
    for (size_t i = 0; i < config_.valid.size(); ++i) {
130
131
132
      // add
      auto new_dataset = std::unique_ptr<Dataset>(
        dataset_loader.LoadFromFileAlignWithOtherDataset(
Guolin Ke's avatar
Guolin Ke committed
133
134
          config_.valid[i].c_str(),
          config_.valid_data_initscores[i].c_str(),
135
          train_data_.get()));
136
137
      valid_datas_.push_back(std::move(new_dataset));
      // need save binary file
Guolin Ke's avatar
Guolin Ke committed
138
      if (config_.save_binary) {
139
140
141
142
143
        valid_datas_.back()->SaveBinaryFile(nullptr);
      }

      // add metric for validation data
      valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
144
145
      for (auto metric_type : config_.metric) {
        auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
146
147
        if (metric == nullptr) { continue; }
        metric->Init(valid_datas_.back()->metadata(),
Guolin Ke's avatar
Guolin Ke committed
148
                     valid_datas_.back()->num_data());
149
150
151
        valid_metrics_.back().push_back(std::move(metric));
      }
      valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
152
    }
153
154
    valid_datas_.shrink_to_fit();
    valid_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
155
156
157
  }
  auto end_time = std::chrono::high_resolution_clock::now();
  // output used time on each iteration
158
  Log::Info("Finished loading data in %f seconds",
Guolin Ke's avatar
Guolin Ke committed
159
            std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
}

void Application::InitTrain() {
  if (config_.is_parallel) {
    // need init network
Guolin Ke's avatar
Guolin Ke committed
165
    Network::Init(config_);
166
    Log::Info("Finished initializing network");
Guolin Ke's avatar
Guolin Ke committed
167
168
169
170
171
172
    config_.feature_fraction_seed =
      Network::GlobalSyncUpByMin(config_.feature_fraction_seed);
    config_.feature_fraction =
      Network::GlobalSyncUpByMin(config_.feature_fraction);
    config_.drop_seed =
      Network::GlobalSyncUpByMin(config_.drop_seed);
Guolin Ke's avatar
Guolin Ke committed
173
  }
Guolin Ke's avatar
Guolin Ke committed
174

Guolin Ke's avatar
Guolin Ke committed
175
  // create boosting
Guolin Ke's avatar
Guolin Ke committed
176
  boosting_.reset(
Guolin Ke's avatar
Guolin Ke committed
177
178
    Boosting::CreateBoosting(config_.boosting,
                             config_.input_model.c_str()));
Guolin Ke's avatar
Guolin Ke committed
179
  // create objective function
Guolin Ke's avatar
Guolin Ke committed
180
  objective_fun_.reset(
Guolin Ke's avatar
Guolin Ke committed
181
182
    ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                               config_));
Guolin Ke's avatar
Guolin Ke committed
183
184
185
186
187
  // load training data
  LoadData();
  // initialize the objective function
  objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
  // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
188
  boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
189
                  Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
190
191
  // add validation data into boosting
  for (size_t i = 0; i < valid_datas_.size(); ++i) {
192
    boosting_->AddValidDataset(valid_datas_[i].get(),
Guolin Ke's avatar
Guolin Ke committed
193
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
Guolin Ke's avatar
Guolin Ke committed
194
  }
195
  Log::Info("Finished initializing training");
Guolin Ke's avatar
Guolin Ke committed
196
197
198
}

void Application::Train() {
199
  Log::Info("Started training...");
Guolin Ke's avatar
Guolin Ke committed
200
  boosting_->Train(config_.snapshot_freq, config_.output_model);
201
  boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
202
203
  // convert model to if-else statement code
  if (config_.convert_model_language == std::string("cpp")) {
Guolin Ke's avatar
Guolin Ke committed
204
    boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
205
  }
206
  Log::Info("Finished training");
Guolin Ke's avatar
Guolin Ke committed
207
208
209
}

void Application::Predict() {
Guolin Ke's avatar
Guolin Ke committed
210
  if (config_.task == TaskType::KRefitTree) {
211
212
    // create predictor
    Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
Guolin Ke's avatar
Guolin Ke committed
213
214
    predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header);
    TextReader<int> result_reader(config_.output_result.c_str(), false);
215
216
217
218
219
220
221
222
    result_reader.ReadAllLines();
    std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < static_cast<int>(result_reader.Lines().size()); ++i) {
      pred_leaf[i] = Common::StringToArray<int>(result_reader.Lines()[i], '\t');
      // Free memory
      result_reader.Lines()[i].clear();
    }
Guolin Ke's avatar
Guolin Ke committed
223
224
225
    DatasetLoader dataset_loader(config_, nullptr,
                                 config_.num_class, config_.data.c_str());
    train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
226
227
                                                  0, 1));
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
228
229
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
230
    objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
Guolin Ke's avatar
Guolin Ke committed
231
    boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
232
233
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    boosting_->RefitTree(pred_leaf);
234
    boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
235
236
237
    Log::Info("Finished RefitTree");
  } else {
    // create predictor
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
    Predictor predictor(boosting_.get(), config_.num_iteration_predict, config_.predict_raw_score,
                        config_.predict_leaf_index, config_.predict_contrib,
                        config_.pred_early_stop, config_.pred_early_stop_freq,
                        config_.pred_early_stop_margin);
    predictor.Predict(config_.data.c_str(),
                      config_.output_result.c_str(), config_.header);
244
245
    Log::Info("Finished prediction");
  }
Guolin Ke's avatar
Guolin Ke committed
246
247
248
}

void Application::InitPredict() {
Guolin Ke's avatar
Guolin Ke committed
249
  boosting_.reset(
Guolin Ke's avatar
Guolin Ke committed
250
    Boosting::CreateBoosting("gbdt", config_.input_model.c_str()));
251
  Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration());
Guolin Ke's avatar
Guolin Ke committed
252
253
}

254
255
void Application::ConvertModel() {
  boosting_.reset(
Guolin Ke's avatar
Guolin Ke committed
256
257
    Boosting::CreateBoosting(config_.boosting, config_.input_model.c_str()));
  boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
258
259
}

Guolin Ke's avatar
Guolin Ke committed
260
261

}  // namespace LightGBM