config.cpp 10.6 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#include <LightGBM/config.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
10

Guolin Ke's avatar
Guolin Ke committed
11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
13
14

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
15
void Config::KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv) {
wxchan's avatar
wxchan committed
16
  std::vector<std::string> tmp_strs = Common::Split(kv, '=');
17
  if (tmp_strs.size() == 2 || tmp_strs.size() == 1) {
wxchan's avatar
wxchan committed
18
    std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
19
20
21
22
    std::string value = "";
    if (tmp_strs.size() == 2) {
      value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
    }
23
    if (!Common::CheckASCII(key) || !Common::CheckASCII(value)) {
24
      Log::Fatal("Do not support non-ASCII characters in config.");
25
    }
wxchan's avatar
wxchan committed
26
    if (key.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
27
28
29
      auto value_search = params->find(key);
      if (value_search == params->end()) {  // not set
        params->emplace(key, value);
wxchan's avatar
wxchan committed
30
      } else {
31
        Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
wxchan's avatar
wxchan committed
32
33
34
35
36
37
38
39
40
          key.c_str(), value_search->second.c_str(), key.c_str(), value.c_str(),
          key.c_str(), value_search->second.c_str());
      }
    }
  } else {
    Log::Warning("Unknown parameter %s", kv);
  }
}

Guolin Ke's avatar
Guolin Ke committed
41
std::unordered_map<std::string, std::string> Config::Str2Map(const char* parameters) {
42
  std::unordered_map<std::string, std::string> params;
43
  auto args = Common::Split(parameters, " \t\n\r");
44
  for (auto arg : args) {
Guolin Ke's avatar
Guolin Ke committed
45
    KV2Map(&params, Common::Trim(arg).c_str());
46
47
  }
  ParameterAlias::KeyAliasTransform(&params);
48
  return params;
49
50
}

Guolin Ke's avatar
Guolin Ke committed
51
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting) {
Guolin Ke's avatar
Guolin Ke committed
52
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
53
  if (Config::GetString(params, "boosting", &value)) {
Guolin Ke's avatar
Guolin Ke committed
54
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
55
    if (value == std::string("gbdt") || value == std::string("gbrt")) {
Guolin Ke's avatar
Guolin Ke committed
56
      *boosting = "gbdt";
57
    } else if (value == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
58
      *boosting = "dart";
Guolin Ke's avatar
Guolin Ke committed
59
    } else if (value == std::string("goss")) {
Guolin Ke's avatar
Guolin Ke committed
60
      *boosting = "goss";
61
    } else if (value == std::string("rf") || value == std::string("random_forest")) {
Guolin Ke's avatar
Guolin Ke committed
62
      *boosting = "rf";
Guolin Ke's avatar
Guolin Ke committed
63
    } else {
64
      Log::Fatal("Unknown boosting type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
74
75
76
77
78
79
80
81
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
  std::unordered_set<std::string> metric_sets;
  out_metric->clear();
  std::vector<std::string> metrics = Common::Split(value.c_str(), ',');
  for (auto& met : metrics) {
    auto type = ParseMetricAlias(met);
    if (metric_sets.count(type) <= 0) {
      out_metric->push_back(type);
      metric_sets.insert(type);
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
82
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective) {
Guolin Ke's avatar
Guolin Ke committed
83
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
84
  if (Config::GetString(params, "objective", &value)) {
Guolin Ke's avatar
Guolin Ke committed
85
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
86
    *objective = ParseObjectiveAlias(value);
Guolin Ke's avatar
Guolin Ke committed
87
88
89
  }
}

Guolin Ke's avatar
Guolin Ke committed
90
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric) {
Guolin Ke's avatar
Guolin Ke committed
91
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
92
  if (Config::GetString(params, "metric", &value)) {
Guolin Ke's avatar
Guolin Ke committed
93
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
94
    ParseMetrics(value, metric);
Guolin Ke's avatar
Guolin Ke committed
95
  }
96
  // add names of objective function if not providing metric
Guolin Ke's avatar
Guolin Ke committed
97
98
  if (metric->empty() && value.size() == 0) {
    if (Config::GetString(params, "objective", &value)) {
99
      std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
100
      ParseMetrics(value, metric);
101
102
    }
  }
Guolin Ke's avatar
Guolin Ke committed
103
104
}

Guolin Ke's avatar
Guolin Ke committed
105
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task) {
Guolin Ke's avatar
Guolin Ke committed
106
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
107
  if (Config::GetString(params, "task", &value)) {
Guolin Ke's avatar
Guolin Ke committed
108
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
Guolin Ke's avatar
Guolin Ke committed
109
    if (value == std::string("train") || value == std::string("training")) {
Guolin Ke's avatar
Guolin Ke committed
110
      *task = TaskType::kTrain;
Guolin Ke's avatar
Guolin Ke committed
111
    } else if (value == std::string("predict") || value == std::string("prediction")
Guolin Ke's avatar
Guolin Ke committed
112
               || value == std::string("test")) {
Guolin Ke's avatar
Guolin Ke committed
113
      *task = TaskType::kPredict;
114
    } else if (value == std::string("convert_model")) {
Guolin Ke's avatar
Guolin Ke committed
115
      *task = TaskType::kConvertModel;
116
    } else if (value == std::string("refit") || value == std::string("refit_tree")) {
Guolin Ke's avatar
Guolin Ke committed
117
      *task = TaskType::KRefitTree;
Guolin Ke's avatar
Guolin Ke committed
118
    } else {
119
      Log::Fatal("Unknown task type %s", value.c_str());
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
    }
  }
}

wxchan's avatar
wxchan committed
124
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
Guolin Ke's avatar
Guolin Ke committed
125
  std::string value;
126
  if (Config::GetString(params, "device_type", &value)) {
Guolin Ke's avatar
Guolin Ke committed
127
128
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("cpu")) {
wxchan's avatar
wxchan committed
129
      *device_type = "cpu";
Guolin Ke's avatar
Guolin Ke committed
130
    } else if (value == std::string("gpu")) {
wxchan's avatar
wxchan committed
131
      *device_type = "gpu";
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
    } else {
      Log::Fatal("Unknown device type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
138
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
139
  std::string value;
Guolin Ke's avatar
Guolin Ke committed
140
  if (Config::GetString(params, "tree_learner", &value)) {
Guolin Ke's avatar
Guolin Ke committed
141
142
    std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
    if (value == std::string("serial")) {
Guolin Ke's avatar
Guolin Ke committed
143
      *tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
144
    } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
145
      *tree_learner = "feature";
Guolin Ke's avatar
Guolin Ke committed
146
    } else if (value == std::string("data") || value == std::string("data_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
147
      *tree_learner = "data";
Guolin Ke's avatar
Guolin Ke committed
148
    } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
Guolin Ke's avatar
Guolin Ke committed
149
      *tree_learner = "voting";
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
    } else {
      Log::Fatal("Unknown tree learner type %s", value.c_str());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
156
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
Guolin Ke's avatar
Guolin Ke committed
157
158
159
  // generate seeds by seed.
  if (GetInt(params, "seed", &seed)) {
    Random rand(seed);
Guolin Ke's avatar
Guolin Ke committed
160
    int int_max = std::numeric_limits<int16_t>::max();
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
    data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
    bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
    drop_seed = static_cast<int>(rand.NextShort(0, int_max));
    feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
Guolin Ke's avatar
Guolin Ke committed
165
166
  }

Guolin Ke's avatar
Guolin Ke committed
167
168
169
170
171
172
  GetTaskType(params, &task);
  GetBoostingType(params, &boosting);
  GetMetricType(params, &metric);
  GetObjectiveType(params, &objective);
  GetDeviceType(params, &device_type);
  GetTreeLearnerType(params, &tree_learner);
Guolin Ke's avatar
Guolin Ke committed
173

Guolin Ke's avatar
Guolin Ke committed
174
  GetMembersFromString(params);
175

Guolin Ke's avatar
Guolin Ke committed
176
177
  // sort eval_at
  std::sort(eval_at.begin(), eval_at.end());
Guolin Ke's avatar
Guolin Ke committed
178

Guolin Ke's avatar
Guolin Ke committed
179
180
181
  if (valid_data_initscores.size() == 0 && valid.size() > 0) {
    valid_data_initscores = std::vector<std::string>(valid.size(), "");
  }
182
  CHECK(valid.size() == valid_data_initscores.size());
Guolin Ke's avatar
Guolin Ke committed
183

184
185
186
187
188
189
190
191
192
193
194
195
196
  if (valid_data_initscores.empty()) {
    std::vector<std::string> new_valid;
    for (size_t i = 0; i < valid.size(); ++i) {
      if (valid[i] != data) {
        // Only push the non-training data
        new_valid.push_back(valid[i]);
      } else {
        is_provide_training_metric = true;
      }
    }
    valid = new_valid;
  }

Guolin Ke's avatar
Guolin Ke committed
197
198
  // check for conflicts
  CheckParamConflict();
199

Guolin Ke's avatar
Guolin Ke committed
200
  if (verbosity == 1) {
Guolin Ke's avatar
Guolin Ke committed
201
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
Guolin Ke's avatar
Guolin Ke committed
202
  } else if (verbosity == 0) {
Guolin Ke's avatar
Guolin Ke committed
203
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
Guolin Ke's avatar
Guolin Ke committed
204
  } else if (verbosity >= 2) {
Guolin Ke's avatar
Guolin Ke committed
205
206
207
208
209
210
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
  } else {
    LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
  }
}

Guolin Ke's avatar
Guolin Ke committed
211
bool CheckMultiClassObjective(const std::string& objective) {
Guolin Ke's avatar
Guolin Ke committed
212
  return (objective == std::string("multiclass") || objective == std::string("multiclassova"));
213
214
}

Guolin Ke's avatar
Guolin Ke committed
215
216
217
void Config::CheckParamConflict() {
  // check if objective, metric, and num_class match
  int num_class_check = num_class;
Guolin Ke's avatar
Guolin Ke committed
218
  bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective == std::string("custom") && num_class_check > 1);
219

220
  if (objective_type_multiclass) {
Guolin Ke's avatar
Guolin Ke committed
221
222
    if (num_class_check <= 1) {
      Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
223
224
    }
  } else {
Guolin Ke's avatar
Guolin Ke committed
225
    if (task == TaskType::kTrain && num_class_check != 1) {
226
227
      Log::Fatal("Number of classes must be 1 for non-multiclass training");
    }
228
  }
Guolin Ke's avatar
Guolin Ke committed
229
  for (std::string metric_type : metric) {
230
231
232
    bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
                                   || metric_type == std::string("multi_logloss")
                                   || metric_type == std::string("multi_error")
Guolin Ke's avatar
Guolin Ke committed
233
                                   || (metric_type == std::string("custom") && num_class_check > 1));
Guolin Ke's avatar
Guolin Ke committed
234
    if ((objective_type_multiclass && !metric_type_multiclass)
235
236
        || (!objective_type_multiclass && metric_type_multiclass)) {
      Log::Fatal("Multiclass objective and metrics don't match");
237
    }
238
  }
239

Guolin Ke's avatar
Guolin Ke committed
240
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
241
242
243
    is_parallel = true;
  } else {
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
244
    tree_learner = "serial";
Guolin Ke's avatar
Guolin Ke committed
245
246
  }

Guolin Ke's avatar
Guolin Ke committed
247
  bool is_single_tree_learner = tree_learner == std::string("serial");
Guolin Ke's avatar
Guolin Ke committed
248
249

  if (is_single_tree_learner) {
Guolin Ke's avatar
Guolin Ke committed
250
    is_parallel = false;
Guolin Ke's avatar
Guolin Ke committed
251
    num_machines = 1;
Guolin Ke's avatar
Guolin Ke committed
252
253
  }

Guolin Ke's avatar
Guolin Ke committed
254
  if (is_single_tree_learner || tree_learner == std::string("feature")) {
Guolin Ke's avatar
Guolin Ke committed
255
    is_parallel_find_bin = false;
Guolin Ke's avatar
Guolin Ke committed
256
257
  } else if (tree_learner == std::string("data")
             || tree_learner == std::string("voting")) {
Guolin Ke's avatar
Guolin Ke committed
258
    is_parallel_find_bin = true;
Guolin Ke's avatar
Guolin Ke committed
259
260
    if (histogram_pool_size >= 0
        && tree_learner == std::string("data")) {
261
262
      Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
                   "Will disable this to reduce communication costs",
Guolin Ke's avatar
Guolin Ke committed
263
                   histogram_pool_size);
tks's avatar
tks committed
264
      // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
Guolin Ke's avatar
Guolin Ke committed
265
      histogram_pool_size = -1;
266
    }
Guolin Ke's avatar
Guolin Ke committed
267
  }
268
  // Check max_depth and num_leaves
Guolin Ke's avatar
Guolin Ke committed
269
  if (max_depth > 0) {
270
    double full_num_leaves = std::pow(2, max_depth);
271
    if (full_num_leaves > num_leaves
Guolin Ke's avatar
Guolin Ke committed
272
        && num_leaves == kDefaultNumLeaves) {
273
      Log::Warning("Accuracy may be bad since you didn't set num_leaves and 2^max_depth > num_leaves");
274
    }
275
276
277
278
279

    if (full_num_leaves < num_leaves) {
      // Fits in an int, and is more restrictive than the current num_leaves
      num_leaves = static_cast<int>(full_num_leaves);
    }
280
  }
Guolin Ke's avatar
Guolin Ke committed
281
282
}

Guolin Ke's avatar
Guolin Ke committed
283
284
285
286
287
288
289
290
291
std::string Config::ToString() const {
  std::stringstream str_buf;
  str_buf << "[boosting: " << boosting << "]\n";
  str_buf << "[objective: " << objective << "]\n";
  str_buf << "[metric: " << Common::Join(metric, ",") << "]\n";
  str_buf << "[tree_learner: " << tree_learner << "]\n";
  str_buf << "[device_type: " << device_type << "]\n";
  str_buf << SaveMembersToString();
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
292
293
294
}

}  // namespace LightGBM