config.cpp 13.3 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>& params, const char* kv) {
wxchan's avatar
wxchan committed
16
17
18
19
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
  if (tmp_strs.size() == 2) {
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
    std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
20
21
22
    if (!Common::CheckASCII(key) || !Common::CheckASCII(value)) {
      Log::Fatal("Do not support non-ascii characters in config.");
    }
wxchan's avatar
wxchan committed
23
24
    if (key.size() > 0) {
      auto value_search = params.find(key);
25
      if (value_search == params.end()) {  // not set
wxchan's avatar
wxchan committed
26
27
        params.emplace(key, value);
      } else {
28
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
29
30
31
32
33
34
35
36
37
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
38
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
39
  std::unordered_map<std::string, std::string> params;
40
  auto args = Common::Split(parameters, " \t\n\r");
41
  for (auto arg : args) {
wxchan's avatar
wxchan committed
42
    KV2Map(params, Common::Trim(arg).c_str());
43
44
  }
  ParameterAlias::KeyAliasTransform(&params);
45
  return params;
46
47
}

Guolin Ke's avatar
Guolin Ke committed
48
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
49
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
50
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
51
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
52
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
53
      *boosting = "gbdt";
54
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
55
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
56
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
57
      *boosting = "goss";
58
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
59
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
60
    } else {
61
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
std::string ParseObjectiveAlias(const std::string& type) {
  if (type == std::string("regression") || type == std::string("regression_l2")
    || type == std::string("mean_squared_error") || type == std::string("mse") || type == std::string("l2")
    || type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
    return "regression";
  } else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error")
    || type == std::string("l1") || type == std::string("mae")) {
    return "regression_l1";
  } else if (type == std::string("multiclass") || type == std::string("softmax")) {
    return "multiclass";
  } else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
    return "multiclassova";
  } else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
    return "cross_entropy";
  } else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
    return "cross_entropy_lambda";
  } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
    return "mape";
  } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
    return "custom";
  }
  return type;
}

std::string ParseMetricAlias(const std::string& type) {
  if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
    return "l2";
  } else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
    return "rmse";
  } else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
    return "l1";
  } else if (type == std::string("binary_logloss") || type == std::string("binary")) {
    return "binary_logloss";
  } else if (type == std::string("ndcg") || type == std::string("lambdarank")) {
    return "ndcg";
  } else if (type == std::string("map") || type == std::string("mean_average_precision")) {
    return "map";
  } else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("softmax") || type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
    return "multi_logloss";
  } else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
    return "cross_entropy";
  } else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
    return "cross_entropy_lambda";
  } else if (type == std::string("kldiv") || type == std::string("kullback_leibler")) {
    return "kullback_leibler";
  } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
    return "mape";
  } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
    return "custom";
  }
  return type;
}

void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
132
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
133
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
134
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
135
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
136
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
137
138
139
  }
}

Guolin Ke's avatar
Guolin Ke committed
140
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
141
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
142
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
143
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
144
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
145
  }
146
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
147
148
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
149
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
150
      ParseMetrics(value, metric);
151
152
    }
  }
Guolin Ke's avatar
Guolin Ke committed
153
154
}

Guolin Ke's avatar
Guolin Ke committed
155
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
156
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
157
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
158
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
159
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
160
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
161
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
162
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
163
      *task = TaskType::kPredict;
164
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
165
      *task = TaskType::kConvertModel;
166
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
167
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
168
    } else {
169
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
170
171
172
173
    }
  }
}

wxchan's avatar
wxchan committed
174
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
175
  std::string value;
176
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
177
178
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
179
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
180
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
181
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
186
187
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
188
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
189
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
190
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
191
192
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
193
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
194
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
195
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
196
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
197
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
198
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
199
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
206
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
207
208
209
210
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
    int int_max = std::numeric_limits<short>::max();
Guolin Ke's avatar
Guolin Ke committed
211
212
213
214
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
215
216
  }

Guolin Ke's avatar
Guolin Ke committed
217
218
219
220
221
222
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
223

Guolin Ke's avatar
Guolin Ke committed
224
  GetMembersFromString(params);
225

Guolin Ke's avatar
Guolin Ke committed
226
227
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
228

Guolin Ke's avatar
Guolin Ke committed
229
230
231
  if (valid_data_initscores.size() == 0 && valid.size() > 0) {
    valid_data_initscores = std::vector<std::string>(valid.size(), "");
  }
232
  CHECK(valid.size() == valid_data_initscores.size());
Guolin Ke's avatar
Guolin Ke committed
233
234
235

  // check for conflicts
  CheckParamConflict();
236

Guolin Ke's avatar
Guolin Ke committed
237
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
238
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
239
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
240
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
241
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
242
243
244
245
246
247
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
248
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
249
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
250
251
}

Guolin Ke's avatar
Guolin Ke committed
252
253
254
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
255
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
256

257
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
258
259
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
260
261
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
262
    if (task == TaskType::kTrain && num_class_check != 1) {
263
264
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
265
  }
Guolin Ke's avatar
Guolin Ke committed
266
  for (std::string metric_type : metric) {
267
268
269
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Guolin Ke's avatar
Guolin Ke committed
270
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
271
    if ((objective_type_multiclass && !metric_type_multiclass)
272
273
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
274
    }
275
  }
276

Guolin Ke's avatar
Guolin Ke committed
277
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
278
279
280
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
281
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
282
283
  }

Guolin Ke's avatar
Guolin Ke committed
284
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
285
286

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
287
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
288
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
289
290
  }

Guolin Ke's avatar
Guolin Ke committed
291
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
292
    is_parallel_find_bin = false;
Guolin Ke's avatar
Guolin Ke committed
293
294
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
295
    is_parallel_find_bin = true;
Guolin Ke's avatar
Guolin Ke committed
296
297
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
298
299
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
300
                   histogram_pool_size);
tks's avatar
tks committed
301
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
302
      histogram_pool_size = -1;
303
    }
Guolin Ke's avatar
Guolin Ke committed
304
  }
305
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
306
307
  if (max_depth > 0) {
    int full_num_leaves = static_cast<int>(std::pow(2, max_depth));
308
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
309
        && num_leaves == kDefaultNumLeaves) {
310
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
311
    }
312
    num_leaves = std::min(num_leaves, 2 << max_depth);
313
  }
Guolin Ke's avatar
Guolin Ke committed
314
315
}

Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
320
321
322
323
324
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
325
326
327
}

}  // namespace LightGBM