"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "c84900753c0332b3ba3e931fb2e4af54bccc67d7"
config.cpp 12.1 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv) {
wxchan's avatar
wxchan committed
16
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
17
  if (tmp_strs.size() == 2 || tmp_strs.size() == 1) {
wxchan's avatar
wxchan committed
18
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
19
20
21
22
    std::string value = "";
    if (tmp_strs.size() == 2) {
      value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    }
23
    if (!Common::CheckASCII(key) || !Common::CheckASCII(value)) {
24
      Log::Fatal("Do not support non-ASCII characters in config.");
25
    }
wxchan's avatar
wxchan committed
26
    if (key.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
27
28
29
      auto value_search = params->find(key);
      if (value_search == params->end()) {  // not set
        params->emplace(key, value);
wxchan's avatar
wxchan committed
30
      } else {
31
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
32
33
34
35
36
37
38
39
40
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
41
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
42
  std::unordered_map<std::string, std::string> params;
43
  auto args = Common::Split(parameters, " \t\n\r");
44
  for (auto arg : args) {
Guolin Ke's avatar
Guolin Ke committed
45
    KV2Map(&params, Common::Trim(arg).c_str());
46
47
  }
  ParameterAlias::KeyAliasTransform(&params);
48
  return params;
49
50
}

Guolin Ke's avatar
Guolin Ke committed
51
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
52
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
53
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
54
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
55
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
56
      *boosting = "gbdt";
57
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
58
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
59
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
60
      *boosting = "goss";
61
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
62
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
63
    } else {
64
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
74
75
76
77
78
79
80
81
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
82
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
83
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
84
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
85
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
86
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
87
88
89
  }
}

Guolin Ke's avatar
Guolin Ke committed
90
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
91
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
92
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
93
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
94
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
95
  }
96
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
97
98
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
99
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
100
      ParseMetrics(value, metric);
101
102
    }
  }
Guolin Ke's avatar
Guolin Ke committed
103
104
}

Guolin Ke's avatar
Guolin Ke committed
105
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
106
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
107
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
108
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
109
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
110
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
111
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
112
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
113
      *task = TaskType::kPredict;
114
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
115
      *task = TaskType::kConvertModel;
116
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
117
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
118
    } else {
119
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
    }
  }
}

wxchan's avatar
wxchan committed
124
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
125
  std::string value;
126
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
127
128
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
129
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
130
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
131
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
138
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
139
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
140
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
141
142
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
143
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
144
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
145
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
146
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
147
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
148
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
149
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Belinda Trotta's avatar
Belinda Trotta committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
void Config::GetAucMuWeights() {
  if (auc_mu_weights.empty()) {
    // equal weights for all classes
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 1));
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      auc_mu_weights_matrix[i][i] = 0;
    }
  } else {
    auc_mu_weights_matrix = std::vector<std::vector<double>> (num_class, std::vector<double>(num_class, 0));
    if (auc_mu_weights.size() != static_cast<size_t>(num_class * num_class)) {
      Log::Fatal("auc_mu_weights must have %d elements, but found %d", num_class * num_class, auc_mu_weights.size());
    }
    for (size_t i = 0; i < static_cast<size_t>(num_class); ++i) {
      for (size_t j = 0; j < static_cast<size_t>(num_class); ++j) {
        if (i == j) {
          auc_mu_weights_matrix[i][j] = 0;
          if (std::fabs(auc_mu_weights[i * num_class + j]) > kZeroThreshold) {
            Log::Info("AUC-mu matrix must have zeros on diagonal. Overwriting value in position %d of auc_mu_weights with 0.", i * num_class + j);
          }
        } else {
          if (std::fabs(auc_mu_weights[i * num_class + j]) < kZeroThreshold) {
            Log::Fatal("AUC-mu matrix must have non-zero values for non-diagonal entries. Found zero value in position %d of auc_mu_weights.", i * num_class + j);
          }
          auc_mu_weights_matrix[i][j] = auc_mu_weights[i * num_class + j];
        }
      }
    }
  }
184
}
Belinda Trotta's avatar
Belinda Trotta committed
185

Guolin Ke's avatar
Guolin Ke committed
186
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
187
188
189
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
Guolin Ke's avatar
Guolin Ke committed
190
    int int_max = std::numeric_limits<int16_t>::max();
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
195
196
  }

Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
201
202
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
203

Guolin Ke's avatar
Guolin Ke committed
204
  GetMembersFromString(params);
205

Belinda Trotta's avatar
Belinda Trotta committed
206
207
  GetAucMuWeights();

Guolin Ke's avatar
Guolin Ke committed
208
209
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
210

Guolin Ke's avatar
Guolin Ke committed
211
212
213
  if (valid_data_initscores.size() == 0 && valid.size() > 0) {
    valid_data_initscores = std::vector<std::string>(valid.size(), "");
  }
214
  CHECK(valid.size() == valid_data_initscores.size());
Guolin Ke's avatar
Guolin Ke committed
215

216
217
218
219
220
221
222
223
224
225
226
227
228
  if (valid_data_initscores.empty()) {
    std::vector<std::string> new_valid;
    for (size_t i = 0; i < valid.size(); ++i) {
      if (valid[i] != data) {
        // Only push the non-training data
        new_valid.push_back(valid[i]);
      } else {
        is_provide_training_metric = true;
      }
    }
    valid = new_valid;
  }

Guolin Ke's avatar
Guolin Ke committed
229
230
  // check for conflicts
  CheckParamConflict();
231

Guolin Ke's avatar
Guolin Ke committed
232
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
233
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
234
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
235
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
236
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
241
242
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
243
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
244
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
245
246
}

Guolin Ke's avatar
Guolin Ke committed
247
248
249
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
250
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
251

252
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
253
254
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
255
256
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
257
    if (task == TaskType::kTrain && num_class_check != 1) {
258
259
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
260
  }
Guolin Ke's avatar
Guolin Ke committed
261
  for (std::string metric_type : metric) {
262
263
264
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Belinda Trotta's avatar
Belinda Trotta committed
265
                                   || metric_type == std::string("auc_mu")
Guolin Ke's avatar
Guolin Ke committed
266
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
267
    if ((objective_type_multiclass && !metric_type_multiclass)
268
269
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
270
    }
271
  }
272

Guolin Ke's avatar
Guolin Ke committed
273
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
274
275
276
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
277
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
278
279
  }

Guolin Ke's avatar
Guolin Ke committed
280
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
281
282

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
283
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
284
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
285
286
  }

Guolin Ke's avatar
Guolin Ke committed
287
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
288
    is_parallel_find_bin = false;
Guolin Ke's avatar
Guolin Ke committed
289
290
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
291
    is_parallel_find_bin = true;
Guolin Ke's avatar
Guolin Ke committed
292
293
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
294
295
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
296
                   histogram_pool_size);
tks's avatar
tks committed
297
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
298
      histogram_pool_size = -1;
299
    }
Guolin Ke's avatar
Guolin Ke committed
300
  }
301
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
302
  if (max_depth > 0) {
303
    double full_num_leaves = std::pow(2, max_depth);
304
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
305
        && num_leaves == kDefaultNumLeaves) {
306
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
307
    }
308
309
310
311
312

    if (full_num_leaves < num_leaves) {
      // Fits in an int, and is more restrictive than the current num_leaves
      num_leaves = static_cast<int>(full_num_leaves);
    }
313
  }
Guolin Ke's avatar
Guolin Ke committed
314
315
}

Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
320
321
322
323
324
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
325
326
327
}

}  // namespace LightGBM