config.cpp 19.3 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
4
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/log.h>

#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
9
#include <unordered_set>
Guolin Ke's avatar
Guolin Ke committed
10
#include <algorithm>
Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

wxchan's avatar
wxchan committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void ConfigBase::KV2Map(std::unordered_map<std::string, std::string>& params, const char* kv) {
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
  if (tmp_strs.size() == 2) {
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
    std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    if (key.size() > 0) {
      auto value_search = params.find(key);
      if (value_search == params.end()) { // not set
        params.emplace(key, value);
      } else {
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s.",
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

35
std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* parameters) {
36
  std::unordered_map<std::string, std::string> params;
37
  auto args = Common::Split(parameters, " \t\n\r");
38
  for (auto arg : args) {
wxchan's avatar
wxchan committed
39
    KV2Map(params, Common::Trim(arg).c_str());
40
41
  }
  ParameterAlias::KeyAliasTransform(&params);
42
  return params;
43
44
}

wxchan's avatar
wxchan committed
45
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting_type) {
Guolin Ke's avatar
Guolin Ke committed
46
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
47
  if (ConfigBase::GetString(params, "boosting_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
48
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
49
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
wxchan's avatar
wxchan committed
50
      *boosting_type = "gbdt";
51
    } else if (value == std::string("dart")) {
wxchan's avatar
wxchan committed
52
      *boosting_type = "dart";
Guolin Ke's avatar
Guolin Ke committed
53
    } else if (value == std::string("goss")) {
wxchan's avatar
wxchan committed
54
      *boosting_type = "goss";
Guolin Ke's avatar
Guolin Ke committed
55
    } else if (value == std::string("rf") || value == std::string("randomforest")) {
wxchan's avatar
wxchan committed
56
      *boosting_type = "rf";
Guolin Ke's avatar
Guolin Ke committed
57
    } else {
58
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
    }
  }
}

wxchan's avatar
wxchan committed
63
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective_type) {
Guolin Ke's avatar
Guolin Ke committed
64
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
65
  if (ConfigBase::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
66
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
wxchan's avatar
wxchan committed
67
    *objective_type = value;
Guolin Ke's avatar
Guolin Ke committed
68
69
70
  }
}

wxchan's avatar
wxchan committed
71
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric_types) {
Guolin Ke's avatar
Guolin Ke committed
72
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
73
  if (ConfigBase::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
74
    // clear old metrics
wxchan's avatar
wxchan committed
75
    metric_types->clear();
Guolin Ke's avatar
Guolin Ke committed
76
    // to lower
Guolin Ke's avatar
Guolin Ke committed
77
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
78
79
    // split
    std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
80
    // remove duplicate
Guolin Ke's avatar
Guolin Ke committed
81
    std::unordered_set<std::string> metric_sets;
Guolin Ke's avatar
Guolin Ke committed
82
    for (auto& metric : metrics) {
Guolin Ke's avatar
Guolin Ke committed
83
      std::transform(metric.begin(), metric.end(), metric.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
84
85
      if (metric_sets.count(metric) <= 0) {
        metric_sets.insert(metric);
Guolin Ke's avatar
Guolin Ke committed
86
87
      }
    }
Guolin Ke's avatar
Guolin Ke committed
88
    for (auto& metric : metric_sets) {
wxchan's avatar
wxchan committed
89
      metric_types->push_back(metric);
Guolin Ke's avatar
Guolin Ke committed
90
    }
wxchan's avatar
wxchan committed
91
    metric_types->shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
92
  }
93
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
94
  if (metric_types->empty() && value.size() == 0) {
95
96
97
98
99
    if (ConfigBase::GetString(params, "objective", &value)) {
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
      metric_types->push_back(value);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
100
101
}

wxchan's avatar
wxchan committed
102
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) {
Guolin Ke's avatar
Guolin Ke committed
103
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
104
  if (ConfigBase::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
105
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
106
    if (value == std::string("train") || value == std::string("training")) {
wxchan's avatar
wxchan committed
107
      *task_type = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
108
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
109
               || value == std::string("test")) {
wxchan's avatar
wxchan committed
110
      *task_type = TaskType::kPredict;
111
    } else if (value == std::string("convert_model")) {
wxchan's avatar
wxchan committed
112
      *task_type = TaskType::kConvertModel;
113
114
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
      *task_type = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
115
    } else {
116
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
    }
  }
}

wxchan's avatar
wxchan committed
121
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
  std::string value;
  if (ConfigBase::GetString(params, "device", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
126
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
127
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
128
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

wxchan's avatar
wxchan committed
135
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner_type) {
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
  std::string value;
  if (ConfigBase::GetString(params, "tree_learner", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
wxchan's avatar
wxchan committed
140
      *tree_learner_type = "serial";
Guolin Ke's avatar
Guolin Ke committed
141
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
wxchan's avatar
wxchan committed
142
      *tree_learner_type = "feature";
Guolin Ke's avatar
Guolin Ke committed
143
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
wxchan's avatar
wxchan committed
144
      *tree_learner_type = "data";
Guolin Ke's avatar
Guolin Ke committed
145
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
wxchan's avatar
wxchan committed
146
      *tree_learner_type = "voting";
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  // load main config types
  GetInt(params, "num_threads", &num_threads);
  GetString(params, "convert_model_language", &convert_model_language);

  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
    int int_max = std::numeric_limits<short>::max();
    io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
  }
wxchan's avatar
wxchan committed
167
168
  GetTaskType(params, &task_type);
  GetBoostingType(params, &boosting_type);
Guolin Ke's avatar
Guolin Ke committed
169

wxchan's avatar
wxchan committed
170
  GetMetricType(params, &metric_types);
Guolin Ke's avatar
Guolin Ke committed
171
172
173
174
175
176

  // sub-config setup
  network_config.Set(params);
  io_config.Set(params);

  boosting_config.Set(params);
wxchan's avatar
wxchan committed
177
  GetObjectiveType(params, &objective_type);
Guolin Ke's avatar
Guolin Ke committed
178
179
180
181
182
  objective_config.Set(params);
  metric_config.Set(params);

  // check for conflicts
  CheckParamConflict();
183

Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
188
189
190
191
192
193
194
  if (io_config.verbosity == 1) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
  } else if (io_config.verbosity == 0) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
  } else if (io_config.verbosity >= 2) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

195
196
197
198
199
200
201
202
203
bool CheckMultiClassObjective(const std::string& objective_type) {
  return (objective_type == std::string("multiclass")
          || objective_type == std::string("multiclassova")
          || objective_type == std::string("softmax")
          || objective_type == std::string("multiclass_ova")
          || objective_type == std::string("ova")
          || objective_type == std::string("ovr"));
}

Guolin Ke's avatar
Guolin Ke committed
204
void OverallConfig::CheckParamConflict() {
205
  // check if objective_type, metric_type, and num_class match
Guolin Ke's avatar
Guolin Ke committed
206
  int num_class_check = boosting_config.num_class;
207
208
  bool objective_custom = objective_type == std::string("none") || objective_type == std::string("null") || objective_type == std::string("custom");
  bool objective_type_multiclass = CheckMultiClassObjective(objective_type) || (objective_custom && num_class_check > 1);
209
  
210
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
211
212
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
213
214
215
216
217
    }
  } else {
    if (task_type == TaskType::kTrain && num_class_check != 1) {
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
218
  }
wxchan's avatar
wxchan committed
219
220
  if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
    for (std::string metric_type : metric_types) {
221
222
      bool metric_type_multiclass = (CheckMultiClassObjective(metric_type) 
                                     || metric_type == std::string("multi_logloss")
223
                                     || metric_type == std::string("multi_error"));
wxchan's avatar
wxchan committed
224
225
226
227
      if ((objective_type_multiclass && !metric_type_multiclass)
        || (!objective_type_multiclass && metric_type_multiclass)) {
        Log::Fatal("Objective and metrics don't match");
      }
228
    }
229
  }
230

Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
  if (network_config.num_machines > 1) {
    is_parallel = true;
  } else {
    is_parallel = false;
235
    boosting_config.tree_learner_type = "serial";
Guolin Ke's avatar
Guolin Ke committed
236
237
  }

Guolin Ke's avatar
Guolin Ke committed
238
239
240
  bool is_single_tree_learner = boosting_config.tree_learner_type == std::string("serial");

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
241
242
243
244
    is_parallel = false;
    network_config.num_machines = 1;
  }

Guolin Ke's avatar
Guolin Ke committed
245
  if (is_single_tree_learner || boosting_config.tree_learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
246
    is_parallel_find_bin = false;
247
248
  } else if (boosting_config.tree_learner_type == std::string("data")
             || boosting_config.tree_learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
249
    is_parallel_find_bin = true;
250
    if (boosting_config.tree_config.histogram_pool_size >= 0
251
        && boosting_config.tree_learner_type == std::string("data")) {
252
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
253
        , boosting_config.tree_config.histogram_pool_size);
tks's avatar
tks committed
254
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
255
      boosting_config.tree_config.histogram_pool_size = -1;
256
    }
Guolin Ke's avatar
Guolin Ke committed
257
  }
258
259
  // Check max_depth and num_leaves
  if (boosting_config.tree_config.max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
260
    int full_num_leaves = static_cast<int>(std::pow(2, boosting_config.tree_config.max_depth));
261
262
    if (full_num_leaves > boosting_config.tree_config.num_leaves 
        && boosting_config.tree_config.num_leaves == kDefaultNumLeaves) {
Nikita Titov's avatar
Nikita Titov committed
263
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves.");
264
265
    }
  }
Guolin Ke's avatar
Guolin Ke committed
266
267
268
269
270
}

void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "max_bin", &max_bin);
  CHECK(max_bin > 0);
271
  GetInt(params, "num_class", &num_class);
272
  CHECK(num_class > 0);
Guolin Ke's avatar
Guolin Ke committed
273
  GetInt(params, "data_random_seed", &data_random_seed);
274
  GetString(params, "data", &data_filename);
275
  GetString(params, "init_score_file", &initscore_filename);
Qiwei Ye's avatar
Qiwei Ye committed
276
  GetInt(params, "verbose", &verbosity);
Guolin Ke's avatar
Guolin Ke committed
277
  GetInt(params, "num_iteration_predict", &num_iteration_predict);
Guolin Ke's avatar
Guolin Ke committed
278
  GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
279
  CHECK(bin_construct_sample_cnt > 0);
Guolin Ke's avatar
Guolin Ke committed
280
281
  GetBool(params, "is_pre_partition", &is_pre_partition);
  GetBool(params, "is_enable_sparse", &is_enable_sparse);
282
  GetDouble(params, "sparse_threshold", &sparse_threshold);
Guolin Ke's avatar
Guolin Ke committed
283
284
  GetBool(params, "use_two_round_loading", &use_two_round_loading);
  GetBool(params, "is_save_binary_file", &is_save_binary_file);
Guolin Ke's avatar
Guolin Ke committed
285
  GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
Guolin Ke's avatar
Guolin Ke committed
286
287
  GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
  GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
288
  GetBool(params, "is_predict_contrib", &is_predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
289
  GetInt(params, "snapshot_freq", &snapshot_freq);
Guolin Ke's avatar
Guolin Ke committed
290
291
  GetString(params, "output_model", &output_model);
  GetString(params, "input_model", &input_model);
292
  GetString(params, "convert_model", &convert_model);
Guolin Ke's avatar
Guolin Ke committed
293
294
  GetString(params, "output_result", &output_result);
  std::string tmp_str = "";
Guolin Ke's avatar
Guolin Ke committed
295
296
297
  if (GetString(params, "monotone_constraints", &tmp_str)) {
    monotone_constraints = Common::StringToArray<int8_t>(tmp_str.c_str(), ',');
  }
Guolin Ke's avatar
Guolin Ke committed
298
299
300
  if (GetString(params, "valid_data", &tmp_str)) {
    valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
  }
301
302
303
304
305
306
  if (GetString(params, "valid_init_score_file", &tmp_str)) {
    valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
  } else {
    valid_data_initscores = std::vector<std::string>(valid_data_filenames.size(), "");
  }
  CHECK(valid_data_filenames.size() == valid_data_initscores.size());
Guolin Ke's avatar
Guolin Ke committed
307
308
309
310
311
  GetBool(params, "has_header", &has_header);
  GetString(params, "label_column", &label_column);
  GetString(params, "weight_column", &weight_column);
  GetString(params, "group_column", &group_column);
  GetString(params, "ignore_column", &ignore_column);
312
  GetString(params, "categorical_column", &categorical_column);
Guolin Ke's avatar
Guolin Ke committed
313
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
314
315
  GetInt(params, "min_data_in_bin", &min_data_in_bin);
  CHECK(min_data_in_bin > 0);
316
  CHECK(min_data_in_leaf >= 0);
Guolin Ke's avatar
Guolin Ke committed
317
  GetDouble(params, "max_conflict_rate", &max_conflict_rate);
318
  CHECK(max_conflict_rate >= 0);
Guolin Ke's avatar
Guolin Ke committed
319
  GetBool(params, "enable_bundle", &enable_bundle);
320
321
322
  GetBool(params, "pred_early_stop", &pred_early_stop);
  GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
  GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
323
324
  GetBool(params, "use_missing", &use_missing);
  GetBool(params, "zero_as_missing", &zero_as_missing);
wxchan's avatar
wxchan committed
325
  GetDeviceType(params, &device_type);
Guolin Ke's avatar
Guolin Ke committed
326
}
Guolin Ke's avatar
Guolin Ke committed
327
328
329

void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetBool(params, "is_unbalance", &is_unbalance);
330
  GetDouble(params, "sigmoid", &sigmoid);
331
  CHECK(sigmoid > 0);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
332
  GetDouble(params, "fair_c", &fair_c);
333
  CHECK(fair_c > 0);
334
  GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
335
  CHECK(poisson_max_delta_step > 0);
Guolin Ke's avatar
Guolin Ke committed
336
337
  GetInt(params, "max_position", &max_position);
  CHECK(max_position > 0);
338
  GetInt(params, "num_class", &num_class);
339
  CHECK(num_class > 0);
Guolin Ke's avatar
Guolin Ke committed
340
  GetDouble(params, "scale_pos_weight", &scale_pos_weight);
341
  CHECK(scale_pos_weight > 0);
342
343
  GetDouble(params, "alpha", &alpha);
  GetBool(params, "reg_sqrt", &reg_sqrt);
Guolin Ke's avatar
Guolin Ke committed
344
345
  GetDouble(params, "tweedie_variance_power", &tweedie_variance_power);
  CHECK(tweedie_variance_power >= 1 && tweedie_variance_power < 2);
Guolin Ke's avatar
Guolin Ke committed
346
347
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
348
    label_gain = Common::StringToArray<double>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
349
350
351
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
352
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
353
    for (int i = 1; i < max_label; ++i) {
354
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
355
356
    }
  }
Guolin Ke's avatar
Guolin Ke committed
357
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
358
359
360
361
}


void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
362
  GetDouble(params, "sigmoid", &sigmoid);
363
  CHECK(sigmoid > 0);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
364
  GetDouble(params, "fair_c", &fair_c);
365
  CHECK(fair_c > 0);
366
  GetInt(params, "num_class", &num_class);
367
  CHECK(num_class > 0);
368
  GetDouble(params, "alpha", &alpha);
Guolin Ke's avatar
Guolin Ke committed
369
370
  GetDouble(params, "tweedie_variance_power", &tweedie_variance_power);
  CHECK(tweedie_variance_power >= 1 && tweedie_variance_power < 2);
Guolin Ke's avatar
Guolin Ke committed
371
372
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
373
    label_gain = Common::StringToArray<double>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
374
375
376
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
377
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
378
    for (int i = 1; i < max_label; ++i) {
379
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
380
381
    }
  }
Guolin Ke's avatar
Guolin Ke committed
382
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
383
  if (GetString(params, "ndcg_eval_at", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
384
    eval_at = Common::StringToArray<int>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
389
390
391
392
393
394
    std::sort(eval_at.begin(), eval_at.end());
    for (size_t i = 0; i < eval_at.size(); ++i) {
      CHECK(eval_at[i] > 0);
    }
  } else {
    // default eval ndcg @[1-5]
    for (int i = 1; i <= 5; ++i) {
      eval_at.push_back(i);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
395
  eval_at.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
396
397
398
399
400
}


void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
401
  GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
402
403
  CHECK(min_data_in_leaf > 0);
  CHECK(min_sum_hessian_in_leaf >= 0);
404
  GetDouble(params, "lambda_l1", &lambda_l1);
405
406
407
408
409
410
  CHECK(lambda_l1 >= 0.0f);
  GetDouble(params, "lambda_l2", &lambda_l2);
  CHECK(lambda_l2 >= 0.0f);
  GetDouble(params, "min_gain_to_split", &min_gain_to_split);
  CHECK(min_gain_to_split >= 0.0f);
  GetInt(params, "num_leaves", &num_leaves);
411
  CHECK(num_leaves > 1);
Guolin Ke's avatar
Guolin Ke committed
412
  GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
413
  GetDouble(params, "feature_fraction", &feature_fraction);
414
  CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
415
  GetDouble(params, "histogram_pool_size", &histogram_pool_size);
Guolin Ke's avatar
Guolin Ke committed
416
  GetInt(params, "max_depth", &max_depth);
Guolin Ke's avatar
Guolin Ke committed
417
  GetInt(params, "top_k", &top_k);
418
  CHECK(top_k > 0);
419
420
421
  GetInt(params, "gpu_platform_id", &gpu_platform_id);
  GetInt(params, "gpu_device_id", &gpu_device_id);
  GetBool(params, "gpu_use_dp", &gpu_use_dp);
422
  GetInt(params, "max_cat_threshold", &max_cat_threshold);
Guolin Ke's avatar
Guolin Ke committed
423
  GetDouble(params, "cat_l2", &cat_l2);
424
  GetDouble(params, "cat_smooth", &cat_smooth);
ChenZhiyong's avatar
ChenZhiyong committed
425
  GetInt(params, "min_data_per_group", &min_data_per_group);
426
  GetInt(params, "max_cat_to_onehot", &max_cat_to_onehot);
ChenZhiyong's avatar
ChenZhiyong committed
427
  CHECK(max_cat_threshold > 0);
Guolin Ke's avatar
Guolin Ke committed
428
  CHECK(cat_l2 >= 0.0f);
429
  CHECK(cat_smooth >= 1);
ChenZhiyong's avatar
ChenZhiyong committed
430
  CHECK(min_data_per_group > 0);
431
  CHECK(max_cat_to_onehot > 0);
Guolin Ke's avatar
Guolin Ke committed
432
433
434
435
436
437
438
439
}

void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_iterations", &num_iterations);
  CHECK(num_iterations >= 0);
  GetInt(params, "bagging_seed", &bagging_seed);
  GetInt(params, "bagging_freq", &bagging_freq);
  CHECK(bagging_freq >= 0);
440
  GetDouble(params, "bagging_fraction", &bagging_fraction);
441
  CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
442
  GetDouble(params, "learning_rate", &learning_rate);
443
  CHECK(learning_rate > 0.0f);
wxchan's avatar
wxchan committed
444
445
  GetInt(params, "early_stopping_round", &early_stopping_round);
  CHECK(early_stopping_round >= 0);
Guolin Ke's avatar
Guolin Ke committed
446
  GetInt(params, "output_freq", &output_freq);
447
448
  CHECK(output_freq >= 0);
  GetBool(params, "is_training_metric", &is_provide_training_metric);
449
  GetInt(params, "num_class", &num_class);
450
  CHECK(num_class > 0);
Guolin Ke's avatar
Guolin Ke committed
451
  GetInt(params, "drop_seed", &drop_seed);
452
  GetDouble(params, "drop_rate", &drop_rate);
453
  GetDouble(params, "skip_drop", &skip_drop);
454
455
  CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
  CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
456
  GetInt(params, "max_drop", &max_drop);
457
  CHECK(max_drop > 0);
458
459
  GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
  GetBool(params, "uniform_drop", &uniform_drop);
Guolin Ke's avatar
Guolin Ke committed
460
461
  GetDouble(params, "top_rate", &top_rate);
  GetDouble(params, "other_rate", &other_rate);
462
  CHECK(top_rate > 0);
Nikita Titov's avatar
Nikita Titov committed
463
464
  CHECK(other_rate > 0);
  CHECK(top_rate + other_rate <= 1.0);
465
  GetBool(params, "boost_from_average", &boost_from_average);
wxchan's avatar
wxchan committed
466
467
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner_type);
Guolin Ke's avatar
Guolin Ke committed
468
  tree_config.Set(params);
Guolin Ke's avatar
Guolin Ke committed
469
470
471
472
473
474
475
476
477
478
}

void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_machines", &num_machines);
  CHECK(num_machines >= 1);
  GetInt(params, "local_listen_port", &local_listen_port);
  CHECK(local_listen_port > 0);
  GetInt(params, "time_out", &time_out);
  CHECK(time_out > 0);
  GetString(params, "machine_list_file", &machine_list_filename);
479
  GetString(params, "machines", &machines);
Guolin Ke's avatar
Guolin Ke committed
480
481
482
}

}  // namespace LightGBM