gbdt_model_text.cpp 20.6 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
#include <LightGBM/config.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/metric.h>
7
8
#include <LightGBM/objective_function.h>
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
9
10

#include <string>
11
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
12
13
#include <vector>

14
15
#include "gbdt.h"

Guolin Ke's avatar
Guolin Ke committed
16
17
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
18
const char* kModelVersion = "v3";
19

20
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
  std::stringstream str_buf;

  str_buf << "{";
24
25
26
27
28
29
  str_buf << "\"name\":\"" << SubModelName() << "\"," << '\n';
  str_buf << "\"version\":\"" << kModelVersion << "\"," << '\n';
  str_buf << "\"num_class\":" << num_class_ << "," << '\n';
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << '\n';
  str_buf << "\"label_index\":" << label_idx_ << "," << '\n';
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << '\n';
30
31
32
33
  str_buf << "\"average_output\":" << (average_output_ ? "true" : "false") << ",\n";
  if (objective_function_ != nullptr) {
    str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n";
  }
Guolin Ke's avatar
Guolin Ke committed
34

35
36
37
38
39
  str_buf << "\"feature_names\":[\"" << Common::Join(feature_names_, "\",\"")
          << "\"]," << '\n';

  str_buf << "\"monotone_constraints\":["
          << Common::Join(monotone_constraints_, ",") << "]," << '\n';
Guolin Ke's avatar
Guolin Ke committed
40
41
42

  str_buf << "\"tree_info\":[";
  int num_used_model = static_cast<int>(models_.size());
43
44
45
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
46
  if (num_iteration > 0) {
47
48
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_ , num_used_model);
Guolin Ke's avatar
Guolin Ke committed
49
  }
50
51
52
  int start_model = start_iteration * num_tree_per_iteration_;
  for (int i = start_model; i < num_used_model; ++i) {
    if (i > start_model) {
Guolin Ke's avatar
Guolin Ke committed
53
54
55
56
57
58
59
      str_buf << ",";
    }
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
  }
60
  str_buf << "]" << '\n';
Guolin Ke's avatar
Guolin Ke committed
61

62
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
63
64
65
66
67
68
69

  return str_buf.str();
}

std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

70
71
72
73
74
75
76
77
78
79
80
81
  str_buf << "#include \"gbdt.h\"" << '\n';
  str_buf << "#include <LightGBM/utils/common.h>" << '\n';
  str_buf << "#include <LightGBM/objective_function.h>" << '\n';
  str_buf << "#include <LightGBM/metric.h>" << '\n';
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << '\n';
  str_buf << "#include <ctime>" << '\n';
  str_buf << "#include <sstream>" << '\n';
  str_buf << "#include <chrono>" << '\n';
  str_buf << "#include <string>" << '\n';
  str_buf << "#include <vector>" << '\n';
  str_buf << "#include <utility>" << '\n';
  str_buf << "namespace LightGBM {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89

  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
90
    str_buf << models_[i]->ToIfElse(i, false) << '\n';
Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
95
96
97
98
99
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
100
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
101
102
103

  std::stringstream pred_str_buf;

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
119
  str_buf << pred_str_buf.str();
120
121
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
122

123
124
125
  // PredictRawByMap
  str_buf << "double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "ByMap";
130
  }
131
  str_buf << " };" << '\n' << '\n';
132
133
134

  std::stringstream pred_str_buf_map;

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  pred_str_buf_map << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf_map << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf_map << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf_map << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf_map << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf_map << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
150
  str_buf << pred_str_buf_map.str();
151
152
  str_buf << "}" << '\n';
  str_buf << '\n';
153

Guolin Ke's avatar
Guolin Ke committed
154
  // Predict
155
156
157
158
159
160
161
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
162
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
163
164
165
166
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
167

168
  // PredictByMap
169
170
171
172
173
174
175
  str_buf << "void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRawByMap(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
176
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
177
178
179
180
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
181
182


Guolin Ke's avatar
Guolin Ke committed
183
184
  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
185
    str_buf << models_[i]->ToIfElse(i, true) << '\n';
Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
190
191
192
193
194
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
195
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
196

197
198
199
200
201
202
  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
203

204
  // PredictLeafIndexByMap
205
206
  str_buf << "double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
207
208
209
210
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "LeafByMap";
211
  }
212
  str_buf << " };" << '\n' << '\n';
213

214
215
216
217
218
219
  str_buf << "void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafByMapPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
220

221
  str_buf << "}  // namespace LightGBM" << '\n';
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
226
227
228
229
230
231
232
233

  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
      (std::istreambuf_iterator<char>()));
    output_file.open(filename);
234
235
236
237
    output_file << "#define USE_HARD_CODE 0" << '\n';
    output_file << "#ifndef USE_HARD_CODE" << '\n';
    output_file << origin << '\n';
    output_file << "#else" << '\n';
Guolin Ke's avatar
Guolin Ke committed
238
    output_file << ModelToIfElse(num_iteration);
239
    output_file << "#endif" << '\n';
Guolin Ke's avatar
Guolin Ke committed
240
241
242
243
244
245
246
247
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }

  ifs.close();
  output_file.close();

248
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
249
250
}

251
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
252
253
254
  std::stringstream ss;

  // output model type
255
256
  ss << SubModelName() << '\n';
  ss << "version=" << kModelVersion << '\n';
Guolin Ke's avatar
Guolin Ke committed
257
  // output number of class
258
259
  ss << "num_class=" << num_class_ << '\n';
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
260
  // output label index
261
  ss << "label_index=" << label_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
262
  // output max_feature_idx
263
  ss << "max_feature_idx=" << max_feature_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
264
265
  // output objective
  if (objective_function_ != nullptr) {
266
    ss << "objective=" << objective_function_->ToString() << '\n';
Guolin Ke's avatar
Guolin Ke committed
267
268
269
  }

  if (average_output_) {
270
    ss << "average_output" << '\n';
Guolin Ke's avatar
Guolin Ke committed
271
272
  }

273
  ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
274

275
276
277
278
279
  if (monotone_constraints_.size() != 0) {
    ss << "monotone_constraints=" << Common::Join(monotone_constraints_, " ")
       << '\n';
  }

280
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
281
282

  int num_used_model = static_cast<int>(models_.size());
283
284
285
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
286
  if (num_iteration > 0) {
287
288
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_, num_used_model);
Guolin Ke's avatar
Guolin Ke committed
289
  }
290

291
292
293
294
  int start_model = start_iteration * num_tree_per_iteration_;

  std::vector<std::string> tree_strs(num_used_model - start_model);
  std::vector<size_t> tree_sizes(num_used_model - start_model);
Guolin Ke's avatar
Guolin Ke committed
295
  // output tree models
296
  #pragma omp parallel for schedule(static)
297
298
299
300
301
  for (int i = start_model; i < num_used_model; ++i) {
    const int idx = i - start_model;
    tree_strs[idx] = "Tree=" + std::to_string(idx) + '\n';
    tree_strs[idx] += models_[i]->ToString() + '\n';
    tree_sizes[idx] = tree_strs[idx].size();
302
303
304
305
306
  }

  ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n';
  ss << '\n';

307
  for (int i = 0; i < num_used_model - start_model; ++i) {
308
309
    ss << tree_strs[i];
    tree_strs[i].clear();
Guolin Ke's avatar
Guolin Ke committed
310
  }
Guolin Ke's avatar
Guolin Ke committed
311
  ss << "end of trees" << "\n";
Guolin Ke's avatar
Guolin Ke committed
312

313
  std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318
319
320
321
322
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
    if (feature_importances_int > 0) {
      pairs.emplace_back(feature_importances_int, feature_names_[i]);
    }
  }
  // sort the importance
323
324
325
  std::stable_sort(pairs.begin(), pairs.end(),
                   [](const std::pair<size_t, std::string>& lhs,
                      const std::pair<size_t, std::string>& rhs) {
Guolin Ke's avatar
Guolin Ke committed
326
327
    return lhs.first > rhs.first;
  });
328
  ss << '\n' << "feature importances:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
329
  for (size_t i = 0; i < pairs.size(); ++i) {
330
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
Guolin Ke's avatar
Guolin Ke committed
331
  }
Guolin Ke's avatar
Guolin Ke committed
332
  if (config_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
333
    ss << "\nparameters:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
334
    ss << config_->ToString() << "\n";
Guolin Ke's avatar
Guolin Ke committed
335
336
337
338
339
    ss << "end of parameters" << '\n';
  } else if (!loaded_parameter_.empty()) {
    ss << "\nparameters:" << '\n';
    ss << loaded_parameter_ << "\n";
    ss << "end of parameters" << '\n';
Guolin Ke's avatar
Guolin Ke committed
340
  }
Guolin Ke's avatar
Guolin Ke committed
341
342
343
  return ss.str();
}

344
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const {
Guolin Ke's avatar
Guolin Ke committed
345
346
  /*! \brief File to write models */
  std::ofstream output_file;
347
  output_file.open(filename, std::ios::out | std::ios::binary);
348
  std::string str_to_write = SaveModelToString(start_iteration, num_iteration);
349
  output_file.write(str_to_write.c_str(), str_to_write.size());
Guolin Ke's avatar
Guolin Ke committed
350
351
  output_file.close();

352
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
353
354
}

355
bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
Guolin Ke's avatar
Guolin Ke committed
356
357
  // use serialized string to restore this object
  models_.clear();
358
359
360
361
362
363
364
  auto c_str = buffer;
  auto p = c_str;
  auto end = p + len;
  std::unordered_map<std::string, std::string> key_vals;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
365
      std::string cur_line(p, line_len);
366
367
368
369
      if (!Common::StartsWith(cur_line, "Tree=")) {
        auto strs = Common::Split(cur_line.c_str(), '=');
        if (strs.size() == 1) {
          key_vals[strs[0]] = "";
370
        } else if (strs.size() == 2) {
371
          key_vals[strs[0]] = strs[1];
372
        } else if (strs.size() > 2) {
Guolin Ke's avatar
Guolin Ke committed
373
          if (strs[0] == "feature_names") {
Guolin Ke's avatar
Guolin Ke committed
374
            key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
375
376
          } else if (strs[0] == "monotone_constraints") {
            key_vals[strs[0]] = cur_line.substr(std::strlen("monotone_constraints="));
377
          } else {
Guolin Ke's avatar
Guolin Ke committed
378
379
            // Use first 128 chars to avoid exceed the message buffer.
            Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
380
          }
381
        }
382
      } else {
383
384
385
386
387
388
        break;
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
Guolin Ke's avatar
Guolin Ke committed
389
390

  // get number of classes
391
392
  if (key_vals.count("num_class")) {
    Common::Atoi(key_vals["num_class"].c_str(), &num_class_);
Guolin Ke's avatar
Guolin Ke committed
393
394
395
396
397
  } else {
    Log::Fatal("Model file doesn't specify the number of classes");
    return false;
  }

398
399
  if (key_vals.count("num_tree_per_iteration")) {
    Common::Atoi(key_vals["num_tree_per_iteration"].c_str(), &num_tree_per_iteration_);
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
404
  } else {
    num_tree_per_iteration_ = num_class_;
  }

  // get index of label
405
406
  if (key_vals.count("label_index")) {
    Common::Atoi(key_vals["label_index"].c_str(), &label_idx_);
Guolin Ke's avatar
Guolin Ke committed
407
408
409
410
  } else {
    Log::Fatal("Model file doesn't specify the label index");
    return false;
  }
411

Guolin Ke's avatar
Guolin Ke committed
412
  // get max_feature_idx first
413
414
  if (key_vals.count("max_feature_idx")) {
    Common::Atoi(key_vals["max_feature_idx"].c_str(), &max_feature_idx_);
Guolin Ke's avatar
Guolin Ke committed
415
416
417
418
  } else {
    Log::Fatal("Model file doesn't specify max_feature_idx");
    return false;
  }
419

Guolin Ke's avatar
Guolin Ke committed
420
  // get average_output
421
  if (key_vals.count("average_output")) {
Guolin Ke's avatar
Guolin Ke committed
422
423
    average_output_ = true;
  }
424

Guolin Ke's avatar
Guolin Ke committed
425
  // get feature names
426
427
  if (key_vals.count("feature_names")) {
    feature_names_ = Common::Split(key_vals["feature_names"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
428
429
430
431
432
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
      return false;
    }
  } else {
433
    Log::Fatal("Model file doesn't contain feature_names");
Guolin Ke's avatar
Guolin Ke committed
434
435
436
    return false;
  }

437
438
  // get monotone_constraints
  if (key_vals.count("monotone_constraints")) {
Guolin Ke's avatar
Guolin Ke committed
439
    monotone_constraints_ = Common::StringToArray<int8_t>(key_vals["monotone_constraints"].c_str(), ' ');
440
441
442
443
444
445
    if (monotone_constraints_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of monotone_constraints");
      return false;
    }
  }

446
447
  if (key_vals.count("feature_infos")) {
    feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
452
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
453
    Log::Fatal("Model file doesn't contain feature_infos");
Guolin Ke's avatar
Guolin Ke committed
454
455
456
    return false;
  }

457
458
  if (key_vals.count("objective")) {
    auto str = key_vals["objective"];
459
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(ParseObjectiveAlias(str)));
Guolin Ke's avatar
Guolin Ke committed
460
461
    objective_function_ = loaded_objective_.get();
  }
462

463
464
465
466
  if (!key_vals.count("tree_sizes")) {
    while (p < end) {
      auto line_len = Common::GetLine(p);
      if (line_len > 0) {
467
        std::string cur_line(p, line_len);
468
469
470
471
472
473
        if (Common::StartsWith(cur_line, "Tree=")) {
          p += line_len;
          p = Common::SkipNewLine(p);
          size_t used_len = 0;
          models_.emplace_back(new Tree(p, &used_len));
          p += used_len;
474
        } else {
475
476
477
478
479
480
481
482
483
484
485
486
          break;
        }
      }
      p = Common::SkipNewLine(p);
    }
  } else {
    std::vector<size_t> tree_sizes = Common::StringToArray<size_t>(key_vals["tree_sizes"].c_str(), ' ');
    std::vector<size_t> tree_boundries(tree_sizes.size() + 1, 0);
    int num_trees = static_cast<int>(tree_sizes.size());
    for (int i = 0; i < num_trees; ++i) {
      tree_boundries[i + 1] = tree_boundries[i] + tree_sizes[i];
      models_.emplace_back(nullptr);
Guolin Ke's avatar
Guolin Ke committed
487
    }
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_trees; ++i) {
      OMP_LOOP_EX_BEGIN();
      auto cur_p = p + tree_boundries[i];
      auto line_len = Common::GetLine(cur_p);
      std::string cur_line(cur_p, line_len);
      if (Common::StartsWith(cur_line, "Tree=")) {
        cur_p += line_len;
        cur_p = Common::SkipNewLine(cur_p);
        size_t used_len = 0;
        models_[i].reset(new Tree(cur_p, &used_len));
      } else {
        Log::Fatal("Model format error, expect a tree here. met %s", cur_line.c_str());
      }
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
506
507
508
509
  }
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
  num_init_iteration_ = num_iteration_for_pred_;
  iter_ = 0;
Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
514
  bool is_inparameter = false;
  std::stringstream ss;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
515
      std::string cur_line(p, line_len);
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
521
522
523
524
525
526
527
528
529
      if (cur_line == std::string("parameters:")) {
        is_inparameter = true;
      } else if (cur_line == std::string("end of parameters")) {
        break;
      } else if (is_inparameter) {
        ss << cur_line << "\n";
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
  if (!ss.str().empty()) {
    loaded_parameter_ = ss.str();
  }
Guolin Ke's avatar
Guolin Ke committed
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
  return true;
}

std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_type) const {
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  std::vector<double> feature_importances(max_feature_idx_ + 1, 0.0);
  if (importance_type == 0) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
          feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
        }
      }
    }
  } else if (importance_type == 1) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
          feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
        }
      }
    }
  } else {
558
    Log::Fatal("Unknown importance type: only support split=0 and gain=1");
Guolin Ke's avatar
Guolin Ke committed
559
560
561
562
563
  }
  return feature_importances;
}

}  // namespace LightGBM