config.cpp 12.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>

#include <vector>
#include <string>
#include <unordered_map>
#include <algorithm>

namespace LightGBM {

13
14
15
16
17
18
19
20
21
22
23
24
25
void OverallConfig::LoadFromString(const char* str) {
  std::unordered_map<std::string, std::string> params;
  auto args = Common::Split(str, " \t\n\r");
  for (auto arg : args) {
    std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
      std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
      if (key.size() <= 0) {
        continue;
      }
      params[key] = value;
    } else {
Qiwei Ye's avatar
Qiwei Ye committed
26
      Log::Warning("Unknown parameter %s", arg.c_str());
27
28
29
30
31
32
    }
  }
  ParameterAlias::KeyAliasTransform(&params);
  Set(params);
}

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  // load main config types
  GetInt(params, "num_threads", &num_threads);
  GetTaskType(params);
  GetBoostingType(params);
  GetObjectiveType(params);
  GetMetricType(params);


  // sub-config setup
  network_config.Set(params);
  io_config.Set(params);

Guolin Ke's avatar
Guolin Ke committed
46
  boosting_config.Set(params);
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
  objective_config.Set(params);
  metric_config.Set(params);
  // check for conflicts
  CheckParamConflict();
Qiwei Ye's avatar
Qiwei Ye committed
51

Guolin Ke's avatar
Guolin Ke committed
52
  if (io_config.verbosity == 1) {
Qiwei Ye's avatar
Qiwei Ye committed
53
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
54
55
  }
  else if (io_config.verbosity == 0) {
Qiwei Ye's avatar
Qiwei Ye committed
56
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
57
58
  }
  else if (io_config.verbosity >= 2) {
Qiwei Ye's avatar
Qiwei Ye committed
59
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
Guolin Ke's avatar
Guolin Ke committed
60
61
  }
  else {
Qiwei Ye's avatar
Qiwei Ye committed
62
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
Guolin Ke's avatar
Guolin Ke committed
63
  }
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
68
69
70
71
}

void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "boosting_type", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
      boosting_type = BoostingType::kGBDT;
72
73
    } else if (value == std::string("dart")) {
      boosting_type = BoostingType::kDART;
Guolin Ke's avatar
Guolin Ke committed
74
    } else {
75
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    }
  }
}

void OverallConfig::GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "objective", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    objective_type = value;
  }
}

void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "metric", &value)) {
    // clear old metrics
    metric_types.clear();
    // to lower
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    // split
    std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
    // remove dumplicate
    std::unordered_map<std::string, int> metric_maps;
    for (auto& metric : metrics) {
      std::transform(metric.begin(), metric.end(), metric.begin(), ::tolower);
      if (metric_maps.count(metric) <= 0) {
        metric_maps[metric] = 1;
      }
    }
    for (auto& pair : metric_maps) {
      std::string sub_metric_str = pair.first;
      metric_types.push_back(sub_metric_str);
    }
Guolin Ke's avatar
Guolin Ke committed
109
    metric_types.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  }
}


void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "task", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    if (value == std::string("train") || value == std::string("training")) {
      task_type = TaskType::kTrain;
    } else if (value == std::string("predict") || value == std::string("prediction")
      || value == std::string("test")) {
      task_type = TaskType::kPredict;
    } else {
124
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
    }
  }
}

void OverallConfig::CheckParamConflict() {
130

131
132
  // check if objective_type, metric_type, and num_class match
  bool objective_type_multiclass = (objective_type == std::string("multiclass"));
Guolin Ke's avatar
Guolin Ke committed
133
  int num_class_check = boosting_config.num_class;
134
135
  if (objective_type_multiclass){
      if (num_class_check <= 1){
136
          Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
137
138
139
140
      }
  }
  else {
      if (task_type == TaskType::kTrain && num_class_check != 1){
141
142
          Log::Fatal("Number of classes must be 1 for non-multiclass training");
      }
143
144
145
  }
  for (std::string metric_type : metric_types){
        bool metric_type_multiclass = ( metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
146
        if ((objective_type_multiclass && !metric_type_multiclass)
147
            || (!objective_type_multiclass && metric_type_multiclass)){
148
            Log::Fatal("Objective and metrics don't match");
149
        }
150
  }
151

Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
  if (network_config.num_machines > 1) {
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
156
    boosting_config.tree_learner_type = TreeLearnerType::kSerialTreeLearner;
Guolin Ke's avatar
Guolin Ke committed
157
158
  }

Guolin Ke's avatar
Guolin Ke committed
159
  if (boosting_config.tree_learner_type == TreeLearnerType::kSerialTreeLearner) {
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
    is_parallel = false;
    network_config.num_machines = 1;
  }

Guolin Ke's avatar
Guolin Ke committed
164
165
  if (boosting_config.tree_learner_type == TreeLearnerType::kSerialTreeLearner ||
    boosting_config.tree_learner_type == TreeLearnerType::kFeatureParallelTreelearner) {
Guolin Ke's avatar
Guolin Ke committed
166
    is_parallel_find_bin = false;
Guolin Ke's avatar
Guolin Ke committed
167
  } else if (boosting_config.tree_learner_type == TreeLearnerType::kDataParallelTreeLearner) {
Guolin Ke's avatar
Guolin Ke committed
168
    is_parallel_find_bin = true;
Guolin Ke's avatar
Guolin Ke committed
169
    if (boosting_config.tree_config.histogram_pool_size >= 0) {
170
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
Guolin Ke's avatar
Guolin Ke committed
171
                 , boosting_config.tree_config.histogram_pool_size);
172
      // Change pool size to -1 (not limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
173
      boosting_config.tree_config.histogram_pool_size = -1;
174
175
    }

Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
181
  }
}

void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "max_bin", &max_bin);
  CHECK(max_bin > 0);
182
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
183
  GetInt(params, "data_random_seed", &data_random_seed);
184
  GetString(params, "data", &data_filename);
Qiwei Ye's avatar
Qiwei Ye committed
185
  GetInt(params, "verbose", &verbosity);
Guolin Ke's avatar
Guolin Ke committed
186
  GetInt(params, "num_model_predict", &num_model_predict);
Guolin Ke's avatar
Guolin Ke committed
187
  GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
  GetBool(params, "is_pre_partition", &is_pre_partition);
  GetBool(params, "is_enable_sparse", &is_enable_sparse);
  GetBool(params, "use_two_round_loading", &use_two_round_loading);
  GetBool(params, "is_save_binary_file", &is_save_binary_file);
Guolin Ke's avatar
Guolin Ke committed
192
  GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
Guolin Ke's avatar
Guolin Ke committed
193
194
  GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
  GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
Guolin Ke's avatar
Guolin Ke committed
195
196
197
198
199
200
201
  GetString(params, "output_model", &output_model);
  GetString(params, "input_model", &input_model);
  GetString(params, "output_result", &output_result);
  std::string tmp_str = "";
  if (GetString(params, "valid_data", &tmp_str)) {
    valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
  }
Guolin Ke's avatar
Guolin Ke committed
202
203
204
205
206
  GetBool(params, "has_header", &has_header);
  GetString(params, "label_column", &label_column);
  GetString(params, "weight_column", &weight_column);
  GetString(params, "group_column", &group_column);
  GetString(params, "ignore_column", &ignore_column);
Guolin Ke's avatar
Guolin Ke committed
207
208
209
210
211
}


void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetBool(params, "is_unbalance", &is_unbalance);
212
  GetDouble(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
213
214
  GetInt(params, "max_position", &max_position);
  CHECK(max_position > 0);
215
216
  GetInt(params, "num_class", &num_class);
  CHECK(num_class >= 1);
Guolin Ke's avatar
Guolin Ke committed
217
218
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
219
    label_gain = Common::StringToDoubleArray(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
220
221
222
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
223
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
224
    for (int i = 1; i < max_label; ++i) {
225
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
226
227
    }
  }
Guolin Ke's avatar
Guolin Ke committed
228
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
229
230
231
232
}


void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
233
  GetDouble(params, "sigmoid", &sigmoid);
234
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
235
236
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
237
    label_gain = Common::StringToDoubleArray(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
238
239
240
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
241
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
242
    for (int i = 1; i < max_label; ++i) {
243
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
244
245
    }
  }
Guolin Ke's avatar
Guolin Ke committed
246
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
247
248
249
250
251
252
253
254
255
256
257
258
  if (GetString(params, "ndcg_eval_at", &tmp_str)) {
    eval_at = Common::StringToIntArray(tmp_str, ',');
    std::sort(eval_at.begin(), eval_at.end());
    for (size_t i = 0; i < eval_at.size(); ++i) {
      CHECK(eval_at[i] > 0);
    }
  } else {
    // default eval ndcg @[1-5]
    for (int i = 1; i <= 5; ++i) {
      eval_at.push_back(i);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
259
  eval_at.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
264
}


void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
265
  GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
266
  CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
267
268
269
270
271
272
  GetDouble(params, "lambda_l1", &lambda_l1);
  CHECK(lambda_l1 >= 0.0f)
  GetDouble(params, "lambda_l2", &lambda_l2);
  CHECK(lambda_l2 >= 0.0f)
  GetDouble(params, "min_gain_to_split", &min_gain_to_split);
  CHECK(min_gain_to_split >= 0.0f)
Guolin Ke's avatar
Guolin Ke committed
273
  GetInt(params, "num_leaves", &num_leaves);
274
  CHECK(num_leaves > 1);
Guolin Ke's avatar
Guolin Ke committed
275
  GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
276
  GetDouble(params, "feature_fraction", &feature_fraction);
277
  CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
278
  GetDouble(params, "histogram_pool_size", &histogram_pool_size);
Guolin Ke's avatar
Guolin Ke committed
279
280
  GetInt(params, "max_depth", &max_depth);
  CHECK(max_depth > 1 || max_depth < 0);
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
}


void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_iterations", &num_iterations);
Guolin Ke's avatar
Guolin Ke committed
286
  GetDouble(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
287
288
289
290
  CHECK(num_iterations >= 0);
  GetInt(params, "bagging_seed", &bagging_seed);
  GetInt(params, "bagging_freq", &bagging_freq);
  CHECK(bagging_freq >= 0);
291
  GetDouble(params, "bagging_fraction", &bagging_fraction);
292
  CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
293
  GetDouble(params, "learning_rate", &learning_rate);
294
  CHECK(learning_rate > 0.0f);
wxchan's avatar
wxchan committed
295
296
  GetInt(params, "early_stopping_round", &early_stopping_round);
  CHECK(early_stopping_round >= 0);
297
298
299
  GetInt(params, "metric_freq", &output_freq);
  CHECK(output_freq >= 0);
  GetBool(params, "is_training_metric", &is_provide_training_metric);
300
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
301
  GetInt(params, "drop_seed", &drop_seed);
302
303
  GetDouble(params, "drop_rate", &drop_rate);
  CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
Guolin Ke's avatar
Guolin Ke committed
304
305
  GetTreeLearnerType(params);
  tree_config.Set(params);
Guolin Ke's avatar
Guolin Ke committed
306
307
}

Guolin Ke's avatar
Guolin Ke committed
308
void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
314
315
316
317
318
319
  std::string value;
  if (GetString(params, "tree_learner", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    if (value == std::string("serial")) {
      tree_learner_type = TreeLearnerType::kSerialTreeLearner;
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
      tree_learner_type = TreeLearnerType::kFeatureParallelTreelearner;
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
      tree_learner_type = TreeLearnerType::kDataParallelTreeLearner;
    }
    else {
320
      Log::Fatal("Unknown tree learner type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    }
  }
}

void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_machines", &num_machines);
  CHECK(num_machines >= 1);
  GetInt(params, "local_listen_port", &local_listen_port);
  CHECK(local_listen_port > 0);
  GetInt(params, "time_out", &time_out);
  CHECK(time_out > 0);
  GetString(params, "machine_list_file", &machine_list_filename);
}

}  // namespace LightGBM