application.cpp 10.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
#include <LightGBM/application.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/text_reader.h>

#include <LightGBM/network.h>
#include <LightGBM/dataset.h>
Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
9
10
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
11
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
#include <LightGBM/metric.h>

#include "predictor.hpp"

16
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
21
22
23
24
25
26
27
28
29

#include <cstdio>
#include <ctime>

#include <chrono>
#include <fstream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
30
Application::Application(int argc, char** argv) {
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
35
  LoadParameters(argc, argv);
  // set number of threads for openmp
  if (config_.num_threads > 0) {
    omp_set_num_threads(config_.num_threads);
  }
36
  if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
37
    Log::Fatal("No training/prediction data, application quit");
38
  }
Guolin Ke's avatar
Guolin Ke committed
39
  omp_set_nested(0);
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43
44
45
46
47
48
49
}

Application::~Application() {
  if (config_.is_parallel) {
    Network::Dispose();
  }
}

void Application::LoadParameters(int argc, char** argv) {
  std::unordered_map<std::string, std::string> params;
Guolin Ke's avatar
Guolin Ke committed
50
  for (int i = 1; i < argc; ++i) {
wxchan's avatar
wxchan committed
51
    ConfigBase::KV2Map(params, argv[i]);
Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
  }
  // check for alias
  ParameterAlias::KeyAliasTransform(&params);
  // read parameters from config file
  if (params.count("config_file") > 0) {
Guolin Ke's avatar
Guolin Ke committed
57
    TextReader<size_t> config_reader(params["config_file"].c_str(), false);
Guolin Ke's avatar
Guolin Ke committed
58
    config_reader.ReadAllLines();
Guolin Ke's avatar
Guolin Ke committed
59
    if (!config_reader.Lines().empty()) {
Guolin Ke's avatar
Guolin Ke committed
60
      for (auto& line : config_reader.Lines()) {
61
        // remove str after "#"
Guolin Ke's avatar
Guolin Ke committed
62
63
64
        if (line.size() > 0 && std::string::npos != line.find_first_of("#")) {
          line.erase(line.find_first_of("#"));
        }
Guolin Ke's avatar
Guolin Ke committed
65
        line = Common::Trim(line);
Guolin Ke's avatar
Guolin Ke committed
66
        if (line.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
67
68
          continue;
        }
wxchan's avatar
wxchan committed
69
        ConfigBase::KV2Map(params, line.c_str());
Guolin Ke's avatar
Guolin Ke committed
70
71
      }
    } else {
72
      Log::Warning("Config file %s doesn't exist, will ignore",
Guolin Ke's avatar
Guolin Ke committed
73
                   params["config_file"].c_str());
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
79
    }
  }
  // check for alias again
  ParameterAlias::KeyAliasTransform(&params);
  // load configs
  config_.Set(params);
80
  Log::Info("Finished loading parameters");
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
}

void Application::LoadData() {
  auto start_time = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
85
  std::unique_ptr<Predictor> predictor;
86
  // prediction is needed if using input initial model(continued train)
Guolin Ke's avatar
Guolin Ke committed
87
  PredictFunction predict_fun = nullptr;
88
  PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
89
  // need to continue training
Guolin Ke's avatar
Guolin Ke committed
90
  if (boosting_->NumberOfTotalModel() > 0 && config_.task_type != TaskType::KRefitTree) {
91
    predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
Guolin Ke's avatar
Guolin Ke committed
92
    predict_fun = predictor->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
93
  }
94

Guolin Ke's avatar
Guolin Ke committed
95
96
  // sync up random seed for data partition
  if (config_.is_parallel_find_bin) {
Guolin Ke's avatar
Guolin Ke committed
97
    config_.io_config.data_random_seed = Network::GlobalSyncUpByMin(config_.io_config.data_random_seed);
Guolin Ke's avatar
Guolin Ke committed
98
  }
Guolin Ke's avatar
Guolin Ke committed
99

100
  DatasetLoader dataset_loader(config_.io_config, predict_fun,
Guolin Ke's avatar
Guolin Ke committed
101
                               config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
102
103
104
  // load Training data
  if (config_.is_parallel_find_bin) {
    // load data for parallel training
Guolin Ke's avatar
Guolin Ke committed
105
    train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(),
106
                                                  config_.io_config.initscore_filename.c_str(),
Guolin Ke's avatar
Guolin Ke committed
107
                                                  Network::rank(), Network::num_machines()));
Guolin Ke's avatar
Guolin Ke committed
108
109
  } else {
    // load data for single machine
110
111
    train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(),
                                                  0, 1));
Guolin Ke's avatar
Guolin Ke committed
112
113
114
  }
  // need save binary file
  if (config_.io_config.is_save_binary_file) {
Guolin Ke's avatar
Guolin Ke committed
115
    train_data_->SaveBinaryFile(nullptr);
Guolin Ke's avatar
Guolin Ke committed
116
117
  }
  // create training metric
Guolin Ke's avatar
Guolin Ke committed
118
  if (config_.boosting_config.is_provide_training_metric) {
Guolin Ke's avatar
Guolin Ke committed
119
    for (auto metric_type : config_.metric_types) {
Guolin Ke's avatar
Guolin Ke committed
120
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
Guolin Ke's avatar
Guolin Ke committed
121
      if (metric == nullptr) { continue; }
Guolin Ke's avatar
Guolin Ke committed
122
      metric->Init(train_data_->metadata(), train_data_->num_data());
Guolin Ke's avatar
Guolin Ke committed
123
      train_metric_.push_back(std::move(metric));
Guolin Ke's avatar
Guolin Ke committed
124
125
    }
  }
Guolin Ke's avatar
Guolin Ke committed
126
  train_metric_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
127

128

Guolin Ke's avatar
Guolin Ke committed
129
  if (!config_.metric_types.empty()) {
130
131
132
133
134
135
136
137
    // only when have metrics then need to construct validation data

    // Add validation data, if it exists
    for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
      // add
      auto new_dataset = std::unique_ptr<Dataset>(
        dataset_loader.LoadFromFileAlignWithOtherDataset(
          config_.io_config.valid_data_filenames[i].c_str(),
138
          config_.io_config.valid_data_initscores[i].c_str(),
139
140
141
142
143
144
145
146
147
148
149
150
151
152
          train_data_.get())
        );
      valid_datas_.push_back(std::move(new_dataset));
      // need save binary file
      if (config_.io_config.is_save_binary_file) {
        valid_datas_.back()->SaveBinaryFile(nullptr);
      }

      // add metric for validation data
      valid_metrics_.emplace_back();
      for (auto metric_type : config_.metric_types) {
        auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
        if (metric == nullptr) { continue; }
        metric->Init(valid_datas_.back()->metadata(),
Guolin Ke's avatar
Guolin Ke committed
153
                     valid_datas_.back()->num_data());
154
155
156
        valid_metrics_.back().push_back(std::move(metric));
      }
      valid_metrics_.back().shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
157
    }
158
159
    valid_datas_.shrink_to_fit();
    valid_metrics_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
160
161
162
  }
  auto end_time = std::chrono::high_resolution_clock::now();
  // output used time on each iteration
163
  Log::Info("Finished loading data in %f seconds",
Guolin Ke's avatar
Guolin Ke committed
164
            std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
170
}

void Application::InitTrain() {
  if (config_.is_parallel) {
    // need init network
    Network::Init(config_.network_config);
171
    Log::Info("Finished initializing network");
Guolin Ke's avatar
Guolin Ke committed
172
    config_.boosting_config.tree_config.feature_fraction_seed =
Guolin Ke's avatar
Guolin Ke committed
173
      Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction_seed);
Guolin Ke's avatar
Guolin Ke committed
174
    config_.boosting_config.tree_config.feature_fraction =
Guolin Ke's avatar
Guolin Ke committed
175
      Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction);
Guolin Ke's avatar
Guolin Ke committed
176
    config_.boosting_config.drop_seed =
Guolin Ke's avatar
Guolin Ke committed
177
      Network::GlobalSyncUpByMin(config_.boosting_config.drop_seed);
Guolin Ke's avatar
Guolin Ke committed
178
  }
Guolin Ke's avatar
Guolin Ke committed
179

Guolin Ke's avatar
Guolin Ke committed
180
  // create boosting
Guolin Ke's avatar
Guolin Ke committed
181
  boosting_.reset(
182
    Boosting::CreateBoosting(config_.boosting_type,
Guolin Ke's avatar
Guolin Ke committed
183
                             config_.io_config.input_model.c_str()));
Guolin Ke's avatar
Guolin Ke committed
184
  // create objective function
Guolin Ke's avatar
Guolin Ke committed
185
  objective_fun_.reset(
Guolin Ke's avatar
Guolin Ke committed
186
    ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
Guolin Ke's avatar
Guolin Ke committed
187
                                               config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
  // load training data
  LoadData();
  // initialize the objective function
  objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
  // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
193
  boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
194
                  Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
195
196
  // add validation data into boosting
  for (size_t i = 0; i < valid_datas_.size(); ++i) {
197
    boosting_->AddValidDataset(valid_datas_[i].get(),
Guolin Ke's avatar
Guolin Ke committed
198
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
Guolin Ke's avatar
Guolin Ke committed
199
  }
200
  Log::Info("Finished initializing training");
Guolin Ke's avatar
Guolin Ke committed
201
202
203
}

void Application::Train() {
204
  Log::Info("Started training...");
Guolin Ke's avatar
Guolin Ke committed
205
  boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
206
  boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
207
208
209
210
  // convert model to if-else statement code
  if (config_.convert_model_language == std::string("cpp")) {
    boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
  }
211
  Log::Info("Finished training");
Guolin Ke's avatar
Guolin Ke committed
212
213
214
}

void Application::Predict() {
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

  if (config_.task_type == TaskType::KRefitTree) {
    // create predictor
    Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
    predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str(), config_.io_config.has_header);
    TextReader<int> result_reader(config_.io_config.output_result.c_str(), false);
    result_reader.ReadAllLines();
    std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < static_cast<int>(result_reader.Lines().size()); ++i) {
      pred_leaf[i] = Common::StringToArray<int>(result_reader.Lines()[i], '\t');
      // Free memory
      result_reader.Lines()[i].clear();
    }
    DatasetLoader dataset_loader(config_.io_config, nullptr,
                                 config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
    train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(),
                                                  0, 1));
    train_metric_.clear();
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
                                                                    config_.objective_config));
    objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    boosting_->RefitTree(pred_leaf);
    boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
    Log::Info("Finished RefitTree");
  } else {
    // create predictor
    Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
                        config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib,
                        config_.io_config.pred_early_stop, config_.io_config.pred_early_stop_freq,
                        config_.io_config.pred_early_stop_margin);
    predictor.Predict(config_.io_config.data_filename.c_str(),
                      config_.io_config.output_result.c_str(), config_.io_config.has_header);
    Log::Info("Finished prediction");
  }
Guolin Ke's avatar
Guolin Ke committed
252
253
254
}

void Application::InitPredict() {
Guolin Ke's avatar
Guolin Ke committed
255
  boosting_.reset(
256
    Boosting::CreateBoosting("gbdt", config_.io_config.input_model.c_str()));
257
  Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration());
Guolin Ke's avatar
Guolin Ke committed
258
259
}

260
261
void Application::ConvertModel() {
  boosting_.reset(
262
    Boosting::CreateBoosting(config_.boosting_type, config_.io_config.input_model.c_str()));
263
264
265
  boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}

Guolin Ke's avatar
Guolin Ke committed
266
267

}  // namespace LightGBM