config.cpp 13.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv) {
wxchan's avatar
wxchan committed
16
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
17
  if (tmp_strs.size() == 2 || tmp_strs.size() == 1) {
wxchan's avatar
wxchan committed
18
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
19
20
21
22
    std::string value = "";
    if (tmp_strs.size() == 2) {
      value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    }
23
24
25
    if (!Common::CheckASCII(key) || !Common::CheckASCII(value)) {
      Log::Fatal("Do not support non-ascii characters in config.");
    }
wxchan's avatar
wxchan committed
26
    if (key.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
27
28
29
      auto value_search = params->find(key);
      if (value_search == params->end()) {  // not set
        params->emplace(key, value);
wxchan's avatar
wxchan committed
30
      } else {
31
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
32
33
34
35
36
37
38
39
40
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
41
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
42
  std::unordered_map<std::string, std::string> params;
43
  auto args = Common::Split(parameters, " \t\n\r");
44
  for (auto arg : args) {
Guolin Ke's avatar
Guolin Ke committed
45
    KV2Map(&params, Common::Trim(arg).c_str());
46
47
  }
  ParameterAlias::KeyAliasTransform(&params);
48
  return params;
49
50
}

Guolin Ke's avatar
Guolin Ke committed
51
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
52
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
53
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
54
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
55
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
56
      *boosting = "gbdt";
57
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
58
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
59
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
60
      *boosting = "goss";
61
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
62
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
63
    } else {
64
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
std::string ParseObjectiveAlias(const std::string& type) {
  if (type == std::string("regression") || type == std::string("regression_l2")
    || type == std::string("mean_squared_error") || type == std::string("mse") || type == std::string("l2")
    || type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
    return "regression";
  } else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error")
    || type == std::string("l1") || type == std::string("mae")) {
    return "regression_l1";
  } else if (type == std::string("multiclass") || type == std::string("softmax")) {
    return "multiclass";
  } else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
    return "multiclassova";
  } else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
    return "cross_entropy";
  } else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
    return "cross_entropy_lambda";
  } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
    return "mape";
  } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
    return "custom";
  }
  return type;
}

std::string ParseMetricAlias(const std::string& type) {
  if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
    return "l2";
  } else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
    return "rmse";
  } else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
    return "l1";
  } else if (type == std::string("binary_logloss") || type == std::string("binary")) {
    return "binary_logloss";
  } else if (type == std::string("ndcg") || type == std::string("lambdarank")) {
    return "ndcg";
  } else if (type == std::string("map") || type == std::string("mean_average_precision")) {
    return "map";
  } else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("softmax") || type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
    return "multi_logloss";
  } else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
    return "cross_entropy";
  } else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
    return "cross_entropy_lambda";
  } else if (type == std::string("kldiv") || type == std::string("kullback_leibler")) {
    return "kullback_leibler";
  } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
    return "mape";
  } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
    return "custom";
  }
  return type;
}

void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
135
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
136
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
137
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
138
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
139
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
140
141
142
  }
}

Guolin Ke's avatar
Guolin Ke committed
143
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
144
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
145
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
146
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
147
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
148
  }
149
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
150
151
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
152
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
153
      ParseMetrics(value, metric);
154
155
    }
  }
Guolin Ke's avatar
Guolin Ke committed
156
157
}

Guolin Ke's avatar
Guolin Ke committed
158
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
159
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
160
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
161
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
162
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
163
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
164
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
165
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
166
      *task = TaskType::kPredict;
167
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
168
      *task = TaskType::kConvertModel;
169
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
170
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
171
    } else {
172
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
    }
  }
}

wxchan's avatar
wxchan committed
177
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
178
  std::string value;
179
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
180
181
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
182
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
183
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
184
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
185
186
187
188
189
190
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
191
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
192
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
193
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
194
195
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
196
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
197
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
198
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
199
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
200
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
201
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
202
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
203
204
205
206
207
208
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
209
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
210
211
212
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
Guolin Ke's avatar
Guolin Ke committed
213
    int int_max = std::numeric_limits<int16_t>::max();
Guolin Ke's avatar
Guolin Ke committed
214
215
216
217
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
218
219
  }

Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
224
225
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
226

Guolin Ke's avatar
Guolin Ke committed
227
  GetMembersFromString(params);
228

Guolin Ke's avatar
Guolin Ke committed
229
230
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
231

Guolin Ke's avatar
Guolin Ke committed
232
233
234
  if (valid_data_initscores.size() == 0 && valid.size() > 0) {
    valid_data_initscores = std::vector<std::string>(valid.size(), "");
  }
235
  CHECK(valid.size() == valid_data_initscores.size());
Guolin Ke's avatar
Guolin Ke committed
236
237
238

  // check for conflicts
  CheckParamConflict();
239

Guolin Ke's avatar
Guolin Ke committed
240
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
241
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
242
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
243
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
244
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
245
246
247
248
249
250
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
251
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
252
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
253
254
}

Guolin Ke's avatar
Guolin Ke committed
255
256
257
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
258
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
259

260
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
261
262
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
263
264
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
265
    if (task == TaskType::kTrain && num_class_check != 1) {
266
267
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
268
  }
Guolin Ke's avatar
Guolin Ke committed
269
  for (std::string metric_type : metric) {
270
271
272
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Guolin Ke's avatar
Guolin Ke committed
273
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
274
    if ((objective_type_multiclass && !metric_type_multiclass)
275
276
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
277
    }
278
  }
279

Guolin Ke's avatar
Guolin Ke committed
280
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
281
282
283
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
284
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
285
286
  }

Guolin Ke's avatar
Guolin Ke committed
287
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
288
289

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
290
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
291
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
292
293
  }

Guolin Ke's avatar
Guolin Ke committed
294
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
295
    is_parallel_find_bin = false;
Guolin Ke's avatar
Guolin Ke committed
296
297
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
298
    is_parallel_find_bin = true;
Guolin Ke's avatar
Guolin Ke committed
299
300
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
301
302
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
303
                   histogram_pool_size);
tks's avatar
tks committed
304
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
305
      histogram_pool_size = -1;
306
    }
Guolin Ke's avatar
Guolin Ke committed
307
  }
308
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
309
310
  if (max_depth > 0) {
    int full_num_leaves = static_cast<int>(std::pow(2, max_depth));
311
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
312
        && num_leaves == kDefaultNumLeaves) {
313
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
314
    }
315
    num_leaves = std::min(num_leaves, 2 << max_depth);
316
  }
Guolin Ke's avatar
Guolin Ke committed
317
318
}

Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
323
324
325
326
327
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
328
329
330
}

}  // namespace LightGBM