"vscode:/vscode.git/clone" did not exist on "6e0b0a8be44b14ade10737288a26aa361a00a18e"
config.cpp 12.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv) {
wxchan's avatar
wxchan committed
16
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
17
  if (tmp_strs.size() == 2 || tmp_strs.size() == 1) {
wxchan's avatar
wxchan committed
18
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
19
20
21
22
    std::string value = "";
    if (tmp_strs.size() == 2) {
      value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    }
23
    if (!Common::CheckASCII(key) || !Common::CheckASCII(value)) {
24
      Log::Fatal("Do not support non-ASCII characters in config.");
25
    }
wxchan's avatar
wxchan committed
26
    if (key.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
27
28
29
      auto value_search = params->find(key);
      if (value_search == params->end()) {  // not set
        params->emplace(key, value);
wxchan's avatar
wxchan committed
30
      } else {
31
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
32
33
34
35
36
37
38
39
40
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
41
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
42
  std::unordered_map<std::string, std::string> params;
43
  auto args = Common::Split(parameters, " \t\n\r");
44
  for (auto arg : args) {
Guolin Ke's avatar
Guolin Ke committed
45
    KV2Map(&params, Common::Trim(arg).c_str());
46
47
  }
  ParameterAlias::KeyAliasTransform(&params);
48
  return params;
49
50
}

Guolin Ke's avatar
Guolin Ke committed
51
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
52
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
53
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
54
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
55
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
56
      *boosting = "gbdt";
57
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
58
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
59
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
60
      *boosting = "goss";
61
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
62
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
63
    } else {
64
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
74
75
76
77
78
79
80
81
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
82
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
83
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
84
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
85
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
86
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
87
88
89
  }
}

Guolin Ke's avatar
Guolin Ke committed
90
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
91
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
92
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
93
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
94
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
95
  }
96
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
97
98
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
99
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
100
      ParseMetrics(value, metric);
101
102
    }
  }
Guolin Ke's avatar
Guolin Ke committed
103
104
}

Guolin Ke's avatar
Guolin Ke committed
105
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
106
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
107
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
108
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
109
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
110
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
111
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
112
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
113
      *task = TaskType::kPredict;
114
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
115
      *task = TaskType::kConvertModel;
116
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
117
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
118
    } else {
119
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
    }
  }
}

wxchan's avatar
wxchan committed
124
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
125
  std::string value;
126
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
127
128
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
129
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
130
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
131
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
138
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
139
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
140
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
141
142
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
143
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
144
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
145
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
146
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
147
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
148
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
149
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Belinda Trotta's avatar
Belinda Trotta committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
void Config::GetAucMuWeights() {
  if (auc_mu_weights.empty()) {
    // equal weights for all classes
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 1));
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      auc_mu_weights_matrix[i][i] = 0;
    }
  } else {
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 0));
    if (auc_mu_weights.size() != static_cast<size_t>(num_class * num_class)) {
      Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size());
    }
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      for (size_t j = 0; j < static_cast<size_t>(num_class); ++j) {
        if (i == j) {
          auc_mu_weights_matrix[i][j] = 0;
          if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) {
            Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j);
          }
        } else {
          if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) {
            Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j);
          }
          auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j];
        }
      }
    }
  }
184
}
Belinda Trotta's avatar
Belinda Trotta committed
185

Guolin Ke's avatar
Guolin Ke committed
186
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
187
188
189
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
Guolin Ke's avatar
Guolin Ke committed
190
    int int_max = std::numeric_limits<int16_t>::max();
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
195
    objective_seed = static_cast<int>(rand.NextShort(0, int_max));
196
    extra_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
197
198
  }

Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
203
204
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
205

Guolin Ke's avatar
Guolin Ke committed
206
  GetMembersFromString(params);
207

Belinda Trotta's avatar
Belinda Trotta committed
208
209
  GetAucMuWeights();

Guolin Ke's avatar
Guolin Ke committed
210
211
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
212

Guolin Ke's avatar
Guolin Ke committed
213
214
215
  if (valid_data_initscores.size() == 0 && valid.size() > 0) {
    valid_data_initscores = std::vector<std::string>(valid.size(), "");
  }
216
  CHECK(valid.size() == valid_data_initscores.size());
Guolin Ke's avatar
Guolin Ke committed
217

218
219
220
221
222
223
224
225
226
227
228
229
230
  if (valid_data_initscores.empty()) {
    std::vector<std::string> new_valid;
    for (size_t i = 0; i < valid.size(); ++i) {
      if (valid[i] != data) {
        // Only push the non-training data
        new_valid.push_back(valid[i]);
      } else {
        is_provide_training_metric = true;
      }
    }
    valid = new_valid;
  }

Guolin Ke's avatar
Guolin Ke committed
231
232
  // check for conflicts
  CheckParamConflict();
233

Guolin Ke's avatar
Guolin Ke committed
234
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
235
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
236
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
237
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
238
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
239
240
241
242
243
244
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
245
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
246
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
247
248
}

Guolin Ke's avatar
Guolin Ke committed
249
250
251
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
252
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
253

254
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
255
256
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
257
258
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
259
    if (task == TaskType::kTrain && num_class_check != 1) {
260
261
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
262
  }
Guolin Ke's avatar
Guolin Ke committed
263
  for (std::string metric_type : metric) {
264
265
266
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Belinda Trotta's avatar
Belinda Trotta committed
267
                                   || metric_type == std::string("auc_mu")
Guolin Ke's avatar
Guolin Ke committed
268
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
269
    if ((objective_type_multiclass && !metric_type_multiclass)
270
271
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
272
    }
273
  }
274

Guolin Ke's avatar
Guolin Ke committed
275
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
276
277
278
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
279
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
280
281
  }

Guolin Ke's avatar
Guolin Ke committed
282
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
283
284

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
285
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
286
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
287
288
  }

Guolin Ke's avatar
Guolin Ke committed
289
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
290
    is_parallel_find_bin = false;
Guolin Ke's avatar
Guolin Ke committed
291
292
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
293
    is_parallel_find_bin = true;
Guolin Ke's avatar
Guolin Ke committed
294
295
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
296
297
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
298
                   histogram_pool_size);
tks's avatar
tks committed
299
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
300
      histogram_pool_size = -1;
301
    }
Guolin Ke's avatar
Guolin Ke committed
302
  }
303
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
304
  if (max_depth > 0) {
305
    double full_num_leaves = std::pow(2, max_depth);
306
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
307
        && num_leaves == kDefaultNumLeaves) {
308
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
309
    }
310
311
312
313
314

    if (full_num_leaves < num_leaves) {
      // Fits in an int, and is more restrictive than the current num_leaves
      num_leaves = static_cast<int>(full_num_leaves);
    }
315
  }
316
317
318
319
320
  // force col-wise for gpu
  if (device_type == std::string("gpu")) {
    force_col_wise = true;
    force_row_wise = false;
  }
Guolin Ke's avatar
Guolin Ke committed
321
322
}

Guolin Ke's avatar
Guolin Ke committed
323
324
325
326
327
328
329
330
331
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
332
333
334
}

}  // namespace LightGBM