config.cpp 10.6 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>

#include <vector>
#include <string>
#include <unordered_map>
#include <algorithm>

namespace LightGBM {

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
void OverallConfig::LoadFromString(const char* str) {
  std::unordered_map<std::string, std::string> params;
  auto args = Common::Split(str, " \t\n\r");
  for (auto arg : args) {
    std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
      std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
      if (key.size() <= 0) {
        continue;
      }
      params[key] = value;
    } else {
      Log::Error("Unknown parameter %s", arg.c_str());
    }
  }
  ParameterAlias::KeyAliasTransform(&params);
  Set(params);
}

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  // load main config types
  GetInt(params, "num_threads", &num_threads);
  GetTaskType(params);
wxchan's avatar
wxchan committed
37
38
  
  GetBool(params, "predict_leaf_index", &predict_leaf_index);
Guolin Ke's avatar
Guolin Ke committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

  GetBoostingType(params);
  GetObjectiveType(params);
  GetMetricType(params);

  // construct boosting configs
  if (boosting_type == BoostingType::kGBDT) {
    boosting_config = new GBDTConfig();
  }


  // sub-config setup
  network_config.Set(params);
  io_config.Set(params);

  boosting_config->Set(params);
  objective_config.Set(params);
  metric_config.Set(params);
  // check for conflicts
  CheckParamConflict();
Qiwei Ye's avatar
Qiwei Ye committed
59

Guolin Ke's avatar
Guolin Ke committed
60
  if (io_config.verbosity == 1) {
Qiwei Ye's avatar
Qiwei Ye committed
61
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
62
63
  }
  else if (io_config.verbosity == 0) {
Qiwei Ye's avatar
Qiwei Ye committed
64
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Error);
Guolin Ke's avatar
Guolin Ke committed
65
66
  }
  else if (io_config.verbosity >= 2) {
Qiwei Ye's avatar
Qiwei Ye committed
67
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
Guolin Ke's avatar
Guolin Ke committed
68
69
  }
  else {
Qiwei Ye's avatar
Qiwei Ye committed
70
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
Guolin Ke's avatar
Guolin Ke committed
71
  }
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
}

void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "boosting_type", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
      boosting_type = BoostingType::kGBDT;
    } else {
Qiwei Ye's avatar
Qiwei Ye committed
81
      Log::Fatal("Boosting type %s error", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    }
  }
}

void OverallConfig::GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "objective", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    objective_type = value;
  }
}

void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "metric", &value)) {
    // clear old metrics
    metric_types.clear();
    // to lower
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    // split
    std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
    // remove dumplicate
    std::unordered_map<std::string, int> metric_maps;
    for (auto& metric : metrics) {
      std::transform(metric.begin(), metric.end(), metric.begin(), ::tolower);
      if (metric_maps.count(metric) <= 0) {
        metric_maps[metric] = 1;
      }
    }
    for (auto& pair : metric_maps) {
      std::string sub_metric_str = pair.first;
      metric_types.push_back(sub_metric_str);
    }
  }
}


void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "task", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    if (value == std::string("train") || value == std::string("training")) {
      task_type = TaskType::kTrain;
    } else if (value == std::string("predict") || value == std::string("prediction")
      || value == std::string("test")) {
      task_type = TaskType::kPredict;
    } else {
Qiwei Ye's avatar
Qiwei Ye committed
129
      Log::Fatal("Task type error");
Guolin Ke's avatar
Guolin Ke committed
130
131
132
133
134
    }
  }
}

void OverallConfig::CheckParamConflict() {
135
  GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config);
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
  if (network_config.num_machines > 1) {
    is_parallel = true;
  } else {
    is_parallel = false;
140
    gbdt_config->tree_learner_type = TreeLearnerType::kSerialTreeLearner;
Guolin Ke's avatar
Guolin Ke committed
141
142
  }

143
  if (gbdt_config->tree_learner_type == TreeLearnerType::kSerialTreeLearner) {
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
    is_parallel = false;
    network_config.num_machines = 1;
  }

148
149
  if (gbdt_config->tree_learner_type == TreeLearnerType::kSerialTreeLearner ||
    gbdt_config->tree_learner_type == TreeLearnerType::kFeatureParallelTreelearner) {
Guolin Ke's avatar
Guolin Ke committed
150
    is_parallel_find_bin = false;
151
  } else if (gbdt_config->tree_learner_type == TreeLearnerType::kDataParallelTreeLearner) {
Guolin Ke's avatar
Guolin Ke committed
152
    is_parallel_find_bin = true;
153
    if (gbdt_config->tree_config.histogram_pool_size >= 0) {
Guolin Ke's avatar
Guolin Ke committed
154
      Log::Error("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this for reducing communication cost."
Guolin Ke's avatar
Guolin Ke committed
155
                 , gbdt_config->tree_config.histogram_pool_size);
Guolin Ke's avatar
Guolin Ke committed
156
      // Change pool size to -1(not limit) when using data parallel for reducing communication cost
157
158
159
      gbdt_config->tree_config.histogram_pool_size = -1;
    }

Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
165
166
167
168
  }
}

void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "max_bin", &max_bin);
  CHECK(max_bin > 0);
  GetInt(params, "data_random_seed", &data_random_seed);

  if (!GetString(params, "data", &data_filename)) {
Qiwei Ye's avatar
Qiwei Ye committed
169
    Log::Fatal("No training/prediction data, application quit");
Guolin Ke's avatar
Guolin Ke committed
170
  }
Qiwei Ye's avatar
Qiwei Ye committed
171
  GetInt(params, "verbose", &verbosity);
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
  GetInt(params, "num_model_predict", &num_model_predict);
  GetBool(params, "is_pre_partition", &is_pre_partition);
  GetBool(params, "is_enable_sparse", &is_enable_sparse);
  GetBool(params, "use_two_round_loading", &use_two_round_loading);
  GetBool(params, "is_save_binary_file", &is_save_binary_file);
  GetBool(params, "is_sigmoid", &is_sigmoid);
  GetString(params, "output_model", &output_model);
  GetString(params, "input_model", &input_model);
  GetString(params, "output_result", &output_result);
  GetString(params, "input_init_score", &input_init_score);
  std::string tmp_str = "";
  if (GetString(params, "valid_data", &tmp_str)) {
    valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
  }
Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
190
  GetBool(params, "has_header", &has_header);
  GetString(params, "label_column", &label_column);
  GetString(params, "weight_column", &weight_column);
  GetString(params, "group_column", &group_column);
  GetString(params, "ignore_column", &ignore_column);
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
195
}


void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetBool(params, "is_unbalance", &is_unbalance);
196
  GetFloat(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
  GetInt(params, "max_position", &max_position);
  CHECK(max_position > 0);
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
201
    label_gain = Common::StringToFloatArray(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
202
203
204
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
205
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
206
    for (int i = 1; i < max_label; ++i) {
207
      label_gain.push_back(static_cast<float>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
208
209
210
211
212
213
    }
  }
}


void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
214
  GetFloat(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
215
216
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
217
    label_gain = Common::StringToFloatArray(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
218
219
220
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
221
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
222
    for (int i = 1; i < max_label; ++i) {
223
      label_gain.push_back(static_cast<float>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    }
  }
  if (GetString(params, "ndcg_eval_at", &tmp_str)) {
    eval_at = Common::StringToIntArray(tmp_str, ',');
    std::sort(eval_at.begin(), eval_at.end());
    for (size_t i = 0; i < eval_at.size(); ++i) {
      CHECK(eval_at[i] > 0);
    }
  } else {
    // default eval ndcg @[1-5]
    for (int i = 1; i <= 5; ++i) {
      eval_at.push_back(i);
    }
  }
}


void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
243
  GetFloat(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
244
  CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
Guolin Ke's avatar
Guolin Ke committed
245
  GetInt(params, "num_leaves", &num_leaves);
246
  CHECK(num_leaves > 1);
Guolin Ke's avatar
Guolin Ke committed
247
  GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
248
249
250
  GetFloat(params, "feature_fraction", &feature_fraction);
  CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
  GetFloat(params, "histogram_pool_size", &histogram_pool_size);
Guolin Ke's avatar
Guolin Ke committed
251
252
  GetInt(params, "max_depth", &max_depth);
  CHECK(max_depth > 1 || max_depth < 0);
Guolin Ke's avatar
Guolin Ke committed
253
254
255
256
257
258
259
260
261
}


void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_iterations", &num_iterations);
  CHECK(num_iterations >= 0);
  GetInt(params, "bagging_seed", &bagging_seed);
  GetInt(params, "bagging_freq", &bagging_freq);
  CHECK(bagging_freq >= 0);
262
263
264
265
  GetFloat(params, "bagging_fraction", &bagging_fraction);
  CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
  GetFloat(params, "learning_rate", &learning_rate);
  CHECK(learning_rate > 0.0f);
wxchan's avatar
wxchan committed
266
267
  GetInt(params, "early_stopping_round", &early_stopping_round);
  CHECK(early_stopping_round >= 0);
268
269
270
  GetInt(params, "metric_freq", &output_freq);
  CHECK(output_freq >= 0);
  GetBool(params, "is_training_metric", &is_provide_training_metric);
Guolin Ke's avatar
Guolin Ke committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
}

void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  if (GetString(params, "tree_learner", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), ::tolower);
    if (value == std::string("serial")) {
      tree_learner_type = TreeLearnerType::kSerialTreeLearner;
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
      tree_learner_type = TreeLearnerType::kFeatureParallelTreelearner;
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
      tree_learner_type = TreeLearnerType::kDataParallelTreeLearner;
    }
    else {
Qiwei Ye's avatar
Qiwei Ye committed
285
      Log::Fatal("Tree learner type error");
Guolin Ke's avatar
Guolin Ke committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    }
  }
}

void GBDTConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  BoostingConfig::Set(params);
  GetTreeLearnerType(params);
  tree_config.Set(params);
}

void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_machines", &num_machines);
  CHECK(num_machines >= 1);
  GetInt(params, "local_listen_port", &local_listen_port);
  CHECK(local_listen_port > 0);
  GetInt(params, "time_out", &time_out);
  CHECK(time_out > 0);
  GetString(params, "machine_list_file", &machine_list_filename);
}

}  // namespace LightGBM