gbdt_model_text.cpp 19.8 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/metric.h>
6
7
#include <LightGBM/objective_function.h>
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
8
9

#include <string>
10
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
11
12
#include <vector>

13
14
#include "gbdt.h"

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
const std::string kModelVersion = "v3";
18

19
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
20
21
22
  std::stringstream str_buf;

  str_buf << "{";
23
24
25
26
27
28
  str_buf << "\"name\":\"" << SubModelName() << "\"," << '\n';
  str_buf << "\"version\":\"" << kModelVersion << "\"," << '\n';
  str_buf << "\"num_class\":" << num_class_ << "," << '\n';
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << '\n';
  str_buf << "\"label_index\":" << label_idx_ << "," << '\n';
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << '\n';
29
30
31
32
  str_buf << "\"average_output\":" << (average_output_ ? "true" : "false") << ",\n";
  if (objective_function_ != nullptr) {
    str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n";
  }
Guolin Ke's avatar
Guolin Ke committed
33
34
35

  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
36
    << '\n';
Guolin Ke's avatar
Guolin Ke committed
37
38
39

  str_buf << "\"tree_info\":[";
  int num_used_model = static_cast<int>(models_.size());
40
41
42
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
43
  if (num_iteration > 0) {
44
45
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_ , num_used_model);
Guolin Ke's avatar
Guolin Ke committed
46
  }
47
48
49
  int start_model = start_iteration * num_tree_per_iteration_;
  for (int i = start_model; i < num_used_model; ++i) {
    if (i > start_model) {
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
55
56
      str_buf << ",";
    }
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
  }
57
  str_buf << "]" << '\n';
Guolin Ke's avatar
Guolin Ke committed
58

59
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
65
66

  return str_buf.str();
}

std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

67
68
69
70
71
72
73
74
75
76
77
78
  str_buf << "#include \"gbdt.h\"" << '\n';
  str_buf << "#include <LightGBM/utils/common.h>" << '\n';
  str_buf << "#include <LightGBM/objective_function.h>" << '\n';
  str_buf << "#include <LightGBM/metric.h>" << '\n';
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << '\n';
  str_buf << "#include <ctime>" << '\n';
  str_buf << "#include <sstream>" << '\n';
  str_buf << "#include <chrono>" << '\n';
  str_buf << "#include <string>" << '\n';
  str_buf << "#include <vector>" << '\n';
  str_buf << "#include <utility>" << '\n';
  str_buf << "namespace LightGBM {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
83
84
85
86

  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
87
    str_buf << models_[i]->ToIfElse(i, false) << '\n';
Guolin Ke's avatar
Guolin Ke committed
88
89
90
91
92
93
94
95
96
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
97
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
98
99
100

  std::stringstream pred_str_buf;

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
116
  str_buf << pred_str_buf.str();
117
118
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
119

120
121
122
  // PredictRawByMap
  str_buf << "double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
123
124
125
126
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "ByMap";
127
  }
128
  str_buf << " };" << '\n' << '\n';
129
130
131

  std::stringstream pred_str_buf_map;

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  pred_str_buf_map << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf_map << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf_map << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf_map << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf_map << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf_map << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
147
  str_buf << pred_str_buf_map.str();
148
149
  str_buf << "}" << '\n';
  str_buf << '\n';
150

Guolin Ke's avatar
Guolin Ke committed
151
  // Predict
152
153
154
155
156
157
158
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
159
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
160
161
162
163
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
164

165
  // PredictByMap
166
167
168
169
170
171
172
  str_buf << "void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRawByMap(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
173
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
174
175
176
177
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
178
179


Guolin Ke's avatar
Guolin Ke committed
180
181
  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
182
    str_buf << models_[i]->ToIfElse(i, true) << '\n';
Guolin Ke's avatar
Guolin Ke committed
183
184
185
186
187
188
189
190
191
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
192
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
193

194
195
196
197
198
199
  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
200

201
  // PredictLeafIndexByMap
202
203
  str_buf << "double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
204
205
206
207
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "LeafByMap";
208
  }
209
  str_buf << " };" << '\n' << '\n';
210

211
212
213
214
215
216
  str_buf << "void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafByMapPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
217

218
  str_buf << "}  // namespace LightGBM" << '\n';
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
223
224
225
226
227
228
229
230

  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
      (std::istreambuf_iterator<char>()));
    output_file.open(filename);
231
232
233
234
    output_file << "#define USE_HARD_CODE 0" << '\n';
    output_file << "#ifndef USE_HARD_CODE" << '\n';
    output_file << origin << '\n';
    output_file << "#else" << '\n';
Guolin Ke's avatar
Guolin Ke committed
235
    output_file << ModelToIfElse(num_iteration);
236
    output_file << "#endif" << '\n';
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
241
242
243
244
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }

  ifs.close();
  output_file.close();

245
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
246
247
}

248
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
249
250
251
  std::stringstream ss;

  // output model type
252
253
  ss << SubModelName() << '\n';
  ss << "version=" << kModelVersion << '\n';
Guolin Ke's avatar
Guolin Ke committed
254
  // output number of class
255
256
  ss << "num_class=" << num_class_ << '\n';
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
257
  // output label index
258
  ss << "label_index=" << label_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
259
  // output max_feature_idx
260
  ss << "max_feature_idx=" << max_feature_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
261
262
  // output objective
  if (objective_function_ != nullptr) {
263
    ss << "objective=" << objective_function_->ToString() << '\n';
Guolin Ke's avatar
Guolin Ke committed
264
265
266
  }

  if (average_output_) {
267
    ss << "average_output" << '\n';
Guolin Ke's avatar
Guolin Ke committed
268
269
  }

270
  ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
271

272
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
273
274

  int num_used_model = static_cast<int>(models_.size());
275
276
277
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
278
  if (num_iteration > 0) {
279
280
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_, num_used_model);
Guolin Ke's avatar
Guolin Ke committed
281
  }
282

283
284
285
286
  int start_model = start_iteration * num_tree_per_iteration_;

  std::vector<std::string> tree_strs(num_used_model - start_model);
  std::vector<size_t> tree_sizes(num_used_model - start_model);
Guolin Ke's avatar
Guolin Ke committed
287
  // output tree models
288
  #pragma omp parallel for schedule(static)
289
290
291
292
293
  for (int i = start_model; i < num_used_model; ++i) {
    const int idx = i - start_model;
    tree_strs[idx] = "Tree=" + std::to_string(idx) + '\n';
    tree_strs[idx] += models_[i]->ToString() + '\n';
    tree_sizes[idx] = tree_strs[idx].size();
294
295
296
297
298
  }

  ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n';
  ss << '\n';

299
  for (int i = 0; i < num_used_model - start_model; ++i) {
300
301
    ss << tree_strs[i];
    tree_strs[i].clear();
Guolin Ke's avatar
Guolin Ke committed
302
  }
Guolin Ke's avatar
Guolin Ke committed
303
  ss << "end of trees" << "\n";
Guolin Ke's avatar
Guolin Ke committed
304

305
  std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
310
311
312
313
314
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
    if (feature_importances_int > 0) {
      pairs.emplace_back(feature_importances_int, feature_names_[i]);
    }
  }
  // sort the importance
315
316
317
  std::stable_sort(pairs.begin(), pairs.end(),
                   [](const std::pair<size_t, std::string>& lhs,
                      const std::pair<size_t, std::string>& rhs) {
Guolin Ke's avatar
Guolin Ke committed
318
319
    return lhs.first > rhs.first;
  });
320
  ss << '\n' << "feature importances:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
321
  for (size_t i = 0; i < pairs.size(); ++i) {
322
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
Guolin Ke's avatar
Guolin Ke committed
323
  }
Guolin Ke's avatar
Guolin Ke committed
324
  if (config_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
325
    ss << "\nparameters:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
326
    ss << config_->ToString() << "\n";
Guolin Ke's avatar
Guolin Ke committed
327
328
329
330
331
    ss << "end of parameters" << '\n';
  } else if (!loaded_parameter_.empty()) {
    ss << "\nparameters:" << '\n';
    ss << loaded_parameter_ << "\n";
    ss << "end of parameters" << '\n';
Guolin Ke's avatar
Guolin Ke committed
332
  }
Guolin Ke's avatar
Guolin Ke committed
333
334
335
  return ss.str();
}

336
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const {
Guolin Ke's avatar
Guolin Ke committed
337
338
  /*! \brief File to write models */
  std::ofstream output_file;
339
  output_file.open(filename, std::ios::out | std::ios::binary);
340
  std::string str_to_write = SaveModelToString(start_iteration, num_iteration);
341
  output_file.write(str_to_write.c_str(), str_to_write.size());
Guolin Ke's avatar
Guolin Ke committed
342
343
  output_file.close();

344
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
345
346
}

347
bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
Guolin Ke's avatar
Guolin Ke committed
348
349
  // use serialized string to restore this object
  models_.clear();
350
351
352
353
354
355
356
  auto c_str = buffer;
  auto p = c_str;
  auto end = p + len;
  std::unordered_map<std::string, std::string> key_vals;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
357
      std::string cur_line(p, line_len);
358
359
360
361
      if (!Common::StartsWith(cur_line, "Tree=")) {
        auto strs = Common::Split(cur_line.c_str(), '=');
        if (strs.size() == 1) {
          key_vals[strs[0]] = "";
362
        } else if (strs.size() == 2) {
363
          key_vals[strs[0]] = strs[1];
364
        } else if (strs.size() > 2) {
Guolin Ke's avatar
Guolin Ke committed
365
          if (strs[0] == "feature_names") {
Guolin Ke's avatar
Guolin Ke committed
366
            key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
367
          } else {
Guolin Ke's avatar
Guolin Ke committed
368
369
            // Use first 128 chars to avoid exceed the message buffer.
            Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
370
          }
371
        }
372
      } else {
373
374
375
376
377
378
        break;
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
Guolin Ke's avatar
Guolin Ke committed
379
380

  // get number of classes
381
382
  if (key_vals.count("num_class")) {
    Common::Atoi(key_vals["num_class"].c_str(), &num_class_);
Guolin Ke's avatar
Guolin Ke committed
383
384
385
386
387
  } else {
    Log::Fatal("Model file doesn't specify the number of classes");
    return false;
  }

388
389
  if (key_vals.count("num_tree_per_iteration")) {
    Common::Atoi(key_vals["num_tree_per_iteration"].c_str(), &num_tree_per_iteration_);
Guolin Ke's avatar
Guolin Ke committed
390
391
392
393
394
  } else {
    num_tree_per_iteration_ = num_class_;
  }

  // get index of label
395
396
  if (key_vals.count("label_index")) {
    Common::Atoi(key_vals["label_index"].c_str(), &label_idx_);
Guolin Ke's avatar
Guolin Ke committed
397
398
399
400
  } else {
    Log::Fatal("Model file doesn't specify the label index");
    return false;
  }
401

Guolin Ke's avatar
Guolin Ke committed
402
  // get max_feature_idx first
403
404
  if (key_vals.count("max_feature_idx")) {
    Common::Atoi(key_vals["max_feature_idx"].c_str(), &max_feature_idx_);
Guolin Ke's avatar
Guolin Ke committed
405
406
407
408
  } else {
    Log::Fatal("Model file doesn't specify max_feature_idx");
    return false;
  }
409

Guolin Ke's avatar
Guolin Ke committed
410
  // get average_output
411
  if (key_vals.count("average_output")) {
Guolin Ke's avatar
Guolin Ke committed
412
413
    average_output_ = true;
  }
414

Guolin Ke's avatar
Guolin Ke committed
415
  // get feature names
416
417
  if (key_vals.count("feature_names")) {
    feature_names_ = Common::Split(key_vals["feature_names"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
418
419
420
421
422
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
      return false;
    }
  } else {
423
    Log::Fatal("Model file doesn't contain feature_names");
Guolin Ke's avatar
Guolin Ke committed
424
425
426
    return false;
  }

427
428
  if (key_vals.count("feature_infos")) {
    feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
429
430
431
432
433
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
434
    Log::Fatal("Model file doesn't contain feature_infos");
Guolin Ke's avatar
Guolin Ke committed
435
436
437
    return false;
  }

438
439
  if (key_vals.count("objective")) {
    auto str = key_vals["objective"];
Guolin Ke's avatar
Guolin Ke committed
440
441
442
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }
443
444
445
446
  if (!key_vals.count("tree_sizes")) {
    while (p < end) {
      auto line_len = Common::GetLine(p);
      if (line_len > 0) {
447
        std::string cur_line(p, line_len);
448
449
450
451
452
453
        if (Common::StartsWith(cur_line, "Tree=")) {
          p += line_len;
          p = Common::SkipNewLine(p);
          size_t used_len = 0;
          models_.emplace_back(new Tree(p, &used_len));
          p += used_len;
454
        } else {
455
456
457
458
459
460
461
462
463
464
465
466
          break;
        }
      }
      p = Common::SkipNewLine(p);
    }
  } else {
    std::vector<size_t> tree_sizes = Common::StringToArray<size_t>(key_vals["tree_sizes"].c_str(), ' ');
    std::vector<size_t> tree_boundries(tree_sizes.size() + 1, 0);
    int num_trees = static_cast<int>(tree_sizes.size());
    for (int i = 0; i < num_trees; ++i) {
      tree_boundries[i + 1] = tree_boundries[i] + tree_sizes[i];
      models_.emplace_back(nullptr);
Guolin Ke's avatar
Guolin Ke committed
467
    }
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_trees; ++i) {
      OMP_LOOP_EX_BEGIN();
      auto cur_p = p + tree_boundries[i];
      auto line_len = Common::GetLine(cur_p);
      std::string cur_line(cur_p, line_len);
      if (Common::StartsWith(cur_line, "Tree=")) {
        cur_p += line_len;
        cur_p = Common::SkipNewLine(cur_p);
        size_t used_len = 0;
        models_[i].reset(new Tree(cur_p, &used_len));
      } else {
        Log::Fatal("Model format error, expect a tree here. met %s", cur_line.c_str());
      }
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
486
487
488
489
  }
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
  num_init_iteration_ = num_iteration_for_pred_;
  iter_ = 0;
Guolin Ke's avatar
Guolin Ke committed
490
491
492
493
494
  bool is_inparameter = false;
  std::stringstream ss;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
495
      std::string cur_line(p, line_len);
Guolin Ke's avatar
Guolin Ke committed
496
497
498
499
500
501
502
503
504
505
506
507
508
509
      if (cur_line == std::string("parameters:")) {
        is_inparameter = true;
      } else if (cur_line == std::string("end of parameters")) {
        break;
      } else if (is_inparameter) {
        ss << cur_line << "\n";
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
  if (!ss.str().empty()) {
    loaded_parameter_ = ss.str();
  }
Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
  return true;
}

std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_type) const {
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  std::vector<double> feature_importances(max_feature_idx_ + 1, 0.0);
  if (importance_type == 0) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
          feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
        }
      }
    }
  } else if (importance_type == 1) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
          feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
        }
      }
    }
  } else {
538
    Log::Fatal("Unknown importance type: only support split=0 and gain=1");
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
543
  }
  return feature_importances;
}

}  // namespace LightGBM