gbdt_model_text.cpp 21.5 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
#include <LightGBM/config.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/metric.h>
7
8
#include <LightGBM/objective_function.h>
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
9
10

#include <string>
11
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
12
13
#include <vector>

14
15
#include "gbdt.h"

Guolin Ke's avatar
Guolin Ke committed
16
17
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
18
const char* kModelVersion = "v3";
19

20
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
  std::stringstream str_buf;

  str_buf << "{";
24
25
26
27
28
29
  str_buf << "\"name\":\"" << SubModelName() << "\"," << '\n';
  str_buf << "\"version\":\"" << kModelVersion << "\"," << '\n';
  str_buf << "\"num_class\":" << num_class_ << "," << '\n';
  str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << '\n';
  str_buf << "\"label_index\":" << label_idx_ << "," << '\n';
  str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << '\n';
30
31
32
  if (objective_function_ != nullptr) {
    str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n";
  }
Guolin Ke's avatar
Guolin Ke committed
33

34
35
  str_buf << "\"average_output\":" << (average_output_ ? "true" : "false") << ",\n";

36
37
38
39
40
  str_buf << "\"feature_names\":[\"" << Common::Join(feature_names_, "\",\"")
          << "\"]," << '\n';

  str_buf << "\"monotone_constraints\":["
          << Common::Join(monotone_constraints_, ",") << "]," << '\n';
Guolin Ke's avatar
Guolin Ke committed
41
42
43

  str_buf << "\"tree_info\":[";
  int num_used_model = static_cast<int>(models_.size());
44
45
46
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
47
  if (num_iteration > 0) {
48
49
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_ , num_used_model);
Guolin Ke's avatar
Guolin Ke committed
50
  }
51
52
53
  int start_model = start_iteration * num_tree_per_iteration_;
  for (int i = start_model; i < num_used_model; ++i) {
    if (i > start_model) {
Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
58
59
60
      str_buf << ",";
    }
    str_buf << "{";
    str_buf << "\"tree_index\":" << i << ",";
    str_buf << models_[i]->ToJSON();
    str_buf << "}";
  }
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  str_buf << "]," << '\n';

  std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
    if (feature_importances_int > 0) {
      pairs.emplace_back(feature_importances_int, feature_names_[i]);
    }
  }
  str_buf << '\n' << "\"feature_importances\":" << "{";
  if (!pairs.empty()) {
    str_buf << "\"" << pairs[0].second << "\":" << std::to_string(pairs[0].first);
    for (size_t i = 1; i < pairs.size(); ++i) {
      str_buf << ",";
      str_buf << "\"" << pairs[i].second << "\":" << std::to_string(pairs[i].first);
    }
  }
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
81

82
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
83
84
85
86
87
88
89

  return str_buf.str();
}

std::string GBDT::ModelToIfElse(int num_iteration) const {
  std::stringstream str_buf;

90
91
92
93
94
95
96
97
98
99
100
101
  str_buf << "#include \"gbdt.h\"" << '\n';
  str_buf << "#include <LightGBM/utils/common.h>" << '\n';
  str_buf << "#include <LightGBM/objective_function.h>" << '\n';
  str_buf << "#include <LightGBM/metric.h>" << '\n';
  str_buf << "#include <LightGBM/prediction_early_stop.h>" << '\n';
  str_buf << "#include <ctime>" << '\n';
  str_buf << "#include <sstream>" << '\n';
  str_buf << "#include <chrono>" << '\n';
  str_buf << "#include <string>" << '\n';
  str_buf << "#include <vector>" << '\n';
  str_buf << "#include <utility>" << '\n';
  str_buf << "namespace LightGBM {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
102
103
104
105
106
107
108
109

  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  // PredictRaw
  for (int i = 0; i < num_used_model; ++i) {
110
    str_buf << models_[i]->ToIfElse(i, false) << '\n';
Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
115
116
117
118
119
  }

  str_buf << "double (*PredictTreePtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i;
  }
120
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
121
122
123

  std::stringstream pred_str_buf;

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
  pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf << "\t\t" << "}" << '\n';
  pred_str_buf << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
Guolin Ke's avatar
Guolin Ke committed
139
  str_buf << pred_str_buf.str();
140
141
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
142

143
144
145
  // PredictRawByMap
  str_buf << "double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "ByMap";
150
  }
151
  str_buf << " };" << '\n' << '\n';
152
153
154

  std::stringstream pred_str_buf_map;

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
  pred_str_buf_map << "\t" << "int early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
  pred_str_buf_map << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
  pred_str_buf_map << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t\t" << "++early_stop_round_counter;" << '\n';
  pred_str_buf_map << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
  pred_str_buf_map << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
  pred_str_buf_map << "\t\t\t\t" << "return;" << '\n';
  pred_str_buf_map << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
  pred_str_buf_map << "\t\t" << "}" << '\n';
  pred_str_buf_map << "\t" << "}" << '\n';

  str_buf << "void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
170
  str_buf << pred_str_buf_map.str();
171
172
  str_buf << "}" << '\n';
  str_buf << '\n';
173

Guolin Ke's avatar
Guolin Ke committed
174
  // Predict
175
176
177
178
179
180
181
  str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRaw(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
182
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
183
184
185
186
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
Guolin Ke's avatar
Guolin Ke committed
187

188
  // PredictByMap
189
190
191
192
193
194
195
  str_buf << "void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
  str_buf << "\t" << "PredictRawByMap(features, output, early_stop);" << '\n';
  str_buf << "\t" << "if (average_output_) {" << '\n';
  str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
  str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
  str_buf << "\t\t" << "}" << '\n';
  str_buf << "\t" << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
196
  str_buf << "\t" << "if (objective_function_ != nullptr) {" << '\n';
197
198
199
200
  str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
  str_buf << '\n';
201
202


Guolin Ke's avatar
Guolin Ke committed
203
204
  // PredictLeafIndex
  for (int i = 0; i < num_used_model; ++i) {
205
    str_buf << models_[i]->ToIfElse(i, true) << '\n';
Guolin Ke's avatar
Guolin Ke committed
206
207
208
209
210
211
212
213
214
  }

  str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
  for (int i = 0; i < num_used_model; ++i) {
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "Leaf";
  }
215
  str_buf << " };" << '\n' << '\n';
Guolin Ke's avatar
Guolin Ke committed
216

217
218
219
220
221
222
  str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
Guolin Ke's avatar
Guolin Ke committed
223

224
  // PredictLeafIndexByMap
225
226
  str_buf << "double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { ";
  for (int i = 0; i < num_used_model; ++i) {
Guolin Ke's avatar
Guolin Ke committed
227
228
229
230
    if (i > 0) {
      str_buf << " , ";
    }
    str_buf << "PredictTree" << i << "LeafByMap";
231
  }
232
  str_buf << " };" << '\n' << '\n';
233

234
235
236
237
238
239
  str_buf << "void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {" << '\n';
  str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
  str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
  str_buf << "\t\t" << "output[i] = (*PredictTreeLeafByMapPtr[i])(features);" << '\n';
  str_buf << "\t" << "}" << '\n';
  str_buf << "}" << '\n';
240

241
  str_buf << "}  // namespace LightGBM" << '\n';
Guolin Ke's avatar
Guolin Ke committed
242
243
244
245
246
247
248
249
250
251
252
253

  return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
  /*! \brief File to write models */
  std::ofstream output_file;
  std::ifstream ifs(filename);
  if (ifs.good()) {
    std::string origin((std::istreambuf_iterator<char>(ifs)),
      (std::istreambuf_iterator<char>()));
    output_file.open(filename);
254
255
256
257
    output_file << "#define USE_HARD_CODE 0" << '\n';
    output_file << "#ifndef USE_HARD_CODE" << '\n';
    output_file << origin << '\n';
    output_file << "#else" << '\n';
Guolin Ke's avatar
Guolin Ke committed
258
    output_file << ModelToIfElse(num_iteration);
259
    output_file << "#endif" << '\n';
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
264
265
266
267
  } else {
    output_file.open(filename);
    output_file << ModelToIfElse(num_iteration);
  }

  ifs.close();
  output_file.close();

268
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
269
270
}

271
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const {
Guolin Ke's avatar
Guolin Ke committed
272
273
274
  std::stringstream ss;

  // output model type
275
276
  ss << SubModelName() << '\n';
  ss << "version=" << kModelVersion << '\n';
Guolin Ke's avatar
Guolin Ke committed
277
  // output number of class
278
279
  ss << "num_class=" << num_class_ << '\n';
  ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
280
  // output label index
281
  ss << "label_index=" << label_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
282
  // output max_feature_idx
283
  ss << "max_feature_idx=" << max_feature_idx_ << '\n';
Guolin Ke's avatar
Guolin Ke committed
284
285
  // output objective
  if (objective_function_ != nullptr) {
286
    ss << "objective=" << objective_function_->ToString() << '\n';
Guolin Ke's avatar
Guolin Ke committed
287
288
289
  }

  if (average_output_) {
290
    ss << "average_output" << '\n';
Guolin Ke's avatar
Guolin Ke committed
291
292
  }

293
  ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
294

295
296
297
298
299
  if (monotone_constraints_.size() != 0) {
    ss << "monotone_constraints=" << Common::Join(monotone_constraints_, " ")
       << '\n';
  }

300
  ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
Guolin Ke's avatar
Guolin Ke committed
301
302

  int num_used_model = static_cast<int>(models_.size());
303
304
305
  int total_iteration = num_used_model / num_tree_per_iteration_;
  start_iteration = std::max(start_iteration, 0);
  start_iteration = std::min(start_iteration, total_iteration);
Guolin Ke's avatar
Guolin Ke committed
306
  if (num_iteration > 0) {
307
308
    int end_iteration = start_iteration + num_iteration;
    num_used_model = std::min(end_iteration * num_tree_per_iteration_, num_used_model);
Guolin Ke's avatar
Guolin Ke committed
309
  }
310

311
312
313
314
  int start_model = start_iteration * num_tree_per_iteration_;

  std::vector<std::string> tree_strs(num_used_model - start_model);
  std::vector<size_t> tree_sizes(num_used_model - start_model);
Guolin Ke's avatar
Guolin Ke committed
315
  // output tree models
316
  #pragma omp parallel for schedule(static)
317
318
319
320
321
  for (int i = start_model; i < num_used_model; ++i) {
    const int idx = i - start_model;
    tree_strs[idx] = "Tree=" + std::to_string(idx) + '\n';
    tree_strs[idx] += models_[i]->ToString() + '\n';
    tree_sizes[idx] = tree_strs[idx].size();
322
323
324
325
326
  }

  ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n';
  ss << '\n';

327
  for (int i = 0; i < num_used_model - start_model; ++i) {
328
329
    ss << tree_strs[i];
    tree_strs[i].clear();
Guolin Ke's avatar
Guolin Ke committed
330
  }
Guolin Ke's avatar
Guolin Ke committed
331
  ss << "end of trees" << "\n";
Guolin Ke's avatar
Guolin Ke committed
332

333
  std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
Guolin Ke's avatar
Guolin Ke committed
334
335
336
337
338
339
340
341
342
  // store the importance first
  std::vector<std::pair<size_t, std::string>> pairs;
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
    if (feature_importances_int > 0) {
      pairs.emplace_back(feature_importances_int, feature_names_[i]);
    }
  }
  // sort the importance
343
344
345
  std::stable_sort(pairs.begin(), pairs.end(),
                   [](const std::pair<size_t, std::string>& lhs,
                      const std::pair<size_t, std::string>& rhs) {
Guolin Ke's avatar
Guolin Ke committed
346
347
    return lhs.first > rhs.first;
  });
348
  ss << '\n' << "feature_importances:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
349
  for (size_t i = 0; i < pairs.size(); ++i) {
350
    ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
Guolin Ke's avatar
Guolin Ke committed
351
  }
Guolin Ke's avatar
Guolin Ke committed
352
  if (config_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
353
    ss << "\nparameters:" << '\n';
Guolin Ke's avatar
Guolin Ke committed
354
    ss << config_->ToString() << "\n";
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
359
    ss << "end of parameters" << '\n';
  } else if (!loaded_parameter_.empty()) {
    ss << "\nparameters:" << '\n';
    ss << loaded_parameter_ << "\n";
    ss << "end of parameters" << '\n';
Guolin Ke's avatar
Guolin Ke committed
360
  }
Guolin Ke's avatar
Guolin Ke committed
361
362
363
  return ss.str();
}

364
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const {
Guolin Ke's avatar
Guolin Ke committed
365
366
  /*! \brief File to write models */
  std::ofstream output_file;
367
  output_file.open(filename, std::ios::out | std::ios::binary);
368
  std::string str_to_write = SaveModelToString(start_iteration, num_iteration);
369
  output_file.write(str_to_write.c_str(), str_to_write.size());
Guolin Ke's avatar
Guolin Ke committed
370
371
  output_file.close();

372
  return static_cast<bool>(output_file);
Guolin Ke's avatar
Guolin Ke committed
373
374
}

375
bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
Guolin Ke's avatar
Guolin Ke committed
376
377
  // use serialized string to restore this object
  models_.clear();
378
379
380
381
382
383
384
  auto c_str = buffer;
  auto p = c_str;
  auto end = p + len;
  std::unordered_map<std::string, std::string> key_vals;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
385
      std::string cur_line(p, line_len);
386
387
388
389
      if (!Common::StartsWith(cur_line, "Tree=")) {
        auto strs = Common::Split(cur_line.c_str(), '=');
        if (strs.size() == 1) {
          key_vals[strs[0]] = "";
390
        } else if (strs.size() == 2) {
391
          key_vals[strs[0]] = strs[1];
392
        } else if (strs.size() > 2) {
Guolin Ke's avatar
Guolin Ke committed
393
          if (strs[0] == "feature_names") {
Guolin Ke's avatar
Guolin Ke committed
394
            key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
395
396
          } else if (strs[0] == "monotone_constraints") {
            key_vals[strs[0]] = cur_line.substr(std::strlen("monotone_constraints="));
397
          } else {
Guolin Ke's avatar
Guolin Ke committed
398
399
            // Use first 128 chars to avoid exceed the message buffer.
            Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
400
          }
401
        }
402
      } else {
403
404
405
406
407
408
        break;
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
Guolin Ke's avatar
Guolin Ke committed
409
410

  // get number of classes
411
412
  if (key_vals.count("num_class")) {
    Common::Atoi(key_vals["num_class"].c_str(), &num_class_);
Guolin Ke's avatar
Guolin Ke committed
413
414
415
416
417
  } else {
    Log::Fatal("Model file doesn't specify the number of classes");
    return false;
  }

418
419
  if (key_vals.count("num_tree_per_iteration")) {
    Common::Atoi(key_vals["num_tree_per_iteration"].c_str(), &num_tree_per_iteration_);
Guolin Ke's avatar
Guolin Ke committed
420
421
422
423
424
  } else {
    num_tree_per_iteration_ = num_class_;
  }

  // get index of label
425
426
  if (key_vals.count("label_index")) {
    Common::Atoi(key_vals["label_index"].c_str(), &label_idx_);
Guolin Ke's avatar
Guolin Ke committed
427
428
429
430
  } else {
    Log::Fatal("Model file doesn't specify the label index");
    return false;
  }
431

Guolin Ke's avatar
Guolin Ke committed
432
  // get max_feature_idx first
433
434
  if (key_vals.count("max_feature_idx")) {
    Common::Atoi(key_vals["max_feature_idx"].c_str(), &max_feature_idx_);
Guolin Ke's avatar
Guolin Ke committed
435
436
437
438
  } else {
    Log::Fatal("Model file doesn't specify max_feature_idx");
    return false;
  }
439

Guolin Ke's avatar
Guolin Ke committed
440
  // get average_output
441
  if (key_vals.count("average_output")) {
Guolin Ke's avatar
Guolin Ke committed
442
443
    average_output_ = true;
  }
444

Guolin Ke's avatar
Guolin Ke committed
445
  // get feature names
446
447
  if (key_vals.count("feature_names")) {
    feature_names_ = Common::Split(key_vals["feature_names"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
452
    if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_names");
      return false;
    }
  } else {
453
    Log::Fatal("Model file doesn't contain feature_names");
Guolin Ke's avatar
Guolin Ke committed
454
455
456
    return false;
  }

457
458
  // get monotone_constraints
  if (key_vals.count("monotone_constraints")) {
Guolin Ke's avatar
Guolin Ke committed
459
    monotone_constraints_ = Common::StringToArray<int8_t>(key_vals["monotone_constraints"].c_str(), ' ');
460
461
462
463
464
465
    if (monotone_constraints_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of monotone_constraints");
      return false;
    }
  }

466
467
  if (key_vals.count("feature_infos")) {
    feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
Guolin Ke's avatar
Guolin Ke committed
468
469
470
471
472
    if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
      Log::Fatal("Wrong size of feature_infos");
      return false;
    }
  } else {
473
    Log::Fatal("Model file doesn't contain feature_infos");
Guolin Ke's avatar
Guolin Ke committed
474
475
476
    return false;
  }

477
478
  if (key_vals.count("objective")) {
    auto str = key_vals["objective"];
479
    loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(ParseObjectiveAlias(str)));
Guolin Ke's avatar
Guolin Ke committed
480
481
    objective_function_ = loaded_objective_.get();
  }
482

483
484
485
486
  if (!key_vals.count("tree_sizes")) {
    while (p < end) {
      auto line_len = Common::GetLine(p);
      if (line_len > 0) {
487
        std::string cur_line(p, line_len);
488
489
490
491
492
493
        if (Common::StartsWith(cur_line, "Tree=")) {
          p += line_len;
          p = Common::SkipNewLine(p);
          size_t used_len = 0;
          models_.emplace_back(new Tree(p, &used_len));
          p += used_len;
494
        } else {
495
496
497
498
499
500
501
502
503
504
505
506
          break;
        }
      }
      p = Common::SkipNewLine(p);
    }
  } else {
    std::vector<size_t> tree_sizes = Common::StringToArray<size_t>(key_vals["tree_sizes"].c_str(), ' ');
    std::vector<size_t> tree_boundries(tree_sizes.size() + 1, 0);
    int num_trees = static_cast<int>(tree_sizes.size());
    for (int i = 0; i < num_trees; ++i) {
      tree_boundries[i + 1] = tree_boundries[i] + tree_sizes[i];
      models_.emplace_back(nullptr);
Guolin Ke's avatar
Guolin Ke committed
507
    }
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_trees; ++i) {
      OMP_LOOP_EX_BEGIN();
      auto cur_p = p + tree_boundries[i];
      auto line_len = Common::GetLine(cur_p);
      std::string cur_line(cur_p, line_len);
      if (Common::StartsWith(cur_line, "Tree=")) {
        cur_p += line_len;
        cur_p = Common::SkipNewLine(cur_p);
        size_t used_len = 0;
        models_[i].reset(new Tree(cur_p, &used_len));
      } else {
        Log::Fatal("Model format error, expect a tree here. met %s", cur_line.c_str());
      }
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
526
527
528
529
  }
  num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
  num_init_iteration_ = num_iteration_for_pred_;
  iter_ = 0;
Guolin Ke's avatar
Guolin Ke committed
530
531
532
533
534
  bool is_inparameter = false;
  std::stringstream ss;
  while (p < end) {
    auto line_len = Common::GetLine(p);
    if (line_len > 0) {
535
      std::string cur_line(p, line_len);
Guolin Ke's avatar
Guolin Ke committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
      if (cur_line == std::string("parameters:")) {
        is_inparameter = true;
      } else if (cur_line == std::string("end of parameters")) {
        break;
      } else if (is_inparameter) {
        ss << cur_line << "\n";
      }
    }
    p += line_len;
    p = Common::SkipNewLine(p);
  }
  if (!ss.str().empty()) {
    loaded_parameter_ = ss.str();
  }
Guolin Ke's avatar
Guolin Ke committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
  return true;
}

std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_type) const {
  int num_used_model = static_cast<int>(models_.size());
  if (num_iteration > 0) {
    num_iteration += 0;
    num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
  }

  std::vector<double> feature_importances(max_feature_idx_ + 1, 0.0);
  if (importance_type == 0) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
565
566
567
#ifdef DEBUG
          CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
Guolin Ke's avatar
Guolin Ke committed
568
569
570
571
572
573
574
575
          feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
        }
      }
    }
  } else if (importance_type == 1) {
    for (int iter = 0; iter < num_used_model; ++iter) {
      for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
        if (models_[iter]->split_gain(split_idx) > 0) {
576
577
578
#ifdef DEBUG
          CHECK(models_[iter]->split_feature(split_idx) >= 0);
#endif
Guolin Ke's avatar
Guolin Ke committed
579
580
581
582
583
          feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
        }
      }
    }
  } else {
584
    Log::Fatal("Unknown importance type: only support split=0 and gain=1");
Guolin Ke's avatar
Guolin Ke committed
585
586
587
588
589
  }
  return feature_importances;
}

}  // namespace LightGBM