"python-package/vscode:/vscode.git/clone" did not exist on "619c06d8f85149cd437df96cf9de9f22f7d48e8d"
config.cpp 15.9 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
4
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/utils/log.h>

#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
9
#include <unordered_set>
Guolin Ke's avatar
Guolin Ke committed
10
#include <algorithm>
Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

15
std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* parameters) {
16
  std::unordered_map<std::string, std::string> params;
17
  auto args = Common::Split(parameters, " \t\n\r");
18
19
20
21
22
23
24
25
26
  for (auto arg : args) {
    std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
      std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
      if (key.size() <= 0) {
        continue;
      }
      params[key] = value;
27
    } else if (Common::Trim(arg).size() > 0) {
Qiwei Ye's avatar
Qiwei Ye committed
28
      Log::Warning("Unknown parameter %s", arg.c_str());
29
30
31
    }
  }
  ParameterAlias::KeyAliasTransform(&params);
32
  return params;
33
34
}

Guolin Ke's avatar
Guolin Ke committed
35
std::string GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
36
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
37
38
  std::string boosting_type = kDefaultBoostingType;
  if (ConfigBase::GetString(params, "boosting_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
39
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
40
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
41
      boosting_type = "gbdt";
42
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
43
      boosting_type = "dart";
Guolin Ke's avatar
Guolin Ke committed
44
45
    } else if (value == std::string("goss")) {
      boosting_type = "goss";
Guolin Ke's avatar
Guolin Ke committed
46
    } else {
47
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
48
49
    }
  }
Guolin Ke's avatar
Guolin Ke committed
50
  return boosting_type;
Guolin Ke's avatar
Guolin Ke committed
51
52
}

Guolin Ke's avatar
Guolin Ke committed
53
std::string GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
54
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
55
56
  std::string objective_type = kDefaultObjectiveType;
  if (ConfigBase::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
57
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
58
59
    objective_type = value;
  }
Guolin Ke's avatar
Guolin Ke committed
60
  return objective_type;
Guolin Ke's avatar
Guolin Ke committed
61
62
}

Guolin Ke's avatar
Guolin Ke committed
63
std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
64
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
65
66
  std::vector<std::string> metric_types;
  if (ConfigBase::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
67
68
69
    // clear old metrics
    metric_types.clear();
    // to lower
Guolin Ke's avatar
Guolin Ke committed
70
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
71
72
    // split
    std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
73
    // remove duplicate
Guolin Ke's avatar
Guolin Ke committed
74
    std::unordered_set<std::string> metric_sets;
Guolin Ke's avatar
Guolin Ke committed
75
    for (auto& metric : metrics) {
Guolin Ke's avatar
Guolin Ke committed
76
      std::transform(metric.begin(), metric.end(), metric.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
77
78
      if (metric_sets.count(metric) <= 0) {
        metric_sets.insert(metric);
Guolin Ke's avatar
Guolin Ke committed
79
80
      }
    }
Guolin Ke's avatar
Guolin Ke committed
81
82
    for (auto& metric : metric_sets) {
      metric_types.push_back(metric);
Guolin Ke's avatar
Guolin Ke committed
83
    }
Guolin Ke's avatar
Guolin Ke committed
84
    metric_types.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
85
  }
Guolin Ke's avatar
Guolin Ke committed
86
  return metric_types;
Guolin Ke's avatar
Guolin Ke committed
87
88
}

Guolin Ke's avatar
Guolin Ke committed
89
TaskType GetTaskType(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
90
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
91
92
  TaskType task_type = TaskType::kTrain;
  if (ConfigBase::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
93
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
94
95
96
    if (value == std::string("train") || value == std::string("training")) {
      task_type = TaskType::kTrain;
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
97
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
98
      task_type = TaskType::kPredict;
99
100
    } else if (value == std::string("convert_model")) {
      task_type = TaskType::kConvertModel;
Guolin Ke's avatar
Guolin Ke committed
101
    } else {
102
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
103
104
    }
  }
Guolin Ke's avatar
Guolin Ke committed
105
  return task_type;
Guolin Ke's avatar
Guolin Ke committed
106
107
}

Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
std::string GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  std::string device_type = kDefaultDevice;
  if (ConfigBase::GetString(params, "device", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
      device_type = "cpu";
    } else if (value == std::string("gpu")) {
      device_type = "gpu";
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
  return device_type;
}

std::string GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
  std::string value;
  std::string tree_learner_type = kDefaultTreeLearnerType;
  if (ConfigBase::GetString(params, "tree_learner", &value)) {
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
      tree_learner_type = "serial";
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
      tree_learner_type = "feature";
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
      tree_learner_type = "data";
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
      tree_learner_type = "voting";
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
  return tree_learner_type;
}

void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  // load main config types
  GetInt(params, "num_threads", &num_threads);
  GetString(params, "convert_model_language", &convert_model_language);

  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
    int int_max = std::numeric_limits<short>::max();
    io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
  }
  task_type = GetTaskType(params);
  boosting_type = GetBoostingType(params);

  metric_types = GetMetricType(params);

  // sub-config setup
  network_config.Set(params);
  io_config.Set(params);

  boosting_config.Set(params);
  objective_type = GetObjectiveType(params);
  objective_config.Set(params);
  metric_config.Set(params);

  // check for conflicts
  CheckParamConflict();
174

Guolin Ke's avatar
Guolin Ke committed
175
176
177
178
179
180
181
182
183
184
185
186
  if (io_config.verbosity == 1) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
  } else if (io_config.verbosity == 0) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
  } else if (io_config.verbosity >= 2) {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

void OverallConfig::CheckParamConflict() {
187
  // check if objective_type, metric_type, and num_class match
Guolin Ke's avatar
Guolin Ke committed
188
189
  bool objective_type_multiclass = (objective_type == std::string("multiclass") 
                                    || objective_type == std::string("multiclassova"));
Guolin Ke's avatar
Guolin Ke committed
190
  int num_class_check = boosting_config.num_class;
191
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
192
193
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
194
195
196
197
198
    }
  } else {
    if (task_type == TaskType::kTrain && num_class_check != 1) {
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
199
  }
wxchan's avatar
wxchan committed
200
201
  if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
    for (std::string metric_type : metric_types) {
Guolin Ke's avatar
Guolin Ke committed
202
      bool metric_type_multiclass = (metric_type == std::string("multi_logloss") 
203
                                     || metric_type == std::string("multi_error"));
wxchan's avatar
wxchan committed
204
205
206
207
      if ((objective_type_multiclass && !metric_type_multiclass)
        || (!objective_type_multiclass && metric_type_multiclass)) {
        Log::Fatal("Objective and metrics don't match");
      }
208
    }
209
  }
Guolin Ke's avatar
Guolin Ke committed
210
  
Guolin Ke's avatar
Guolin Ke committed
211
212
213
214
  if (network_config.num_machines > 1) {
    is_parallel = true;
  } else {
    is_parallel = false;
215
    boosting_config.tree_learner_type = "serial";
Guolin Ke's avatar
Guolin Ke committed
216
217
  }

Guolin Ke's avatar
Guolin Ke committed
218
219
220
  bool is_single_tree_learner = boosting_config.tree_learner_type == std::string("serial");

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
221
222
223
224
    is_parallel = false;
    network_config.num_machines = 1;
  }

Guolin Ke's avatar
Guolin Ke committed
225
  if (is_single_tree_learner || boosting_config.tree_learner_type == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
226
    is_parallel_find_bin = false;
227
228
  } else if (boosting_config.tree_learner_type == std::string("data")
             || boosting_config.tree_learner_type == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
229
    is_parallel_find_bin = true;
230
231
    if (boosting_config.tree_config.histogram_pool_size >= 0 
        && boosting_config.tree_learner_type == std::string("data")) {
232
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
233
        , boosting_config.tree_config.histogram_pool_size);
tks's avatar
tks committed
234
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
235
      boosting_config.tree_config.histogram_pool_size = -1;
236
    }
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
241
242
  }
}

void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "max_bin", &max_bin);
  CHECK(max_bin > 0);
243
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
244
  GetInt(params, "data_random_seed", &data_random_seed);
245
  GetString(params, "data", &data_filename);
Qiwei Ye's avatar
Qiwei Ye committed
246
  GetInt(params, "verbose", &verbosity);
Guolin Ke's avatar
Guolin Ke committed
247
  GetInt(params, "num_iteration_predict", &num_iteration_predict);
Guolin Ke's avatar
Guolin Ke committed
248
  GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
249
250
  GetBool(params, "is_pre_partition", &is_pre_partition);
  GetBool(params, "is_enable_sparse", &is_enable_sparse);
251
  GetDouble(params, "sparse_threshold", &sparse_threshold);
Guolin Ke's avatar
Guolin Ke committed
252
253
  GetBool(params, "use_two_round_loading", &use_two_round_loading);
  GetBool(params, "is_save_binary_file", &is_save_binary_file);
Guolin Ke's avatar
Guolin Ke committed
254
  GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
Guolin Ke's avatar
Guolin Ke committed
255
256
  GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
  GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
Guolin Ke's avatar
Guolin Ke committed
257
  GetInt(params, "snapshot_freq", &snapshot_freq);
Guolin Ke's avatar
Guolin Ke committed
258
259
  GetString(params, "output_model", &output_model);
  GetString(params, "input_model", &input_model);
260
  GetString(params, "convert_model", &convert_model);
Guolin Ke's avatar
Guolin Ke committed
261
262
263
264
265
  GetString(params, "output_result", &output_result);
  std::string tmp_str = "";
  if (GetString(params, "valid_data", &tmp_str)) {
    valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
  }
Guolin Ke's avatar
Guolin Ke committed
266
267
268
269
270
  GetBool(params, "has_header", &has_header);
  GetString(params, "label_column", &label_column);
  GetString(params, "weight_column", &weight_column);
  GetString(params, "group_column", &group_column);
  GetString(params, "ignore_column", &ignore_column);
271
  GetString(params, "categorical_column", &categorical_column);
Guolin Ke's avatar
Guolin Ke committed
272
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
273
274
  GetInt(params, "min_data_in_bin", &min_data_in_bin);
  CHECK(min_data_in_bin > 0);
Guolin Ke's avatar
Guolin Ke committed
275
276
  GetDouble(params, "max_conflict_rate", &max_conflict_rate);
  GetBool(params, "enable_bundle", &enable_bundle);
277
278
279
280
281

  GetBool(params, "pred_early_stop", &pred_early_stop);
  GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
  GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);

Guolin Ke's avatar
Guolin Ke committed
282
  device_type = GetDeviceType(params);
Guolin Ke's avatar
Guolin Ke committed
283
}
Guolin Ke's avatar
Guolin Ke committed
284
285
286

void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetBool(params, "is_unbalance", &is_unbalance);
287
  GetDouble(params, "sigmoid", &sigmoid);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
288
  GetDouble(params, "huber_delta", &huber_delta);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
289
  GetDouble(params, "fair_c", &fair_c);
290
  GetDouble(params, "gaussian_eta", &gaussian_eta);
291
  GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
Guolin Ke's avatar
Guolin Ke committed
292
293
  GetInt(params, "max_position", &max_position);
  CHECK(max_position > 0);
294
295
  GetInt(params, "num_class", &num_class);
  CHECK(num_class >= 1);
Guolin Ke's avatar
Guolin Ke committed
296
  GetDouble(params, "scale_pos_weight", &scale_pos_weight);
Guolin Ke's avatar
Guolin Ke committed
297
298
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
299
    label_gain = Common::StringToArray<double>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
300
301
302
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
303
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
304
    for (int i = 1; i < max_label; ++i) {
305
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
306
307
    }
  }
Guolin Ke's avatar
Guolin Ke committed
308
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
}


void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
313
  GetDouble(params, "sigmoid", &sigmoid);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
314
  GetDouble(params, "huber_delta", &huber_delta);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
315
  GetDouble(params, "fair_c", &fair_c);
316
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
317
318
  std::string tmp_str = "";
  if (GetString(params, "label_gain", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
319
    label_gain = Common::StringToArray<double>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
320
321
322
  } else {
    // label_gain = 2^i - 1, may overflow, so we use 31 here
    const int max_label = 31;
323
    label_gain.push_back(0.0f);
Guolin Ke's avatar
Guolin Ke committed
324
    for (int i = 1; i < max_label; ++i) {
325
      label_gain.push_back(static_cast<double>((1 << i) - 1));
Guolin Ke's avatar
Guolin Ke committed
326
327
    }
  }
Guolin Ke's avatar
Guolin Ke committed
328
  label_gain.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
329
  if (GetString(params, "ndcg_eval_at", &tmp_str)) {
Guolin Ke's avatar
Guolin Ke committed
330
    eval_at = Common::StringToArray<int>(tmp_str, ',');
Guolin Ke's avatar
Guolin Ke committed
331
332
333
334
335
336
337
338
339
340
    std::sort(eval_at.begin(), eval_at.end());
    for (size_t i = 0; i < eval_at.size(); ++i) {
      CHECK(eval_at[i] > 0);
    }
  } else {
    // default eval ndcg @[1-5]
    for (int i = 1; i <= 5; ++i) {
      eval_at.push_back(i);
    }
  }
Guolin Ke's avatar
Guolin Ke committed
341
  eval_at.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
342
343
344
345
346
}


void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
347
  GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
348
  CHECK(min_sum_hessian_in_leaf > 0 || min_data_in_leaf > 0);
349
  GetDouble(params, "lambda_l1", &lambda_l1);
350
351
352
353
354
355
  CHECK(lambda_l1 >= 0.0f);
  GetDouble(params, "lambda_l2", &lambda_l2);
  CHECK(lambda_l2 >= 0.0f);
  GetDouble(params, "min_gain_to_split", &min_gain_to_split);
  CHECK(min_gain_to_split >= 0.0f);
  GetInt(params, "num_leaves", &num_leaves);
356
  CHECK(num_leaves > 1);
Guolin Ke's avatar
Guolin Ke committed
357
  GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
358
  GetDouble(params, "feature_fraction", &feature_fraction);
359
  CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
360
  GetDouble(params, "histogram_pool_size", &histogram_pool_size);
Guolin Ke's avatar
Guolin Ke committed
361
  GetInt(params, "max_depth", &max_depth);
Guolin Ke's avatar
Guolin Ke committed
362
  GetInt(params, "top_k", &top_k);
363
364
365
  GetInt(params, "gpu_platform_id", &gpu_platform_id);
  GetInt(params, "gpu_device_id", &gpu_device_id);
  GetBool(params, "gpu_use_dp", &gpu_use_dp);
366
  GetBool(params, "use_missing", &use_missing);
Guolin Ke's avatar
Guolin Ke committed
367
368
369
370
}

void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_iterations", &num_iterations);
Guolin Ke's avatar
Guolin Ke committed
371
  GetDouble(params, "sigmoid", &sigmoid);
Guolin Ke's avatar
Guolin Ke committed
372
373
374
375
  CHECK(num_iterations >= 0);
  GetInt(params, "bagging_seed", &bagging_seed);
  GetInt(params, "bagging_freq", &bagging_freq);
  CHECK(bagging_freq >= 0);
376
  GetDouble(params, "bagging_fraction", &bagging_fraction);
377
  CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
378
  GetDouble(params, "learning_rate", &learning_rate);
379
  CHECK(learning_rate > 0.0f);
wxchan's avatar
wxchan committed
380
381
  GetInt(params, "early_stopping_round", &early_stopping_round);
  CHECK(early_stopping_round >= 0);
382
383
384
  GetInt(params, "metric_freq", &output_freq);
  CHECK(output_freq >= 0);
  GetBool(params, "is_training_metric", &is_provide_training_metric);
385
  GetInt(params, "num_class", &num_class);
Guolin Ke's avatar
Guolin Ke committed
386
  GetInt(params, "drop_seed", &drop_seed);
387
  GetDouble(params, "drop_rate", &drop_rate);
388
389
390
391
  GetDouble(params, "skip_drop", &skip_drop);
  GetInt(params, "max_drop", &max_drop);
  GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
  GetBool(params, "uniform_drop", &uniform_drop);
Guolin Ke's avatar
Guolin Ke committed
392
393
  GetDouble(params, "top_rate", &top_rate);
  GetDouble(params, "other_rate", &other_rate);
394
  GetBool(params, "boost_from_average", &boost_from_average);
395
  CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
396
  CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
Guolin Ke's avatar
Guolin Ke committed
397
398
  device_type = GetDeviceType(params);
  tree_learner_type = GetTreeLearnerType(params);
Guolin Ke's avatar
Guolin Ke committed
399
  tree_config.Set(params);
Guolin Ke's avatar
Guolin Ke committed
400
401
402
}


403

Guolin Ke's avatar
Guolin Ke committed
404
405
406
407
408
409
410
411
412
413
414
void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
  GetInt(params, "num_machines", &num_machines);
  CHECK(num_machines >= 1);
  GetInt(params, "local_listen_port", &local_listen_port);
  CHECK(local_listen_port > 0);
  GetInt(params, "time_out", &time_out);
  CHECK(time_out > 0);
  GetString(params, "machine_list_file", &machine_list_filename);
}

}  // namespace LightGBM