config.cpp 13.8 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

11
12
#include <limits>

Guolin Ke's avatar
Guolin Ke committed
13
14
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv) {
wxchan's avatar
wxchan committed
16
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
17
  if (tmp_strs.size() == 2 || tmp_strs.size() == 1) {
wxchan's avatar
wxchan committed
18
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
19
20
21
22
    std::string value = "";
    if (tmp_strs.size() == 2) {
      value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    }
wxchan's avatar
wxchan committed
23
    if (key.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
      auto value_search = params->find(key);
      if (value_search == params->end()) {  // not set
        params->emplace(key, value);
wxchan's avatar
wxchan committed
27
      } else {
28
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
29
30
31
32
33
34
35
36
37
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
38
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
39
  std::unordered_map<std::string, std::string> params;
40
  auto args = Common::Split(parameters, " \t\n\r");
41
  for (auto arg : args) {
Guolin Ke's avatar
Guolin Ke committed
42
    KV2Map(&params, Common::Trim(arg).c_str());
43
44
  }
  ParameterAlias::KeyAliasTransform(&params);
45
  return params;
46
47
}

Guolin Ke's avatar
Guolin Ke committed
48
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
49
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
50
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
51
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
52
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
53
      *boosting = "gbdt";
54
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
55
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
56
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
57
      *boosting = "goss";
58
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
59
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
60
    } else {
61
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
71
72
73
74
75
76
77
78
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
79
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
80
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
81
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
82
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
83
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
84
85
86
  }
}

Guolin Ke's avatar
Guolin Ke committed
87
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
88
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
89
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
90
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
91
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
92
  }
93
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
94
95
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
96
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
97
      ParseMetrics(value, metric);
98
99
    }
  }
Guolin Ke's avatar
Guolin Ke committed
100
101
}

Guolin Ke's avatar
Guolin Ke committed
102
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
103
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
104
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
105
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
106
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
107
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
108
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
109
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
110
      *task = TaskType::kPredict;
111
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
112
      *task = TaskType::kConvertModel;
113
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
114
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
115
    } else {
116
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
    }
  }
}

wxchan's avatar
wxchan committed
121
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
122
  std::string value;
123
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
124
125
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
126
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
127
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
128
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
135
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
136
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
137
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
138
139
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
140
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
141
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
142
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
143
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
144
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
145
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
146
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Belinda Trotta's avatar
Belinda Trotta committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
void Config::GetAucMuWeights() {
  if (auc_mu_weights.empty()) {
    // equal weights for all classes
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 1));
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      auc_mu_weights_matrix[i][i] = 0;
    }
  } else {
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 0));
    if (auc_mu_weights.size() != static_cast<size_t>(num_class * num_class)) {
      Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size());
    }
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      for (size_t j = 0; j < static_cast<size_t>(num_class); ++j) {
        if (i == j) {
          auc_mu_weights_matrix[i][j] = 0;
          if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) {
            Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j);
          }
        } else {
          if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) {
            Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j);
          }
          auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j];
        }
      }
    }
  }
181
}
Belinda Trotta's avatar
Belinda Trotta committed
182

Guolin Ke's avatar
Guolin Ke committed
183
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
184
185
186
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
Guolin Ke's avatar
Guolin Ke committed
187
    int int_max = std::numeric_limits<int16_t>::max();
Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
192
    objective_seed = static_cast<int>(rand.NextShort(0, int_max));
193
    extra_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
194
195
  }

Guolin Ke's avatar
Guolin Ke committed
196
197
198
199
200
201
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
202

Guolin Ke's avatar
Guolin Ke committed
203
  GetMembersFromString(params);
204

Belinda Trotta's avatar
Belinda Trotta committed
205
206
  GetAucMuWeights();

Guolin Ke's avatar
Guolin Ke committed
207
208
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
209

210
211
212
213
214
215
216
  std::vector<std::string> new_valid;
  for (size_t i = 0; i < valid.size(); ++i) {
    if (valid[i] != data) {
      // Only push the non-training data
      new_valid.push_back(valid[i]);
    } else {
      is_provide_training_metric = true;
217
218
    }
  }
219
  valid = new_valid;
220

Guolin Ke's avatar
Guolin Ke committed
221
222
  // check for conflicts
  CheckParamConflict();
223

Guolin Ke's avatar
Guolin Ke committed
224
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
225
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
226
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
227
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
228
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
229
230
231
232
233
234
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
235
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
236
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
237
238
}

Guolin Ke's avatar
Guolin Ke committed
239
240
241
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
242
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
243

244
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
245
246
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
247
248
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
249
    if (task == TaskType::kTrain && num_class_check != 1) {
250
251
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
252
  }
Guolin Ke's avatar
Guolin Ke committed
253
  for (std::string metric_type : metric) {
254
255
256
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Belinda Trotta's avatar
Belinda Trotta committed
257
                                   || metric_type == std::string("auc_mu")
Guolin Ke's avatar
Guolin Ke committed
258
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
259
    if ((objective_type_multiclass && !metric_type_multiclass)
260
261
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
262
    }
263
  }
264

Guolin Ke's avatar
Guolin Ke committed
265
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
266
267
268
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
269
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
270
271
  }

Guolin Ke's avatar
Guolin Ke committed
272
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
273
274

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
275
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
276
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
277
278
  }

Guolin Ke's avatar
Guolin Ke committed
279
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
280
    is_data_based_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
281
282
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
283
    is_data_based_parallel = true;
Guolin Ke's avatar
Guolin Ke committed
284
285
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
286
287
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
288
                   histogram_pool_size);
tks's avatar
tks committed
289
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
290
      histogram_pool_size = -1;
291
    }
Guolin Ke's avatar
Guolin Ke committed
292
  }
293
294
295
296
297
298
  if (is_data_based_parallel) {
    if (!forcedsplits_filename.empty()) {
      Log::Fatal("Don't support forcedsplits in %s tree learner",
                 tree_learner.c_str());
    }
  }
299
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
300
  if (max_depth > 0) {
301
    double full_num_leaves = std::pow(2, max_depth);
302
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
303
        && num_leaves == kDefaultNumLeaves) {
304
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
305
    }
306
307
308
309
310

    if (full_num_leaves < num_leaves) {
      // Fits in an int, and is more restrictive than the current num_leaves
      num_leaves = static_cast<int>(full_num_leaves);
    }
311
  }
312
313
314
315
316
  // force col-wise for gpu
  if (device_type == std::string("gpu")) {
    force_col_wise = true;
    force_row_wise = false;
  }
Belinda Trotta's avatar
Belinda Trotta committed
317
318
319
320
321
322
323
324
  // min_data_in_leaf must be at least 2 if path smoothing is active. This is because when the split is calculated
  // the count is calculated using the proportion of hessian in the leaf which is rounded up to nearest int, so it can
  // be 1 when there is actually no data in the leaf. In rare cases this can cause a bug because with path smoothing the
  // calculated split gain can be positive even with zero gradient and hessian.
  if (path_smooth > kEpsilon && min_data_in_leaf < 2) {
    min_data_in_leaf = 2;
    Log::Warning("min_data_in_leaf has been increased to 2 because this is required when path smoothing is active.");
  }
325
326
327
328
329
330
331
332
333
334
335
  if (is_parallel && monotone_constraints_method == std::string("intermediate")) {
    // In distributed mode, local node doesn't have histograms on all features, cannot perform "intermediate" monotone constraints.
    Log::Warning("Cannot use \"intermediate\" monotone constraints in parallel learning, auto set to \"basic\" method.");
    monotone_constraints_method = "basic";
  }
  if (feature_fraction_bynode != 1.0 && monotone_constraints_method == std::string("intermediate")) {
    // "intermediate" monotone constraints need to recompute splits. If the features are sampled when computing the
    // split initially, then the sampling needs to be recorded or done once again, which is currently not supported
    Log::Warning("Cannot use \"intermediate\" monotone constraints with feature fraction different from 1, auto set monotone constraints to \"basic\" method.");
    monotone_constraints_method = "basic";
  }
336
337
338
  if (max_depth > 0 && monotone_penalty >= max_depth) {
    Log::Warning("Monotone penalty greater than tree depth. Monotone features won't be used.");
  }
Guolin Ke's avatar
Guolin Ke committed
339
340
}

Guolin Ke's avatar
Guolin Ke committed
341
342
343
344
345
346
347
348
349
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
350
351
352
}

}  // namespace LightGBM