boosting.cpp 2.35 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include <LightGBM/boosting.h>
#include "gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
3
#include "dart.hpp"
Guolin Ke's avatar
Guolin Ke committed
4
#include "goss.hpp"
Guolin Ke's avatar
Guolin Ke committed
5
#include "rf.hpp"
Guolin Ke's avatar
Guolin Ke committed
6
7
8

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
9
std::string GetBoostingTypeFromModelFile(const char* filename) {
10
11
  TextReader<size_t> model_reader(filename, true);
  std::string type = model_reader.first_line();
Guolin Ke's avatar
Guolin Ke committed
12
  return type;
13
14
}

wxchan's avatar
wxchan committed
15
bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) {
16
  if (boosting != nullptr) {
wxchan's avatar
wxchan committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    if (format == std::string("text")) {
      TextReader<size_t> model_reader(filename, true);
      model_reader.ReadAllLines();
      std::stringstream str_buf;
      for (auto& line : model_reader.Lines()) {
        str_buf << line << '\n';
      }
      if (!boosting->LoadModelFromString(str_buf.str())) {
        return false;
      }
    } else if (format == std::string("proto")) {
      #ifdef USE_PROTO
      if (!boosting->LoadModelFromProto(filename)) {
        return false;
      }
      #else
      Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
      #endif // USE_PROTO
    } else {
      Log::Fatal("Unknown model format during loading: %s", format.c_str());
37
38
    }
  }
39
  return true;
40
41
}

wxchan's avatar
wxchan committed
42
Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
43
  if (filename == nullptr || filename[0] == '\0') {
Guolin Ke's avatar
Guolin Ke committed
44
    if (type == std::string("gbdt")) {
45
      return new GBDT();
Guolin Ke's avatar
Guolin Ke committed
46
    } else if (type == std::string("dart")) {
47
      return new DART();
Guolin Ke's avatar
Guolin Ke committed
48
49
    } else if (type == std::string("goss")) {
      return new GOSS();
Guolin Ke's avatar
Guolin Ke committed
50
51
    } else if (type == std::string("rf")) {
      return new RF();
52
53
54
    } else {
      return nullptr;
    }
Guolin Ke's avatar
Guolin Ke committed
55
  } else {
Guolin Ke's avatar
Guolin Ke committed
56
    std::unique_ptr<Boosting> ret;
wxchan's avatar
wxchan committed
57
    if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
Guolin Ke's avatar
Guolin Ke committed
58
      if (type == std::string("gbdt")) {
Guolin Ke's avatar
Guolin Ke committed
59
        ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
60
      } else if (type == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
61
        ret.reset(new DART());
Guolin Ke's avatar
Guolin Ke committed
62
63
      } else if (type == std::string("goss")) {
        ret.reset(new GOSS());
Guolin Ke's avatar
Guolin Ke committed
64
65
      } else if (type == std::string("rf")) {
        return new RF();
Guolin Ke's avatar
Guolin Ke committed
66
      } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
67
        Log::Fatal("unknown boosting type %s", type.c_str());
68
      }
wxchan's avatar
wxchan committed
69
      LoadFileToBoosting(ret.get(), format, filename);
70
    } else {
wxchan's avatar
wxchan committed
71
      Log::Fatal("unknown model format or submodel type in model file %s", filename);
72
    }
Guolin Ke's avatar
Guolin Ke committed
73
    return ret.release();
74
75
76
  }
}

Guolin Ke's avatar
Guolin Ke committed
77
}  // namespace LightGBM