gbdt_model_text.cpp 22.8 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
7
8
9

#include <string>
#include <sstream>
#include <vector>

10
#include <LightGBM/config.h>
Guolin Ke's avatar
Guolin Ke committed
11
#include <LightGBM/metric.h>
12
#include <LightGBM/objective_function.h>
13
#include <LightGBM/utils/array_args.h>
14
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
15

16
17
#include "gbdt.h"

Guolin Ke's avatar
Guolin Ke committed
18
19
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
const char* kModelVersion = "v3";
21

22
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
23
24
25
  std::stringstream str_buf;

  str_buf << "{";
26
27
28
29
30
31
  str_buf << "\"name\":\"" << SubModelName() << "\"," << '\n';
  str_buf << "\"version\":\"" << kModelVersion << "\"," << '\n';
  str_buf << "\"num_class\":" << num_class_ << "," << '\n';
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << '\n';
  str_buf << "\"label_index\":" << label_idx_ << "," << '\n';
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << '\n';
32
33
34
  if (objective_function_ != nullptr) {
    str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n";
  }
Guolin Ke's avatar
Guolin Ke committed
35

36
37
  str_buf << "\"average_output\":" << (average_output_ ? "true" : "false") << ",\n";

38
39
40
41
42
  str_buf << "\"feature_names\":[\"" << Common::Join(feature_names_, "\",\"")
          << "\"]," << '\n';

  str_buf << "\"monotone_constraints\":["
          << Common::Join(monotone_constraints_, ",") << "]," << '\n';
Guolin Ke's avatar
Guolin Ke committed
43

44
45
46
47
48
49
50
51
  str_buf << "\"feature_infos\":" << "{";
  bool first_obj = true;
  for (size_t i = 0; i < feature_infos_.size(); ++i) {
    std::stringstream json_str_buf;
    auto strs = Common::Split(feature_infos_[i].c_str(), ":");
    if (strs[0][0] == '[') {
      strs[0].erase(0, 1);  // remove '['
      strs[1].erase(strs[1].size() - 1);  // remove ']'
52
53
54
55
56
57
      double max_, min_;
      Common::Atof(strs[0].c_str(), &min_);
      Common::Atof(strs[1].c_str(), &max_);
      json_str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
      json_str_buf << "{\"min_value\":" << Common::AvoidInf(min_) << ",";
      json_str_buf << "\"max_value\":" << Common::AvoidInf(max_) << ",";
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
      json_str_buf << "\"values\":[]}";
    } else if (strs[0] != "none") {  // categorical feature
      auto vals = Common::StringToArray<int>(feature_infos_[i], ':');
      auto max_idx = ArrayArgs<int>::ArgMax(vals);
      auto min_idx = ArrayArgs<int>::ArgMin(vals);
      json_str_buf << "{\"min_value\":" << vals[min_idx] << ",";
      json_str_buf << "\"max_value\":" << vals[max_idx] << ",";
      json_str_buf << "\"values\":[" << Common::Join(vals, ",") << "]}";
    } else {  // unused feature
      continue;
    }
    if (!first_obj) {
      str_buf << ",";
    }
    str_buf << "\"" << feature_names_[i] << "\":";
    str_buf << json_str_buf.str();
    first_obj = false;
  }
  str_buf << "}," << '\n';

Guolin Ke's avatar
Guolin Ke committed
78
79
  str_buf << "\"tree_info\":[";
  int num_used_model = static_cast<int>(models_.size());
80
81
82
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
83
  if (num_iteration > 0) {
84
85
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_ , num_used_model);
Guolin Ke's avatar
Guolin Ke committed
86
  }
87
88
89
  int start_model = start_iteration * num_tree_per_iteration_;
  for (int i = start_model; i < num_used_model; ++i) {
    if (i > start_model) {
Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
94
95
96
      str_buf << ",";
    }
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
  }
97
98
99
100
101
102
103
104
105
106
107
108
  str_buf << "]," << '\n';

  std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
    if (feature_importances_int > 0) {
      pairs.emplace_back(feature_importances_int, feature_names_[i]);
    }
  }
  str_buf << '\n' << "\"feature_importances\":" << "{";
109
110
  for (size_t i = 0; i < pairs.size(); ++i) {
    if (i > 0) {
111
112
      str_buf << ",";
    }
113
    str_buf << "\"" << pairs[i].second << "\":" << std::to_string(pairs[i].first);
114
115
  }
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
116

117
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
118
119
120
121
122
123
124

  return str_buf.str();
}

std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

125
126
127
128
129
130
131
132
133
134
135
136
  str_buf << "#include \"gbdt.h\"" << '\n';
  str_buf << "#include <LightGBM/utils/common.h>" << '\n';
  str_buf << "#include <LightGBM/objective_function.h>" << '\n';
  str_buf << "#include <LightGBM/metric.h>" << '\n';
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << '\n';
  str_buf << "#include <ctime>" << '\n';
  str_buf << "#include <sstream>" << '\n';
  str_buf << "#include <chrono>" << '\n';
  str_buf << "#include <string>" << '\n';
  str_buf << "#include <vector>" << '\n';
  str_buf << "#include <utility>" << '\n';
  str_buf << "namespace LightGBM {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
142
143
144

  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
145
    str_buf << models_[i]->ToIfElse(i, false) << '\n';
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
153
154
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
155
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
156
157
158

  std::stringstream pred_str_buf;

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
174
  str_buf << pred_str_buf.str();
175
176
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
177

178
179
180
  // PredictRawByMap
  str_buf << "double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "ByMap";
185
  }
186
  str_buf << " };" << '\n' << '\n';
187
188
189

  std::stringstream pred_str_buf_map;

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
  pred_str_buf_map << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf_map << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf_map << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf_map << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf_map << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf_map << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
205
  str_buf << pred_str_buf_map.str();
206
207
  str_buf << "}" << '\n';
  str_buf << '\n';
208

Guolin Ke's avatar
Guolin Ke committed
209
  // Predict
210
211
212
213
214
215
216
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
217
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
218
219
220
221
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
222

223
  // PredictByMap
224
225
226
227
228
229
230
  str_buf << "void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRawByMap(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
231
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
232
233
234
235
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
236
237


Guolin Ke's avatar
Guolin Ke committed
238
239
  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
240
    str_buf << models_[i]->ToIfElse(i, true) << '\n';
Guolin Ke's avatar
Guolin Ke committed
241
242
243
244
245
246
247
248
249
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
250
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
251

252
253
254
255
256
257
  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
258

259
  // PredictLeafIndexByMap
260
261
  str_buf << "double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
262
263
264
265
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "LeafByMap";
266
  }
267
  str_buf << " };" << '\n' << '\n';
268

269
270
271
272
273
274
  str_buf << "void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafByMapPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
275

276
  str_buf << "}  // namespace LightGBM" << '\n';
Guolin Ke's avatar
Guolin Ke committed
277
278
279
280
281
282
283
284
285
286
287
288

  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
      (std::istreambuf_iterator<char>()));
    output_file.open(filename);
289
290
291
292
    output_file << "#define USE_HARD_CODE 0" << '\n';
    output_file << "#ifndef USE_HARD_CODE" << '\n';
    output_file << origin << '\n';
    output_file << "#else" << '\n';
Guolin Ke's avatar
Guolin Ke committed
293
    output_file << ModelToIfElse(num_iteration);
294
    output_file << "#endif" << '\n';
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
299
300
301
302
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }

  ifs.close();
  output_file.close();

303
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
304
305
}

306
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
307
308
309
  std::stringstream ss;

  // output model type
310
311
  ss << SubModelName() << '\n';
  ss << "version=" << kModelVersion << '\n';
Guolin Ke's avatar
Guolin Ke committed
312
  // output number of class
313
314
  ss << "num_class=" << num_class_ << '\n';
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
315
  // output label index
316
  ss << "label_index=" << label_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
317
  // output max_feature_idx
318
  ss << "max_feature_idx=" << max_feature_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
319
320
  // output objective
  if (objective_function_ != nullptr) {
321
    ss << "objective=" << objective_function_->ToString() << '\n';
Guolin Ke's avatar
Guolin Ke committed
322
323
324
  }

  if (average_output_) {
325
    ss << "average_output" << '\n';
Guolin Ke's avatar
Guolin Ke committed
326
327
  }

328
  ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
329

330
331
332
333
334
  if (monotone_constraints_.size() != 0) {
    ss << "monotone_constraints=" << Common::Join(monotone_constraints_, " ")
       << '\n';
  }

335
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
336
337

  int num_used_model = static_cast<int>(models_.size());
338
339
340
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
341
  if (num_iteration > 0) {
342
343
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_, num_used_model);
Guolin Ke's avatar
Guolin Ke committed
344
  }
345

346
347
348
349
  int start_model = start_iteration * num_tree_per_iteration_;

  std::vector<std::string> tree_strs(num_used_model - start_model);
  std::vector<size_t> tree_sizes(num_used_model - start_model);
Guolin Ke's avatar
Guolin Ke committed
350
  // output tree models
351
  #pragma omp parallel for schedule(static)
352
353
354
355
356
  for (int i = start_model; i < num_used_model; ++i) {
    const int idx = i - start_model;
    tree_strs[idx] = "Tree=" + std::to_string(idx) + '\n';
    tree_strs[idx] += models_[i]->ToString() + '\n';
    tree_sizes[idx] = tree_strs[idx].size();
357
358
359
360
361
  }

  ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n';
  ss << '\n';

362
  for (int i = 0; i < num_used_model - start_model; ++i) {
363
364
    ss << tree_strs[i];
    tree_strs[i].clear();
Guolin Ke's avatar
Guolin Ke committed
365
  }
Guolin Ke's avatar
Guolin Ke committed
366
  ss << "end of trees" << "\n";
Guolin Ke's avatar
Guolin Ke committed
367

368
  std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
Guolin Ke's avatar
Guolin Ke committed
369
370
371
372
373
374
375
376
377
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
    if (feature_importances_int > 0) {
      pairs.emplace_back(feature_importances_int, feature_names_[i]);
    }
  }
  // sort the importance
378
379
380
  std::stable_sort(pairs.begin(), pairs.end(),
                   [](const std::pair<size_t, std::string>& lhs,
                      const std::pair<size_t, std::string>& rhs) {
Guolin Ke's avatar
Guolin Ke committed
381
382
    return lhs.first > rhs.first;
  });
383
  ss << '\n' << "feature_importances:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
384
  for (size_t i = 0; i < pairs.size(); ++i) {
385
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
Guolin Ke's avatar
Guolin Ke committed
386
  }
Guolin Ke's avatar
Guolin Ke committed
387
  if (config_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
388
    ss << "\nparameters:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
389
    ss << config_->ToString() << "\n";
Guolin Ke's avatar
Guolin Ke committed
390
391
392
393
394
    ss << "end of parameters" << '\n';
  } else if (!loaded_parameter_.empty()) {
    ss << "\nparameters:" << '\n';
    ss << loaded_parameter_ << "\n";
    ss << "end of parameters" << '\n';
Guolin Ke's avatar
Guolin Ke committed
395
  }
Guolin Ke's avatar
Guolin Ke committed
396
397
398
  return ss.str();
}

399
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const {
Guolin Ke's avatar
Guolin Ke committed
400
401
  /*! \brief File to write models */
  std::ofstream output_file;
402
  output_file.open(filename, std::ios::out | std::ios::binary);
403
  std::string str_to_write = SaveModelToString(start_iteration, num_iteration);
404
  output_file.write(str_to_write.c_str(), str_to_write.size());
Guolin Ke's avatar
Guolin Ke committed
405
406
  output_file.close();

407
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
408
409
}

410
bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
Guolin Ke's avatar
Guolin Ke committed
411
412
  // use serialized string to restore this object
  models_.clear();
413
414
415
416
417
418
419
  auto c_str = buffer;
  auto p = c_str;
  auto end = p + len;
  std::unordered_map<std::string, std::string> key_vals;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
420
      std::string cur_line(p, line_len);
421
422
423
424
      if (!Common::StartsWith(cur_line, "Tree=")) {
        auto strs = Common::Split(cur_line.c_str(), '=');
        if (strs.size() == 1) {
          key_vals[strs[0]] = "";
425
        } else if (strs.size() == 2) {
426
          key_vals[strs[0]] = strs[1];
427
        } else if (strs.size() > 2) {
Guolin Ke's avatar
Guolin Ke committed
428
          if (strs[0] == "feature_names") {
Guolin Ke's avatar
Guolin Ke committed
429
            key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
430
431
          } else if (strs[0] == "monotone_constraints") {
            key_vals[strs[0]] = cur_line.substr(std::strlen("monotone_constraints="));
432
          } else {
Guolin Ke's avatar
Guolin Ke committed
433
434
            // Use first 128 chars to avoid exceed the message buffer.
            Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
435
          }
436
        }
437
      } else {
438
439
440
441
442
443
        break;
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
Guolin Ke's avatar
Guolin Ke committed
444
445

  // get number of classes
446
447
  if (key_vals.count("num_class")) {
    Common::Atoi(key_vals["num_class"].c_str(), &num_class_);
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
452
  } else {
    Log::Fatal("Model file doesn't specify the number of classes");
    return false;
  }

453
454
  if (key_vals.count("num_tree_per_iteration")) {
    Common::Atoi(key_vals["num_tree_per_iteration"].c_str(), &num_tree_per_iteration_);
Guolin Ke's avatar
Guolin Ke committed
455
456
457
458
459
  } else {
    num_tree_per_iteration_ = num_class_;
  }

  // get index of label
460
461
  if (key_vals.count("label_index")) {
    Common::Atoi(key_vals["label_index"].c_str(), &label_idx_);
Guolin Ke's avatar
Guolin Ke committed
462
463
464
465
  } else {
    Log::Fatal("Model file doesn't specify the label index");
    return false;
  }
466

Guolin Ke's avatar
Guolin Ke committed
467
  // get max_feature_idx first
468
469
  if (key_vals.count("max_feature_idx")) {
    Common::Atoi(key_vals["max_feature_idx"].c_str(), &max_feature_idx_);
Guolin Ke's avatar
Guolin Ke committed
470
471
472
473
  } else {
    Log::Fatal("Model file doesn't specify max_feature_idx");
    return false;
  }
474

Guolin Ke's avatar
Guolin Ke committed
475
  // get average_output
476
  if (key_vals.count("average_output")) {
Guolin Ke's avatar
Guolin Ke committed
477
478
    average_output_ = true;
  }
479

Guolin Ke's avatar
Guolin Ke committed
480
  // get feature names
481
482
  if (key_vals.count("feature_names")) {
    feature_names_ = Common::Split(key_vals["feature_names"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
483
484
485
486
487
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
      return false;
    }
  } else {
488
    Log::Fatal("Model file doesn't contain feature_names");
Guolin Ke's avatar
Guolin Ke committed
489
490
491
    return false;
  }

492
493
  // get monotone_constraints
  if (key_vals.count("monotone_constraints")) {
Guolin Ke's avatar
Guolin Ke committed
494
    monotone_constraints_ = Common::StringToArray<int8_t>(key_vals["monotone_constraints"].c_str(), ' ');
495
496
497
498
499
500
    if (monotone_constraints_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of monotone_constraints");
      return false;
    }
  }

501
502
  if (key_vals.count("feature_infos")) {
    feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
503
504
505
506
507
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
508
    Log::Fatal("Model file doesn't contain feature_infos");
Guolin Ke's avatar
Guolin Ke committed
509
510
511
    return false;
  }

512
513
  if (key_vals.count("objective")) {
    auto str = key_vals["objective"];
514
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(ParseObjectiveAlias(str)));
Guolin Ke's avatar
Guolin Ke committed
515
516
    objective_function_ = loaded_objective_.get();
  }
517

518
519
520
521
  if (!key_vals.count("tree_sizes")) {
    while (p < end) {
      auto line_len = Common::GetLine(p);
      if (line_len > 0) {
522
        std::string cur_line(p, line_len);
523
524
525
526
527
528
        if (Common::StartsWith(cur_line, "Tree=")) {
          p += line_len;
          p = Common::SkipNewLine(p);
          size_t used_len = 0;
          models_.emplace_back(new Tree(p, &used_len));
          p += used_len;
529
        } else {
530
531
532
533
534
535
536
537
538
539
540
541
          break;
        }
      }
      p = Common::SkipNewLine(p);
    }
  } else {
    std::vector<size_t> tree_sizes = Common::StringToArray<size_t>(key_vals["tree_sizes"].c_str(), ' ');
    std::vector<size_t> tree_boundries(tree_sizes.size() + 1, 0);
    int num_trees = static_cast<int>(tree_sizes.size());
    for (int i = 0; i < num_trees; ++i) {
      tree_boundries[i + 1] = tree_boundries[i] + tree_sizes[i];
      models_.emplace_back(nullptr);
Guolin Ke's avatar
Guolin Ke committed
542
    }
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_trees; ++i) {
      OMP_LOOP_EX_BEGIN();
      auto cur_p = p + tree_boundries[i];
      auto line_len = Common::GetLine(cur_p);
      std::string cur_line(cur_p, line_len);
      if (Common::StartsWith(cur_line, "Tree=")) {
        cur_p += line_len;
        cur_p = Common::SkipNewLine(cur_p);
        size_t used_len = 0;
        models_[i].reset(new Tree(cur_p, &used_len));
      } else {
        Log::Fatal("Model format error, expect a tree here. met %s", cur_line.c_str());
      }
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
561
562
563
564
  }
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
  num_init_iteration_ = num_iteration_for_pred_;
  iter_ = 0;
Guolin Ke's avatar
Guolin Ke committed
565
566
567
568
569
  bool is_inparameter = false;
  std::stringstream ss;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
570
      std::string cur_line(p, line_len);
Guolin Ke's avatar
Guolin Ke committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
      if (cur_line == std::string("parameters:")) {
        is_inparameter = true;
      } else if (cur_line == std::string("end of parameters")) {
        break;
      } else if (is_inparameter) {
        ss << cur_line << "\n";
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
  if (!ss.str().empty()) {
    loaded_parameter_ = ss.str();
  }
Guolin Ke's avatar
Guolin Ke committed
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
  return true;
}

std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_type) const {
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  std::vector<double> feature_importances(max_feature_idx_ + 1, 0.0);
  if (importance_type == 0) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
600
#ifdef DEBUG
601
          CHECK_GE(models_[iter]->split_feature(split_idx), 0);
602
#endif
Guolin Ke's avatar
Guolin Ke committed
603
604
605
606
607
608
609
610
          feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
        }
      }
    }
  } else if (importance_type == 1) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
611
#ifdef DEBUG
612
          CHECK_GE(models_[iter]->split_feature(split_idx), 0);
613
#endif
Guolin Ke's avatar
Guolin Ke committed
614
615
616
617
618
          feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
        }
      }
    }
  } else {
619
    Log::Fatal("Unknown importance type: only support split=0 and gain=1");
Guolin Ke's avatar
Guolin Ke committed
620
621
622
623
624
  }
  return feature_importances;
}

}  // namespace LightGBM