config.cpp 12.3 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv) {
wxchan's avatar
wxchan committed
16
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
17
  if (tmp_strs.size() == 2 || tmp_strs.size() == 1) {
wxchan's avatar
wxchan committed
18
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
19
20
21
22
    std::string value = "";
    if (tmp_strs.size() == 2) {
      value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    }
23
    if (!Common::CheckASCII(key) || !Common::CheckASCII(value)) {
24
      Log::Fatal("Do not support non-ASCII characters in config.");
25
    }
wxchan's avatar
wxchan committed
26
    if (key.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
27
28
29
      auto value_search = params->find(key);
      if (value_search == params->end()) {  // not set
        params->emplace(key, value);
wxchan's avatar
wxchan committed
30
      } else {
31
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
32
33
34
35
36
37
38
39
40
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
41
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
42
  std::unordered_map<std::string, std::string> params;
43
  auto args = Common::Split(parameters, " \t\n\r");
44
  for (auto arg : args) {
Guolin Ke's avatar
Guolin Ke committed
45
    KV2Map(&params, Common::Trim(arg).c_str());
46
47
  }
  ParameterAlias::KeyAliasTransform(&params);
48
  return params;
49
50
}

Guolin Ke's avatar
Guolin Ke committed
51
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
52
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
53
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
54
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
55
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
56
      *boosting = "gbdt";
57
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
58
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
59
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
60
      *boosting = "goss";
61
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
62
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
63
    } else {
64
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
74
75
76
77
78
79
80
81
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
82
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
83
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
84
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
85
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
86
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
87
88
89
  }
}

Guolin Ke's avatar
Guolin Ke committed
90
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
91
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
92
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
93
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
94
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
95
  }
96
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
97
98
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
99
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
100
      ParseMetrics(value, metric);
101
102
    }
  }
Guolin Ke's avatar
Guolin Ke committed
103
104
}

Guolin Ke's avatar
Guolin Ke committed
105
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
106
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
107
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
108
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
109
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
110
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
111
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
112
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
113
      *task = TaskType::kPredict;
114
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
115
      *task = TaskType::kConvertModel;
116
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
117
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
118
    } else {
119
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
    }
  }
}

wxchan's avatar
wxchan committed
124
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
125
  std::string value;
126
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
127
128
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
129
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
130
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
131
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
138
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
139
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
140
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
141
142
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
143
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
144
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
145
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
146
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
147
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
148
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
149
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Belinda Trotta's avatar
Belinda Trotta committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
void Config::GetAucMuWeights() {
  if (auc_mu_weights.empty()) {
    // equal weights for all classes
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 1));
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      auc_mu_weights_matrix[i][i] = 0;
    }
  } else {
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 0));
    if (auc_mu_weights.size() != static_cast<size_t>(num_class * num_class)) {
      Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size());
    }
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      for (size_t j = 0; j < static_cast<size_t>(num_class); ++j) {
        if (i == j) {
          auc_mu_weights_matrix[i][j] = 0;
          if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) {
            Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j);
          }
        } else {
          if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) {
            Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j);
          }
          auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j];
        }
      }
    }
  }
184
}
Belinda Trotta's avatar
Belinda Trotta committed
185

Guolin Ke's avatar
Guolin Ke committed
186
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
187
188
189
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
Guolin Ke's avatar
Guolin Ke committed
190
    int int_max = std::numeric_limits<int16_t>::max();
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
195
    objective_seed = static_cast<int>(rand.NextShort(0, int_max));
196
    extra_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
197
198
  }

Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
203
204
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
205

Guolin Ke's avatar
Guolin Ke committed
206
  GetMembersFromString(params);
207

Belinda Trotta's avatar
Belinda Trotta committed
208
209
  GetAucMuWeights();

Guolin Ke's avatar
Guolin Ke committed
210
211
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
212

213
214
215
216
217
218
219
  std::vector<std::string> new_valid;
  for (size_t i = 0; i < valid.size(); ++i) {
    if (valid[i] != data) {
      // Only push the non-training data
      new_valid.push_back(valid[i]);
    } else {
      is_provide_training_metric = true;
220
221
    }
  }
222
  valid = new_valid;
223

Guolin Ke's avatar
Guolin Ke committed
224
225
  // check for conflicts
  CheckParamConflict();
226

Guolin Ke's avatar
Guolin Ke committed
227
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
228
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
229
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
230
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
231
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
232
233
234
235
236
237
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
238
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
239
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
240
241
}

Guolin Ke's avatar
Guolin Ke committed
242
243
244
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
245
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
246

247
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
248
249
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
250
251
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
252
    if (task == TaskType::kTrain && num_class_check != 1) {
253
254
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
255
  }
Guolin Ke's avatar
Guolin Ke committed
256
  for (std::string metric_type : metric) {
257
258
259
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Belinda Trotta's avatar
Belinda Trotta committed
260
                                   || metric_type == std::string("auc_mu")
Guolin Ke's avatar
Guolin Ke committed
261
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
262
    if ((objective_type_multiclass && !metric_type_multiclass)
263
264
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
265
    }
266
  }
267

Guolin Ke's avatar
Guolin Ke committed
268
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
269
270
271
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
272
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
273
274
  }

Guolin Ke's avatar
Guolin Ke committed
275
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
276
277

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
278
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
279
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
280
281
  }

Guolin Ke's avatar
Guolin Ke committed
282
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
283
    is_data_based_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
284
285
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
286
    is_data_based_parallel = true;
Guolin Ke's avatar
Guolin Ke committed
287
288
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
289
290
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
291
                   histogram_pool_size);
tks's avatar
tks committed
292
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
293
      histogram_pool_size = -1;
294
    }
Guolin Ke's avatar
Guolin Ke committed
295
  }
296
297
298
299
300
301
  if (is_data_based_parallel) {
    if (!forcedsplits_filename.empty()) {
      Log::Fatal("Don't support forcedsplits in %s tree learner",
                 tree_learner.c_str());
    }
  }
302
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
303
  if (max_depth > 0) {
304
    double full_num_leaves = std::pow(2, max_depth);
305
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
306
        && num_leaves == kDefaultNumLeaves) {
307
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
308
    }
309
310
311
312
313

    if (full_num_leaves < num_leaves) {
      // Fits in an int, and is more restrictive than the current num_leaves
      num_leaves = static_cast<int>(full_num_leaves);
    }
314
  }
315
316
317
318
319
  // force col-wise for gpu
  if (device_type == std::string("gpu")) {
    force_col_wise = true;
    force_row_wise = false;
  }
Guolin Ke's avatar
Guolin Ke committed
320
321
}

Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
326
327
328
329
330
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
331
332
333
}

}  // namespace LightGBM