boosting.cpp 1.82 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include <LightGBM/boosting.h>
#include "gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
3
#include "dart.hpp"
Guolin Ke's avatar
Guolin Ke committed
4
5
6

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
7
std::string GetBoostingTypeFromModelFile(const char* filename) {
8
9
  TextReader<size_t> model_reader(filename, true);
  std::string type = model_reader.first_line();
Guolin Ke's avatar
Guolin Ke committed
10
  return type;
11
12
}

13
bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
14
15
16
17
18
19
20
  if (boosting != nullptr) {
    TextReader<size_t> model_reader(filename, true);
    model_reader.ReadAllLines();
    std::stringstream str_buf;
    for (auto& line : model_reader.Lines()) {
      str_buf << line << '\n';
    }
21
22
    if (!boosting->LoadModelFromString(str_buf.str()))
        return false;
23
  }
24
25

  return true;
26
27
}

Guolin Ke's avatar
Guolin Ke committed
28
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
29
  if (filename == nullptr || filename[0] == '\0') {
Guolin Ke's avatar
Guolin Ke committed
30
    if (type == std::string("gbdt")) {
31
      return new GBDT();
Guolin Ke's avatar
Guolin Ke committed
32
    } else if (type == std::string("dart")) {
33
      return new DART();
34
35
36
    } else {
      return nullptr;
    }
Guolin Ke's avatar
Guolin Ke committed
37
  } else {
Guolin Ke's avatar
Guolin Ke committed
38
    std::unique_ptr<Boosting> ret;
39
    auto type_in_file = GetBoostingTypeFromModelFile(filename);
Guolin Ke's avatar
Guolin Ke committed
40
41
    if (type_in_file == std::string("tree")) {
      if (type == std::string("gbdt")) {
Guolin Ke's avatar
Guolin Ke committed
42
        ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
43
      } else if (type == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
44
        ret.reset(new DART());
45
      }
Guolin Ke's avatar
Guolin Ke committed
46
      LoadFileToBoosting(ret.get(), filename);
47
    } else {
Guolin Ke's avatar
Guolin Ke committed
48
      Log::Fatal("unknow submodel type in model file %s", filename);
49
    }
Guolin Ke's avatar
Guolin Ke committed
50
    return ret.release();
51
52
53
54
55
  }
}

Boosting* Boosting::CreateBoosting(const char* filename) {
  auto type = GetBoostingTypeFromModelFile(filename);
Guolin Ke's avatar
Guolin Ke committed
56
  std::unique_ptr<Boosting> ret;
Guolin Ke's avatar
Guolin Ke committed
57
  if (type == std::string("tree")) {
Guolin Ke's avatar
Guolin Ke committed
58
    ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
59
60
  } else {
    Log::Fatal("unknow submodel type in model file %s", filename);
Guolin Ke's avatar
Guolin Ke committed
61
  }
Guolin Ke's avatar
Guolin Ke committed
62
63
  LoadFileToBoosting(ret.get(), filename);
  return ret.release();
Guolin Ke's avatar
Guolin Ke committed
64
65
66
}

}  // namespace LightGBM