config.cpp 12.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>

#include <vector>
#include <string>
#include <unordered_map>
#include <algorithm>

namespace LightGBM {

13
14
15
16
17
18
19
20
21
22
23
24
25
void OverallConfig::LoadFromString(const char* str) {
  std::unordered_map<std::string, std::string> params;
  auto args = Common::Split(str, " \t\n\r");
  for (auto arg : args) {
    std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
      std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
      if (key.size() <= 0) {
        continue;
      }
      params[key] = value;
    } else {
Qiwei Ye's avatar
Qiwei Ye committed
26
      Log::Warning("Unknown parameter %s", arg.c_str());
27
28
29
30
31
32
    }
  }
  ParameterAlias::KeyAliasTransform(&params);
  Set(params);
}

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  // load main config types
  GetInt(params, "num_threads", &num_threads);
  GetTaskType(params);
  GetBoostingType(params);
  GetObjectiveType(params);
  GetMetricType(params);


  // sub-config setup
  network_config.Set(params);
  io_config.Set(params);

Guolin Ke's avatar
Guolin Ke committed
46
  boosting_config.Set(params);
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
  objective_config.Set(params);
  metric_config.Set(params);
  // check for conflicts
  CheckParamConflict();
Qiwei Ye's avatar
Qiwei Ye committed
51

Guolin Ke's avatar
Guolin Ke committed
52
  if (io_config.verbosity == 1) {
Qiwei Ye's avatar
Qiwei Ye committed
53
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
54
55
  }
  else if (io_config.verbosity == 0) {
Qiwei Ye's avatar
Qiwei Ye committed
56
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
57
58
  }
  else if (io_config.verbosity >= 2) {
Qiwei Ye's avatar
Qiwei Ye committed
59
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
Guolin Ke's avatar
Guolin Ke committed
60
61
  }
  else {
Qiwei Ye's avatar
Qiwei Ye committed
62
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
Guolin Ke's avatar
Guolin Ke committed
63
  }
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
68
}

void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "boosting_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
69
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
70
71
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
      boosting_type = BoostingType::kGBDT;
72
73
    } else if (value == std::string("dart")) {
      boosting_type = BoostingType::kDART;
Guolin Ke's avatar
Guolin Ke committed
74
    } else {
75
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
81
82
    }
  }
}

void OverallConfig::GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
83
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88
89
90
91
92
93
    objective_type = value;
  }
}

void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "metric", &value)) {
    // clear old metrics
    metric_types.clear();
    // to lower
Guolin Ke's avatar
Guolin Ke committed
94
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
    // split
    std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
    // remove dumplicate
    std::unordered_map<std::string, int> metric_maps;
    for (auto& metric : metrics) {
Guolin Ke's avatar
Guolin Ke committed
100
      std::transform(metric.begin(), metric.end(), metric.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
105
106
107
108
      if (metric_maps.count(metric) <= 0) {
        metric_maps[metric] = 1;
      }
    }
    for (auto& pair : metric_maps) {
      std::string sub_metric_str = pair.first;
      metric_types.push_back(sub_metric_str);
    }
Guolin Ke's avatar
Guolin Ke committed
109
    metric_types.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
116
  }
}


void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
117
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
118
119
120
121
122
123
    if (value == std::string("train") || value == std::string("training")) {
      task_type = TaskType::kTrain;
    } else if (value == std::string("predict") || value == std::string("prediction")
      || value == std::string("test")) {
      task_type = TaskType::kPredict;
    } else {
124
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
    }
  }
}

void OverallConfig::CheckParamConflict() {
130

131
132
  // check if objective_type, metric_type, and num_class match
  bool objective_type_multiclass = (objective_type == std::string("multiclass"));
Guolin Ke's avatar
Guolin Ke committed
133
  int num_class_check = boosting_config.num_class;
134
135
  if (objective_type_multiclass){
      if (num_class_check <= 1){
136
          Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
137
138
139
140
      }
  }
  else {
      if (task_type == TaskType::kTrain && num_class_check != 1){
141
142
          Log::Fatal("Number of classes must be 1 for non-multiclass training");
      }
143
144
145
  }
  for (std::string metric_type : metric_types){
        bool metric_type_multiclass = ( metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
146
        if ((objective_type_multiclass && !metric_type_multiclass)
147
            || (!objective_type_multiclass && metric_type_multiclass)){
148
            Log::Fatal("Objective and metrics don't match");
149
        }
150
  }
151

Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
  if (network_config.num_machines > 1) {
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
156
    boosting_config.tree_learner_type = TreeLearnerType::kSerialTreeLearner;
Guolin Ke's avatar
Guolin Ke committed
157
158
  }

Guolin Ke's avatar
Guolin Ke committed
159
  if (boosting_config.tree_learner_type == TreeLearnerType::kSerialTreeLearner) {
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
    is_parallel = false;
    network_config.num_machines = 1;
  }

Guolin Ke's avatar
Guolin Ke committed
164
165
  if (boosting_config.tree_learner_type == TreeLearnerType::kSerialTreeLearner ||
    boosting_config.tree_learner_type == TreeLearnerType::kFeatureParallelTreelearner) {
Guolin Ke's avatar
Guolin Ke committed
166
    is_parallel_find_bin = false;
Guolin Ke's avatar
Guolin Ke committed
167
  } else if (boosting_config.tree_learner_type == TreeLearnerType::kDataParallelTreeLearner) {
Guolin Ke's avatar
Guolin Ke committed
168
    is_parallel_find_bin = true;
Guolin Ke's avatar
Guolin Ke committed
169
    if (boosting_config.tree_config.histogram_pool_size >= 0) {
170
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
Guolin Ke's avatar
Guolin Ke committed
171
                 , boosting_config.tree_config.histogram_pool_size);
172
      // Change pool size to -1 (not limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
173
      boosting_config.tree_config.histogram_pool_size = -1;
174
175
    }

Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
181
  }
}

void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "max_bin", &max_bin);
  CHECK(max_bin > 0);
182
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
183
  GetInt(params, "data_random_seed", &data_random_seed);
184
  GetString(params, "data", &data_filename);
Qiwei Ye's avatar
Qiwei Ye committed
185
  GetInt(params, "verbose", &verbosity);
Guolin Ke's avatar
Guolin Ke committed
186
  GetInt(params, "num_model_predict", &num_model_predict);
Guolin Ke's avatar
Guolin Ke committed
187
  GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
  GetBool(params, "is_pre_partition", &is_pre_partition);
  GetBool(params, "is_enable_sparse", &is_enable_sparse);
  GetBool(params, "use_two_round_loading", &use_two_round_loading);
  GetBool(params, "is_save_binary_file", &is_save_binary_file);
Guolin Ke's avatar
Guolin Ke committed
192
  GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
Guolin Ke's avatar
Guolin Ke committed
193
194
  GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
  GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
Guolin Ke's avatar
Guolin Ke committed
195
196
197
198
199
200
201
  GetString(params, "output_model", &output_model);
  GetString(params, "input_model", &input_model);
  GetString(params, "output_result", &output_result);
  std::string tmp_str = "";
  if (GetString(params, "valid_data", &tmp_str)) {
    valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
  }
Guolin Ke's avatar
Guolin Ke committed
202
203
204
205
206
  GetBool(params, "has_header", &has_header);
  GetString(params, "label_column", &label_column);
  GetString(params, "weight_column", &weight_column);
  GetString(params, "group_column", &group_column);
  GetString(params, "ignore_column", &ignore_column);
Guolin Ke's avatar
Guolin Ke committed
207
208
209
210
211
}


void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetBool(params, "is_unbalance", &is_unbalance);
212
  GetDouble(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
213
214
  GetInt(params, "max_position", &max_position);
  CHECK(max_position > 0);
215
216
  GetInt(params, "num_class", &num_class);
  CHECK(num_class >= 1);
Guolin Ke's avatar
Guolin Ke committed
217
218
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
219
    label_gain = Common::StringToDoubleArray(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
220
221
222
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
223
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
224
    for (int i = 1; i < max_label; ++i) {
225
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
226
227
    }
  }
Guolin Ke's avatar
Guolin Ke committed
228
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
229
230
231
232
}


void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
233
  GetDouble(params, "sigmoid", &sigmoid);
234
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
235
236
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
237
    label_gain = Common::StringToDoubleArray(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
238
239
240
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
241
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
242
    for (int i = 1; i < max_label; ++i) {
243
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
244
245
    }
  }
Guolin Ke's avatar
Guolin Ke committed
246
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
247
248
249
250
251
252
253
254
255
256
257
258
  if (GetString(params, "ndcg_eval_at", &tmp_str)) {
    eval_at = Common::StringToIntArray(tmp_str, ',');
    std::sort(eval_at.begin(), eval_at.end());
    for (size_t i = 0; i < eval_at.size(); ++i) {
      CHECK(eval_at[i] > 0);
    }
  } else {
    // default eval ndcg @[1-5]
    for (int i = 1; i <= 5; ++i) {
      eval_at.push_back(i);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
259
  eval_at.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
264
}


void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
265
  GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
266
  CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
267
268
269
270
271
272
  GetDouble(params, "lambda_l1", &lambda_l1);
  CHECK(lambda_l1 >= 0.0f)
  GetDouble(params, "lambda_l2", &lambda_l2);
  CHECK(lambda_l2 >= 0.0f)
  GetDouble(params, "min_gain_to_split", &min_gain_to_split);
  CHECK(min_gain_to_split >= 0.0f)
Guolin Ke's avatar
Guolin Ke committed
273
  GetInt(params, "num_leaves", &num_leaves);
274
  CHECK(num_leaves > 1);
Guolin Ke's avatar
Guolin Ke committed
275
  GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
276
  GetDouble(params, "feature_fraction", &feature_fraction);
277
  CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
278
  GetDouble(params, "histogram_pool_size", &histogram_pool_size);
Guolin Ke's avatar
Guolin Ke committed
279
280
  GetInt(params, "max_depth", &max_depth);
  CHECK(max_depth > 1 || max_depth < 0);
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
}


void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_iterations", &num_iterations);
Guolin Ke's avatar
Guolin Ke committed
286
  GetDouble(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
287
288
289
290
  CHECK(num_iterations >= 0);
  GetInt(params, "bagging_seed", &bagging_seed);
  GetInt(params, "bagging_freq", &bagging_freq);
  CHECK(bagging_freq >= 0);
291
  GetDouble(params, "bagging_fraction", &bagging_fraction);
292
  CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
293
  GetDouble(params, "learning_rate", &learning_rate);
294
  CHECK(learning_rate > 0.0f);
wxchan's avatar
wxchan committed
295
296
  GetInt(params, "early_stopping_round", &early_stopping_round);
  CHECK(early_stopping_round >= 0);
297
298
299
  GetInt(params, "metric_freq", &output_freq);
  CHECK(output_freq >= 0);
  GetBool(params, "is_training_metric", &is_provide_training_metric);
300
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
301
  GetInt(params, "drop_seed", &drop_seed);
302
303
  GetDouble(params, "drop_rate", &drop_rate);
  CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
Guolin Ke's avatar
Guolin Ke committed
304
305
  GetTreeLearnerType(params);
  tree_config.Set(params);
Guolin Ke's avatar
Guolin Ke committed
306
307
}

Guolin Ke's avatar
Guolin Ke committed
308
void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
309
310
  std::string value;
  if (GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
311
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
312
313
314
315
316
317
318
319
    if (value == std::string("serial")) {
      tree_learner_type = TreeLearnerType::kSerialTreeLearner;
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
      tree_learner_type = TreeLearnerType::kFeatureParallelTreelearner;
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
      tree_learner_type = TreeLearnerType::kDataParallelTreeLearner;
    }
    else {
320
      Log::Fatal("Unknown tree learner type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    }
  }
}

void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_machines", &num_machines);
  CHECK(num_machines >= 1);
  GetInt(params, "local_listen_port", &local_listen_port);
  CHECK(local_listen_port > 0);
  GetInt(params, "time_out", &time_out);
  CHECK(time_out > 0);
  GetString(params, "machine_list_file", &machine_list_filename);
}

}  // namespace LightGBM