gbdt_model_text.cpp 19.4 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include "gbdt.h"

#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>

#include <sstream>
#include <string>
#include <vector>

namespace LightGBM {

13
14
const std::string kModelVersion = "v2";

15
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
16
17
18
  std::stringstream str_buf;

  str_buf << "{";
19
20
21
22
23
24
  str_buf << "\"name\":\"" << SubModelName() << "\"," << '\n';
  str_buf << "\"version\":\"" << kModelVersion << "\"," << '\n';
  str_buf << "\"num_class\":" << num_class_ << "," << '\n';
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << '\n';
  str_buf << "\"label_index\":" << label_idx_ << "," << '\n';
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << '\n';
Guolin Ke's avatar
Guolin Ke committed
25
26
27

  str_buf << "\"feature_names\":[\""
    << Common::Join(feature_names_, "\",\"") << "\"],"
28
    << '\n';
Guolin Ke's avatar
Guolin Ke committed
29
30
31

  str_buf << "\"tree_info\":[";
  int num_used_model = static_cast<int>(models_.size());
32
33
34
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
35
  if (num_iteration > 0) {
36
37
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_ , num_used_model);
Guolin Ke's avatar
Guolin Ke committed
38
  }
39
40
41
  int start_model = start_iteration * num_tree_per_iteration_;
  for (int i = start_model; i < num_used_model; ++i) {
    if (i > start_model) {
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
      str_buf << ",";
    }
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
  }
49
  str_buf << "]" << '\n';
Guolin Ke's avatar
Guolin Ke committed
50

51
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
57
58

  return str_buf.str();
}

std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

59
60
61
62
63
64
65
66
67
68
69
70
  str_buf << "#include \"gbdt.h\"" << '\n';
  str_buf << "#include <LightGBM/utils/common.h>" << '\n';
  str_buf << "#include <LightGBM/objective_function.h>" << '\n';
  str_buf << "#include <LightGBM/metric.h>" << '\n';
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << '\n';
  str_buf << "#include <ctime>" << '\n';
  str_buf << "#include <sstream>" << '\n';
  str_buf << "#include <chrono>" << '\n';
  str_buf << "#include <string>" << '\n';
  str_buf << "#include <vector>" << '\n';
  str_buf << "#include <utility>" << '\n';
  str_buf << "namespace LightGBM {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
75
76
77
78

  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
79
    str_buf << models_[i]->ToIfElse(i, false) << '\n';
Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
85
86
87
88
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
89
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
90
91
92

  std::stringstream pred_str_buf;

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
108
  str_buf << pred_str_buf.str();
109
110
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
111

112
113
114
  // PredictRawByMap
  str_buf << "double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "ByMap";
119
  }
120
  str_buf << " };" << '\n' << '\n';
121
122
123

  std::stringstream pred_str_buf_map;

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
  pred_str_buf_map << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf_map << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf_map << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf_map << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf_map << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf_map << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
139
  str_buf << pred_str_buf_map.str();
140
141
  str_buf << "}" << '\n';
  str_buf << '\n';
142

Guolin Ke's avatar
Guolin Ke committed
143
  // Predict
144
145
146
147
148
149
150
151
152
153
154
155
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "\t" << "else if (objective_function_ != nullptr) {" << '\n';
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
156

157
  // PredictByMap
158
159
160
161
162
163
164
165
166
167
168
169
  str_buf << "void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRawByMap(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "\t" << "else if (objective_function_ != nullptr) {" << '\n';
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
170
171


Guolin Ke's avatar
Guolin Ke committed
172
173
  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
174
    str_buf << models_[i]->ToIfElse(i, true) << '\n';
Guolin Ke's avatar
Guolin Ke committed
175
176
177
178
179
180
181
182
183
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
184
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
185

186
187
188
189
190
191
  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
192

193
194
195
  //PredictLeafIndexByMap
  str_buf << "double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
196
197
198
199
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "LeafByMap";
200
  }
201
  str_buf << " };" << '\n' << '\n';
202

203
204
205
206
207
208
  str_buf << "void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafByMapPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
209

210
  str_buf << "}  // namespace LightGBM" << '\n';
Guolin Ke's avatar
Guolin Ke committed
211
212
213
214
215
216
217
218
219
220
221
222

  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
      (std::istreambuf_iterator<char>()));
    output_file.open(filename);
223
224
225
226
    output_file << "#define USE_HARD_CODE 0" << '\n';
    output_file << "#ifndef USE_HARD_CODE" << '\n';
    output_file << origin << '\n';
    output_file << "#else" << '\n';
Guolin Ke's avatar
Guolin Ke committed
227
    output_file << ModelToIfElse(num_iteration);
228
    output_file << "#endif" << '\n';
Guolin Ke's avatar
Guolin Ke committed
229
230
231
232
233
234
235
236
237
238
239
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }

  ifs.close();
  output_file.close();

  return (bool)output_file;
}

240
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
241
242
243
  std::stringstream ss;

  // output model type
244
245
  ss << SubModelName() << '\n';
  ss << "version=" << kModelVersion << '\n';
Guolin Ke's avatar
Guolin Ke committed
246
  // output number of class
247
248
  ss << "num_class=" << num_class_ << '\n';
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
249
  // output label index
250
  ss << "label_index=" << label_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
251
  // output max_feature_idx
252
  ss << "max_feature_idx=" << max_feature_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
253
254
  // output objective
  if (objective_function_ != nullptr) {
255
    ss << "objective=" << objective_function_->ToString() << '\n';
Guolin Ke's avatar
Guolin Ke committed
256
257
258
  }

  if (average_output_) {
259
    ss << "average_output" << '\n';
Guolin Ke's avatar
Guolin Ke committed
260
261
  }

262
  ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
263

264
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
265
266

  int num_used_model = static_cast<int>(models_.size());
267
268
269
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
270
  if (num_iteration > 0) {
271
272
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_, num_used_model);
Guolin Ke's avatar
Guolin Ke committed
273
  }
274

275
276
277
278
  int start_model = start_iteration * num_tree_per_iteration_;

  std::vector<std::string> tree_strs(num_used_model - start_model);
  std::vector<size_t> tree_sizes(num_used_model - start_model);
Guolin Ke's avatar
Guolin Ke committed
279
  // output tree models
280
  #pragma omp parallel for schedule(static)
281
282
283
284
285
  for (int i = start_model; i < num_used_model; ++i) {
    const int idx = i - start_model;
    tree_strs[idx] = "Tree=" + std::to_string(idx) + '\n';
    tree_strs[idx] += models_[i]->ToString() + '\n';
    tree_sizes[idx] = tree_strs[idx].size();
286
287
288
289
290
  }

  ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n';
  ss << '\n';

291
  for (int i = 0; i < num_used_model - start_model; ++i) {
292
293
    ss << tree_strs[i];
    tree_strs[i].clear();
Guolin Ke's avatar
Guolin Ke committed
294
  }
Guolin Ke's avatar
Guolin Ke committed
295
  ss << "end of trees" << "\n";
Guolin Ke's avatar
Guolin Ke committed
296

297
  std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
Guolin Ke's avatar
Guolin Ke committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
    if (feature_importances_int > 0) {
      pairs.emplace_back(feature_importances_int, feature_names_[i]);
    }
  }
  // sort the importance
  std::sort(pairs.begin(), pairs.end(),
            [](const std::pair<size_t, std::string>& lhs,
               const std::pair<size_t, std::string>& rhs) {
    return lhs.first > rhs.first;
  });
312
  ss << '\n' << "feature importances:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
313
  for (size_t i = 0; i < pairs.size(); ++i) {
314
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
Guolin Ke's avatar
Guolin Ke committed
315
  }
Guolin Ke's avatar
Guolin Ke committed
316
  if (config_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
317
    ss << "\nparameters:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
318
    ss << config_->ToString() << "\n";
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
323
    ss << "end of parameters" << '\n';
  } else if (!loaded_parameter_.empty()) {
    ss << "\nparameters:" << '\n';
    ss << loaded_parameter_ << "\n";
    ss << "end of parameters" << '\n';
Guolin Ke's avatar
Guolin Ke committed
324
  }
Guolin Ke's avatar
Guolin Ke committed
325
326
327
  return ss.str();
}

328
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const {
Guolin Ke's avatar
Guolin Ke committed
329
330
  /*! \brief File to write models */
  std::ofstream output_file;
331
  output_file.open(filename, std::ios::out | std::ios::binary);
332
  std::string str_to_write = SaveModelToString(start_iteration, num_iteration);
333
  output_file.write(str_to_write.c_str(), str_to_write.size());
Guolin Ke's avatar
Guolin Ke committed
334
335
336
337
338
  output_file.close();

  return (bool)output_file;
}

339
bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
Guolin Ke's avatar
Guolin Ke committed
340
341
  // use serialized string to restore this object
  models_.clear();
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
  auto c_str = buffer;
  auto p = c_str;
  auto end = p + len;
  std::unordered_map<std::string, std::string> key_vals;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    std::string cur_line(p, line_len);
    if (line_len > 0) {
      if (!Common::StartsWith(cur_line, "Tree=")) {
        auto strs = Common::Split(cur_line.c_str(), '=');
        if (strs.size() == 1) {
          key_vals[strs[0]] = "";
        }
        else if (strs.size() == 2) {
          key_vals[strs[0]] = strs[1];
        }
        else if (strs.size() > 2) {
Guolin Ke's avatar
Guolin Ke committed
359
          if (strs[0] == "feature_names") {
Guolin Ke's avatar
Guolin Ke committed
360
            key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
361
          } else {
Guolin Ke's avatar
Guolin Ke committed
362
363
            // Use first 128 chars to avoid exceed the message buffer.
            Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
364
          }
365
366
367
368
369
370
371
372
373
        }
      }
      else {
        break;
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
Guolin Ke's avatar
Guolin Ke committed
374
375

  // get number of classes
376
377
  if (key_vals.count("num_class")) {
    Common::Atoi(key_vals["num_class"].c_str(), &num_class_);
Guolin Ke's avatar
Guolin Ke committed
378
379
380
381
382
  } else {
    Log::Fatal("Model file doesn't specify the number of classes");
    return false;
  }

383
384
  if (key_vals.count("num_tree_per_iteration")) {
    Common::Atoi(key_vals["num_tree_per_iteration"].c_str(), &num_tree_per_iteration_);
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
389
  } else {
    num_tree_per_iteration_ = num_class_;
  }

  // get index of label
390
391
  if (key_vals.count("label_index")) {
    Common::Atoi(key_vals["label_index"].c_str(), &label_idx_);
Guolin Ke's avatar
Guolin Ke committed
392
393
394
395
  } else {
    Log::Fatal("Model file doesn't specify the label index");
    return false;
  }
396

Guolin Ke's avatar
Guolin Ke committed
397
  // get max_feature_idx first
398
399
  if (key_vals.count("max_feature_idx")) {
    Common::Atoi(key_vals["max_feature_idx"].c_str(), &max_feature_idx_);
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
  } else {
    Log::Fatal("Model file doesn't specify max_feature_idx");
    return false;
  }
404

Guolin Ke's avatar
Guolin Ke committed
405
  // get average_output
406
  if (key_vals.count("average_output")) {
Guolin Ke's avatar
Guolin Ke committed
407
408
    average_output_ = true;
  }
409

Guolin Ke's avatar
Guolin Ke committed
410
  // get feature names
411
412
  if (key_vals.count("feature_names")) {
    feature_names_ = Common::Split(key_vals["feature_names"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
413
414
415
416
417
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
      return false;
    }
  } else {
418
    Log::Fatal("Model file doesn't contain feature_names");
Guolin Ke's avatar
Guolin Ke committed
419
420
421
    return false;
  }

422
423
  if (key_vals.count("feature_infos")) {
    feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
428
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
429
    Log::Fatal("Model file doesn't contain feature_infos");
Guolin Ke's avatar
Guolin Ke committed
430
431
432
    return false;
  }

433
434
  if (key_vals.count("objective")) {
    auto str = key_vals["objective"];
Guolin Ke's avatar
Guolin Ke committed
435
436
437
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
    objective_function_ = loaded_objective_.get();
  }
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
  if (!key_vals.count("tree_sizes")) {
    while (p < end) {
      auto line_len = Common::GetLine(p);
      std::string cur_line(p, line_len);
      if (line_len > 0) {
        if (Common::StartsWith(cur_line, "Tree=")) {
          p += line_len;
          p = Common::SkipNewLine(p);
          size_t used_len = 0;
          models_.emplace_back(new Tree(p, &used_len));
          p += used_len;
        }
        else {
          break;
        }
      }
      p = Common::SkipNewLine(p);
    }
  } else {
    std::vector<size_t> tree_sizes = Common::StringToArray<size_t>(key_vals["tree_sizes"].c_str(), ' ');
    std::vector<size_t> tree_boundries(tree_sizes.size() + 1, 0);
    int num_trees = static_cast<int>(tree_sizes.size());
    for (int i = 0; i < num_trees; ++i) {
      tree_boundries[i + 1] = tree_boundries[i] + tree_sizes[i];
      models_.emplace_back(nullptr);
Guolin Ke's avatar
Guolin Ke committed
463
    }
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_trees; ++i) {
      OMP_LOOP_EX_BEGIN();
      auto cur_p = p + tree_boundries[i];
      auto line_len = Common::GetLine(cur_p);
      std::string cur_line(cur_p, line_len);
      if (Common::StartsWith(cur_line, "Tree=")) {
        cur_p += line_len;
        cur_p = Common::SkipNewLine(cur_p);
        size_t used_len = 0;
        models_[i].reset(new Tree(cur_p, &used_len));
      } else {
        Log::Fatal("Model format error, expect a tree here. met %s", cur_line.c_str());
      }
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
482
483
484
485
  }
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
  num_init_iteration_ = num_iteration_for_pred_;
  iter_ = 0;
Guolin Ke's avatar
Guolin Ke committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
  bool is_inparameter = false;
  std::stringstream ss;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    std::string cur_line(p, line_len);
    if (line_len > 0) {
      if (cur_line == std::string("parameters:")) {
        is_inparameter = true;
      } else if (cur_line == std::string("end of parameters")) {
        break;
      } else if (is_inparameter) {
        ss << cur_line << "\n";
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
  if (!ss.str().empty()) {
    loaded_parameter_ = ss.str();
  }
Guolin Ke's avatar
Guolin Ke committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
  return true;
}

std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_type) const {

  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  std::vector<double> feature_importances(max_feature_idx_ + 1, 0.0);
  if (importance_type == 0) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
          feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
        }
      }
    }
  } else if (importance_type == 1) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
          feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
        }
      }
    }
  } else {
535
    Log::Fatal("Unknown importance type: only support split=0 and gain=1");
Guolin Ke's avatar
Guolin Ke committed
536
537
538
539
540
  }
  return feature_importances;
}

}  // namespace LightGBM