boosting.cpp 2.06 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include <LightGBM/boosting.h>
#include "gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
3
#include "dart.hpp"
Guolin Ke's avatar
Guolin Ke committed
4
#include "goss.hpp"
Guolin Ke's avatar
Guolin Ke committed
5
6
7

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
8
std::string GetBoostingTypeFromModelFile(const char* filename) {
9
10
  TextReader<size_t> model_reader(filename, true);
  std::string type = model_reader.first_line();
Guolin Ke's avatar
Guolin Ke committed
11
  return type;
12
13
}

14
bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
15
16
17
18
19
20
21
  if (boosting != nullptr) {
    TextReader<size_t> model_reader(filename, true);
    model_reader.ReadAllLines();
    std::stringstream str_buf;
    for (auto& line : model_reader.Lines()) {
      str_buf << line << '\n';
    }
22
23
    if (!boosting->LoadModelFromString(str_buf.str()))
        return false;
24
  }
25
26

  return true;
27
28
}

Guolin Ke's avatar
Guolin Ke committed
29
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
30
  if (filename == nullptr || filename[0] == '\0') {
Guolin Ke's avatar
Guolin Ke committed
31
    if (type == std::string("gbdt")) {
32
      return new GBDT();
Guolin Ke's avatar
Guolin Ke committed
33
    } else if (type == std::string("dart")) {
34
      return new DART();
Guolin Ke's avatar
Guolin Ke committed
35
36
    } else if (type == std::string("goss")) {
      return new GOSS();
37
38
39
    } else {
      return nullptr;
    }
Guolin Ke's avatar
Guolin Ke committed
40
  } else {
Guolin Ke's avatar
Guolin Ke committed
41
    std::unique_ptr<Boosting> ret;
42
    auto type_in_file = GetBoostingTypeFromModelFile(filename);
Guolin Ke's avatar
Guolin Ke committed
43
44
    if (type_in_file == std::string("tree")) {
      if (type == std::string("gbdt")) {
Guolin Ke's avatar
Guolin Ke committed
45
        ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
46
      } else if (type == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
47
        ret.reset(new DART());
Guolin Ke's avatar
Guolin Ke committed
48
49
50
      } else if (type == std::string("goss")) {
        ret.reset(new GOSS());
      } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
51
        Log::Fatal("unknown boosting type %s", type.c_str());
52
      }
Guolin Ke's avatar
Guolin Ke committed
53
      LoadFileToBoosting(ret.get(), filename);
54
    } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
55
      Log::Fatal("unknown submodel type in model file %s", filename);
56
    }
Guolin Ke's avatar
Guolin Ke committed
57
    return ret.release();
58
59
60
61
62
  }
}

Boosting* Boosting::CreateBoosting(const char* filename) {
  auto type = GetBoostingTypeFromModelFile(filename);
Guolin Ke's avatar
Guolin Ke committed
63
  std::unique_ptr<Boosting> ret;
Guolin Ke's avatar
Guolin Ke committed
64
  if (type == std::string("tree")) {
Guolin Ke's avatar
Guolin Ke committed
65
    ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
66
  } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
67
    Log::Fatal("unknown submodel type in model file %s", filename);
Guolin Ke's avatar
Guolin Ke committed
68
  }
Guolin Ke's avatar
Guolin Ke committed
69
70
  LoadFileToBoosting(ret.get(), filename);
  return ret.release();
Guolin Ke's avatar
Guolin Ke committed
71
72
73
}

}  // namespace LightGBM