config.cpp 14.2 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

11
12
#include <limits>

Guolin Ke's avatar
Guolin Ke committed
13
14
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv) {
wxchan's avatar
wxchan committed
16
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
17
  if (tmp_strs.size() == 2 || tmp_strs.size() == 1) {
wxchan's avatar
wxchan committed
18
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
19
20
21
22
    std::string value = "";
    if (tmp_strs.size() == 2) {
      value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    }
wxchan's avatar
wxchan committed
23
    if (key.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
      auto value_search = params->find(key);
      if (value_search == params->end()) {  // not set
        params->emplace(key, value);
wxchan's avatar
wxchan committed
27
      } else {
28
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
29
30
31
32
33
34
35
36
37
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
38
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
39
  std::unordered_map<std::string, std::string> params;
40
  auto args = Common::Split(parameters, " \t\n\r");
41
  for (auto arg : args) {
Guolin Ke's avatar
Guolin Ke committed
42
    KV2Map(&params, Common::Trim(arg).c_str());
43
44
  }
  ParameterAlias::KeyAliasTransform(&params);
45
  return params;
46
47
}

Guolin Ke's avatar
Guolin Ke committed
48
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
49
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
50
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
51
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
52
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
53
      *boosting = "gbdt";
54
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
55
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
56
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
57
      *boosting = "goss";
58
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
59
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
60
    } else {
61
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
71
72
73
74
75
76
77
78
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
79
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
80
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
81
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
82
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
83
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
84
85
86
  }
}

Guolin Ke's avatar
Guolin Ke committed
87
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
88
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
89
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
90
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
91
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
92
  }
93
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
94
95
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
96
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
97
      ParseMetrics(value, metric);
98
99
    }
  }
Guolin Ke's avatar
Guolin Ke committed
100
101
}

Guolin Ke's avatar
Guolin Ke committed
102
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
103
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
104
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
105
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
106
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
107
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
108
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
109
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
110
      *task = TaskType::kPredict;
111
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
112
      *task = TaskType::kConvertModel;
113
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
114
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
115
    } else {
116
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
    }
  }
}

wxchan's avatar
wxchan committed
121
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
122
  std::string value;
123
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
124
125
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
126
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
127
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
128
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
135
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
136
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
137
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
138
139
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
140
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
141
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
142
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
143
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
144
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
145
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
146
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Belinda Trotta's avatar
Belinda Trotta committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
void Config::GetAucMuWeights() {
  if (auc_mu_weights.empty()) {
    // equal weights for all classes
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 1));
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      auc_mu_weights_matrix[i][i] = 0;
    }
  } else {
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 0));
    if (auc_mu_weights.size() != static_cast<size_t>(num_class * num_class)) {
      Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size());
    }
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      for (size_t j = 0; j < static_cast<size_t>(num_class); ++j) {
        if (i == j) {
          auc_mu_weights_matrix[i][j] = 0;
          if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) {
            Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j);
          }
        } else {
          if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) {
            Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j);
          }
          auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j];
        }
      }
    }
  }
181
}
Belinda Trotta's avatar
Belinda Trotta committed
182

183
184
185
186
187
188
189
190
void Config::GetInteractionConstraints() {
  if (interaction_constraints == "") {
    interaction_constraints_vector = std::vector<std::vector<int>>();
  } else {
    interaction_constraints_vector = Common::StringToArrayofArrays<int>(interaction_constraints, '[', ']', ',');
  }
}

Guolin Ke's avatar
Guolin Ke committed
191
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
192
193
194
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
Guolin Ke's avatar
Guolin Ke committed
195
    int int_max = std::numeric_limits<int16_t>::max();
Guolin Ke's avatar
Guolin Ke committed
196
197
198
199
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
200
    objective_seed = static_cast<int>(rand.NextShort(0, int_max));
201
    extra_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
202
203
  }

Guolin Ke's avatar
Guolin Ke committed
204
205
206
207
208
209
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
210

Guolin Ke's avatar
Guolin Ke committed
211
  GetMembersFromString(params);
212

Belinda Trotta's avatar
Belinda Trotta committed
213
214
  GetAucMuWeights();

215
216
  GetInteractionConstraints();

Guolin Ke's avatar
Guolin Ke committed
217
218
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
219

220
221
222
223
224
225
226
  std::vector<std::string> new_valid;
  for (size_t i = 0; i < valid.size(); ++i) {
    if (valid[i] != data) {
      // Only push the non-training data
      new_valid.push_back(valid[i]);
    } else {
      is_provide_training_metric = true;
227
228
    }
  }
229
  valid = new_valid;
230

Guolin Ke's avatar
Guolin Ke committed
231
232
  // check for conflicts
  CheckParamConflict();
233

Guolin Ke's avatar
Guolin Ke committed
234
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
235
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
236
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
237
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
238
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
239
240
241
242
243
244
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
245
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
246
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
247
248
}

Guolin Ke's avatar
Guolin Ke committed
249
250
251
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
252
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
253

254
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
255
256
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
257
258
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
259
    if (task == TaskType::kTrain && num_class_check != 1) {
260
261
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
262
  }
Guolin Ke's avatar
Guolin Ke committed
263
  for (std::string metric_type : metric) {
264
265
266
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Belinda Trotta's avatar
Belinda Trotta committed
267
                                   || metric_type == std::string("auc_mu")
Guolin Ke's avatar
Guolin Ke committed
268
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
269
    if ((objective_type_multiclass && !metric_type_multiclass)
270
271
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
272
    }
273
  }
274

Guolin Ke's avatar
Guolin Ke committed
275
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
276
277
278
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
279
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
280
281
  }

Guolin Ke's avatar
Guolin Ke committed
282
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
283
284

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
285
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
286
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
287
288
  }

Guolin Ke's avatar
Guolin Ke committed
289
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
290
    is_data_based_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
291
292
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
293
    is_data_based_parallel = true;
Guolin Ke's avatar
Guolin Ke committed
294
295
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
296
297
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
298
                   histogram_pool_size);
tks's avatar
tks committed
299
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
300
      histogram_pool_size = -1;
301
    }
Guolin Ke's avatar
Guolin Ke committed
302
  }
303
304
305
306
307
308
  if (is_data_based_parallel) {
    if (!forcedsplits_filename.empty()) {
      Log::Fatal("Don't support forcedsplits in %s tree learner",
                 tree_learner.c_str());
    }
  }
309
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
310
  if (max_depth > 0) {
311
    double full_num_leaves = std::pow(2, max_depth);
312
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
313
        && num_leaves == kDefaultNumLeaves) {
314
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
315
    }
316
317
318
319
320

    if (full_num_leaves < num_leaves) {
      // Fits in an int, and is more restrictive than the current num_leaves
      num_leaves = static_cast<int>(full_num_leaves);
    }
321
  }
322
323
324
325
326
  // force col-wise for gpu
  if (device_type == std::string("gpu")) {
    force_col_wise = true;
    force_row_wise = false;
  }
Belinda Trotta's avatar
Belinda Trotta committed
327
328
329
330
331
332
333
334
  // min_data_in_leaf must be at least 2 if path smoothing is active. This is because when the split is calculated
  // the count is calculated using the proportion of hessian in the leaf which is rounded up to nearest int, so it can
  // be 1 when there is actually no data in the leaf. In rare cases this can cause a bug because with path smoothing the
  // calculated split gain can be positive even with zero gradient and hessian.
  if (path_smooth > kEpsilon && min_data_in_leaf < 2) {
    min_data_in_leaf = 2;
    Log::Warning("min_data_in_leaf has been increased to 2 because this is required when path smoothing is active.");
  }
335
336
337
338
339
340
341
342
343
344
345
  if (is_parallel && monotone_constraints_method == std::string("intermediate")) {
    // In distributed mode, local node doesn't have histograms on all features, cannot perform "intermediate" monotone constraints.
    Log::Warning("Cannot use \"intermediate\" monotone constraints in parallel learning, auto set to \"basic\" method.");
    monotone_constraints_method = "basic";
  }
  if (feature_fraction_bynode != 1.0 && monotone_constraints_method == std::string("intermediate")) {
    // "intermediate" monotone constraints need to recompute splits. If the features are sampled when computing the
    // split initially, then the sampling needs to be recorded or done once again, which is currently not supported
    Log::Warning("Cannot use \"intermediate\" monotone constraints with feature fraction different from 1, auto set monotone constraints to \"basic\" method.");
    monotone_constraints_method = "basic";
  }
346
347
348
  if (max_depth > 0 && monotone_penalty >= max_depth) {
    Log::Warning("Monotone penalty greater than tree depth. Monotone features won't be used.");
  }
Guolin Ke's avatar
Guolin Ke committed
349
350
}

Guolin Ke's avatar
Guolin Ke committed
351
352
353
354
355
356
357
358
359
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
360
361
362
}

}  // namespace LightGBM