config.cpp 16.9 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
4
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/log.h>

#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
9
#include <unordered_set>
Guolin Ke's avatar
Guolin Ke committed
10
#include <algorithm>
Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

15
std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* parameters) {
16
  std::unordered_map<std::string, std::string> params;
17
  auto args = Common::Split(parameters, " \t\n\r");
18
19
20
21
22
23
24
25
26
  for (auto arg : args) {
    std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
      std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
      if (key.size() <= 0) {
        continue;
      }
      params[key] = value;
27
    } else if (Common::Trim(arg).size() > 0) {
Qiwei Ye's avatar
Qiwei Ye committed
28
      Log::Warning("Unknown parameter %s", arg.c_str());
29
30
31
    }
  }
  ParameterAlias::KeyAliasTransform(&params);
32
  return params;
33
34
}

Guolin Ke's avatar
Guolin Ke committed
35
std::string GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
36
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
37
38
  std::string boosting_type = kDefaultBoostingType;
  if (ConfigBase::GetString(params, "boosting_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
39
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
40
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
41
      boosting_type = "gbdt";
42
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
43
      boosting_type = "dart";
Guolin Ke's avatar
Guolin Ke committed
44
45
    } else if (value == std::string("goss")) {
      boosting_type = "goss";
Guolin Ke's avatar
Guolin Ke committed
46
47
    } else if (value == std::string("rf") || value == std::string("randomforest")) {
      boosting_type = "rf";
Guolin Ke's avatar
Guolin Ke committed
48
    } else {
49
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
50
51
    }
  }
Guolin Ke's avatar
Guolin Ke committed
52
  return boosting_type;
Guolin Ke's avatar
Guolin Ke committed
53
54
}

Guolin Ke's avatar
Guolin Ke committed
55
std::string GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
56
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
57
58
  std::string objective_type = kDefaultObjectiveType;
  if (ConfigBase::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
59
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
60
61
    objective_type = value;
  }
Guolin Ke's avatar
Guolin Ke committed
62
  return objective_type;
Guolin Ke's avatar
Guolin Ke committed
63
64
}

Guolin Ke's avatar
Guolin Ke committed
65
std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
66
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
67
68
  std::vector<std::string> metric_types;
  if (ConfigBase::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
69
70
71
    // clear old metrics
    metric_types.clear();
    // to lower
Guolin Ke's avatar
Guolin Ke committed
72
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
73
74
    // split
    std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
75
    // remove duplicate
Guolin Ke's avatar
Guolin Ke committed
76
    std::unordered_set<std::string> metric_sets;
Guolin Ke's avatar
Guolin Ke committed
77
    for (auto& metric : metrics) {
Guolin Ke's avatar
Guolin Ke committed
78
      std::transform(metric.begin(), metric.end(), metric.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
79
80
      if (metric_sets.count(metric) <= 0) {
        metric_sets.insert(metric);
Guolin Ke's avatar
Guolin Ke committed
81
82
      }
    }
Guolin Ke's avatar
Guolin Ke committed
83
84
    for (auto& metric : metric_sets) {
      metric_types.push_back(metric);
Guolin Ke's avatar
Guolin Ke committed
85
    }
Guolin Ke's avatar
Guolin Ke committed
86
    metric_types.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
87
  }
Guolin Ke's avatar
Guolin Ke committed
88
  return metric_types;
Guolin Ke's avatar
Guolin Ke committed
89
90
}

Guolin Ke's avatar
Guolin Ke committed
91
TaskType GetTaskType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
92
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
93
94
  TaskType task_type = TaskType::kTrain;
  if (ConfigBase::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
95
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
96
97
98
    if (value == std::string("train") || value == std::string("training")) {
      task_type = TaskType::kTrain;
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
99
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
100
      task_type = TaskType::kPredict;
101
102
    } else if (value == std::string("convert_model")) {
      task_type = TaskType::kConvertModel;
Guolin Ke's avatar
Guolin Ke committed
103
    } else {
104
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
105
106
    }
  }
Guolin Ke's avatar
Guolin Ke committed
107
  return task_type;
Guolin Ke's avatar
Guolin Ke committed
108
109
}

Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
std::string GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  std::string device_type = kDefaultDevice;
  if (ConfigBase::GetString(params, "device", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
      device_type = "cpu";
    } else if (value == std::string("gpu")) {
      device_type = "gpu";
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
  return device_type;
}

std::string GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  std::string tree_learner_type = kDefaultTreeLearnerType;
  if (ConfigBase::GetString(params, "tree_learner", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
      tree_learner_type = "serial";
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
      tree_learner_type = "feature";
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
      tree_learner_type = "data";
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
      tree_learner_type = "voting";
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
  return tree_learner_type;
}

void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  // load main config types
  GetInt(params, "num_threads", &num_threads);
  GetString(params, "convert_model_language", &convert_model_language);

  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
    int int_max = std::numeric_limits<short>::max();
    io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
  }
  task_type = GetTaskType(params);
  boosting_type = GetBoostingType(params);

  metric_types = GetMetricType(params);

  // sub-config setup
  network_config.Set(params);
  io_config.Set(params);

  boosting_config.Set(params);
  objective_type = GetObjectiveType(params);
  objective_config.Set(params);
  metric_config.Set(params);

  // check for conflicts
  CheckParamConflict();
176

Guolin Ke's avatar
Guolin Ke committed
177
178
179
180
181
182
183
184
185
186
187
188
  if (io_config.verbosity == 1) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
  } else if (io_config.verbosity == 0) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
  } else if (io_config.verbosity >= 2) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

void OverallConfig::CheckParamConflict() {
189
  // check if objective_type, metric_type, and num_class match
190
  bool objective_type_multiclass = (objective_type == std::string("multiclass")
Guolin Ke's avatar
Guolin Ke committed
191
                                    || objective_type == std::string("multiclassova"));
Guolin Ke's avatar
Guolin Ke committed
192
  int num_class_check = boosting_config.num_class;
193
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
194
195
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
196
197
198
199
200
    }
  } else {
    if (task_type == TaskType::kTrain && num_class_check != 1) {
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
201
  }
wxchan's avatar
wxchan committed
202
203
  if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
    for (std::string metric_type : metric_types) {
204
      bool metric_type_multiclass = (metric_type == std::string("multi_logloss")
205
                                     || metric_type == std::string("multi_error"));
wxchan's avatar
wxchan committed
206
207
208
209
      if ((objective_type_multiclass && !metric_type_multiclass)
        || (!objective_type_multiclass && metric_type_multiclass)) {
        Log::Fatal("Objective and metrics don't match");
      }
210
    }
211
  }
212

Guolin Ke's avatar
Guolin Ke committed
213
214
215
216
  if (network_config.num_machines > 1) {
    is_parallel = true;
  } else {
    is_parallel = false;
217
    boosting_config.tree_learner_type = "serial";
Guolin Ke's avatar
Guolin Ke committed
218
219
  }

Guolin Ke's avatar
Guolin Ke committed
220
221
222
  bool is_single_tree_learner = boosting_config.tree_learner_type == std::string("serial");

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
223
224
225
226
    is_parallel = false;
    network_config.num_machines = 1;
  }

Guolin Ke's avatar
Guolin Ke committed
227
  if (is_single_tree_learner || boosting_config.tree_learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
228
    is_parallel_find_bin = false;
229
230
  } else if (boosting_config.tree_learner_type == std::string("data")
             || boosting_config.tree_learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
231
    is_parallel_find_bin = true;
232
    if (boosting_config.tree_config.histogram_pool_size >= 0
233
        && boosting_config.tree_learner_type == std::string("data")) {
234
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
235
        , boosting_config.tree_config.histogram_pool_size);
tks's avatar
tks committed
236
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
237
      boosting_config.tree_config.histogram_pool_size = -1;
238
    }
Guolin Ke's avatar
Guolin Ke committed
239
  }
240
241
  // Check max_depth and num_leaves
  if (boosting_config.tree_config.max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
242
    int full_num_leaves = static_cast<int>(std::pow(2, boosting_config.tree_config.max_depth));
243
244
245
246
247
    if (full_num_leaves > boosting_config.tree_config.num_leaves 
        && boosting_config.tree_config.num_leaves == kDefaultNumLeaves) {
      Log::Warning("Accuarcy may be bad since you didn't set num_leaves.");
    }
  }
Guolin Ke's avatar
Guolin Ke committed
248
249
250
251
252
}

void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "max_bin", &max_bin);
  CHECK(max_bin > 0);
253
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
254
  GetInt(params, "data_random_seed", &data_random_seed);
255
  GetString(params, "data", &data_filename);
256
  GetString(params, "init_score_file", &initscore_filename);
Qiwei Ye's avatar
Qiwei Ye committed
257
  GetInt(params, "verbose", &verbosity);
Guolin Ke's avatar
Guolin Ke committed
258
  GetInt(params, "num_iteration_predict", &num_iteration_predict);
Guolin Ke's avatar
Guolin Ke committed
259
  GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
260
261
  GetBool(params, "is_pre_partition", &is_pre_partition);
  GetBool(params, "is_enable_sparse", &is_enable_sparse);
262
  GetDouble(params, "sparse_threshold", &sparse_threshold);
Guolin Ke's avatar
Guolin Ke committed
263
264
  GetBool(params, "use_two_round_loading", &use_two_round_loading);
  GetBool(params, "is_save_binary_file", &is_save_binary_file);
Guolin Ke's avatar
Guolin Ke committed
265
  GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
Guolin Ke's avatar
Guolin Ke committed
266
267
  GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
  GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
268
  GetBool(params, "is_predict_contrib", &is_predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
269
  GetInt(params, "snapshot_freq", &snapshot_freq);
Guolin Ke's avatar
Guolin Ke committed
270
271
  GetString(params, "output_model", &output_model);
  GetString(params, "input_model", &input_model);
272
  GetString(params, "convert_model", &convert_model);
Guolin Ke's avatar
Guolin Ke committed
273
274
275
276
277
  GetString(params, "output_result", &output_result);
  std::string tmp_str = "";
  if (GetString(params, "valid_data", &tmp_str)) {
    valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
  }
278
279
280
281
282
283
  if (GetString(params, "valid_init_score_file", &tmp_str)) {
    valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
  } else {
    valid_data_initscores = std::vector<std::string>(valid_data_filenames.size(), "");
  }
  CHECK(valid_data_filenames.size() == valid_data_initscores.size());
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
  GetBool(params, "has_header", &has_header);
  GetString(params, "label_column", &label_column);
  GetString(params, "weight_column", &weight_column);
  GetString(params, "group_column", &group_column);
  GetString(params, "ignore_column", &ignore_column);
289
  GetString(params, "categorical_column", &categorical_column);
Guolin Ke's avatar
Guolin Ke committed
290
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
291
292
  GetInt(params, "min_data_in_bin", &min_data_in_bin);
  CHECK(min_data_in_bin > 0);
Guolin Ke's avatar
Guolin Ke committed
293
294
  GetDouble(params, "max_conflict_rate", &max_conflict_rate);
  GetBool(params, "enable_bundle", &enable_bundle);
295
296
297
298

  GetBool(params, "pred_early_stop", &pred_early_stop);
  GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
  GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
299
300
  GetBool(params, "use_missing", &use_missing);
  GetBool(params, "zero_as_missing", &zero_as_missing);
Guolin Ke's avatar
Guolin Ke committed
301
  device_type = GetDeviceType(params);
Guolin Ke's avatar
Guolin Ke committed
302
}
Guolin Ke's avatar
Guolin Ke committed
303
304
305

void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetBool(params, "is_unbalance", &is_unbalance);
306
  GetDouble(params, "sigmoid", &sigmoid);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
307
  GetDouble(params, "huber_delta", &huber_delta);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
308
  GetDouble(params, "fair_c", &fair_c);
309
  GetDouble(params, "gaussian_eta", &gaussian_eta);
310
  GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
Guolin Ke's avatar
Guolin Ke committed
311
312
  GetInt(params, "max_position", &max_position);
  CHECK(max_position > 0);
313
314
  GetInt(params, "num_class", &num_class);
  CHECK(num_class >= 1);
Guolin Ke's avatar
Guolin Ke committed
315
  GetDouble(params, "scale_pos_weight", &scale_pos_weight);
Guolin Ke's avatar
Guolin Ke committed
316
317
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
318
    label_gain = Common::StringToArray<double>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
319
320
321
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
322
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
323
    for (int i = 1; i < max_label; ++i) {
324
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
325
326
    }
  }
Guolin Ke's avatar
Guolin Ke committed
327
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
}


void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
332
  GetDouble(params, "sigmoid", &sigmoid);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
333
  GetDouble(params, "huber_delta", &huber_delta);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
334
  GetDouble(params, "fair_c", &fair_c);
335
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
336
337
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
338
    label_gain = Common::StringToArray<double>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
339
340
341
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
342
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
343
    for (int i = 1; i < max_label; ++i) {
344
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
345
346
    }
  }
Guolin Ke's avatar
Guolin Ke committed
347
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
348
  if (GetString(params, "ndcg_eval_at", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
349
    eval_at = Common::StringToArray<int>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
350
351
352
353
354
355
356
357
358
359
    std::sort(eval_at.begin(), eval_at.end());
    for (size_t i = 0; i < eval_at.size(); ++i) {
      CHECK(eval_at[i] > 0);
    }
  } else {
    // default eval ndcg @[1-5]
    for (int i = 1; i <= 5; ++i) {
      eval_at.push_back(i);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
360
  eval_at.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
361
362
363
364
365
}


void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
366
  GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
367
  CHECK(min_sum_hessian_in_leaf > 0 || min_data_in_leaf > 0);
368
  GetDouble(params, "lambda_l1", &lambda_l1);
369
370
371
372
373
374
  CHECK(lambda_l1 >= 0.0f);
  GetDouble(params, "lambda_l2", &lambda_l2);
  CHECK(lambda_l2 >= 0.0f);
  GetDouble(params, "min_gain_to_split", &min_gain_to_split);
  CHECK(min_gain_to_split >= 0.0f);
  GetInt(params, "num_leaves", &num_leaves);
375
  CHECK(num_leaves > 1);
Guolin Ke's avatar
Guolin Ke committed
376
  GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
377
  GetDouble(params, "feature_fraction", &feature_fraction);
378
  CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
379
  GetDouble(params, "histogram_pool_size", &histogram_pool_size);
Guolin Ke's avatar
Guolin Ke committed
380
  GetInt(params, "max_depth", &max_depth);
Guolin Ke's avatar
Guolin Ke committed
381
  GetInt(params, "top_k", &top_k);
382
383
384
  GetInt(params, "gpu_platform_id", &gpu_platform_id);
  GetInt(params, "gpu_device_id", &gpu_device_id);
  GetBool(params, "gpu_use_dp", &gpu_use_dp);
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
}

void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_iterations", &num_iterations);
Guolin Ke's avatar
Guolin Ke committed
389
  GetDouble(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
390
391
392
393
  CHECK(num_iterations >= 0);
  GetInt(params, "bagging_seed", &bagging_seed);
  GetInt(params, "bagging_freq", &bagging_freq);
  CHECK(bagging_freq >= 0);
394
  GetDouble(params, "bagging_fraction", &bagging_fraction);
395
  CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
396
  GetDouble(params, "learning_rate", &learning_rate);
397
  CHECK(learning_rate > 0.0f);
wxchan's avatar
wxchan committed
398
399
  GetInt(params, "early_stopping_round", &early_stopping_round);
  CHECK(early_stopping_round >= 0);
400
401
402
  GetInt(params, "metric_freq", &output_freq);
  CHECK(output_freq >= 0);
  GetBool(params, "is_training_metric", &is_provide_training_metric);
403
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
404
  GetInt(params, "drop_seed", &drop_seed);
405
  GetDouble(params, "drop_rate", &drop_rate);
406
407
408
409
  GetDouble(params, "skip_drop", &skip_drop);
  GetInt(params, "max_drop", &max_drop);
  GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
  GetBool(params, "uniform_drop", &uniform_drop);
Guolin Ke's avatar
Guolin Ke committed
410
411
  GetDouble(params, "top_rate", &top_rate);
  GetDouble(params, "other_rate", &other_rate);
412
  GetBool(params, "boost_from_average", &boost_from_average);
413
  CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
414
  CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
Guolin Ke's avatar
Guolin Ke committed
415
416
  device_type = GetDeviceType(params);
  tree_learner_type = GetTreeLearnerType(params);
Guolin Ke's avatar
Guolin Ke committed
417
  tree_config.Set(params);
Guolin Ke's avatar
Guolin Ke committed
418
419
420
}


421

Guolin Ke's avatar
Guolin Ke committed
422
423
424
425
426
427
428
429
430
431
432
void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_machines", &num_machines);
  CHECK(num_machines >= 1);
  GetInt(params, "local_listen_port", &local_listen_port);
  CHECK(local_listen_port > 0);
  GetInt(params, "time_out", &time_out);
  CHECK(time_out > 0);
  GetString(params, "machine_list_file", &machine_list_filename);
}

}  // namespace LightGBM