"python-package/vscode:/vscode.git/clone" did not exist on "691b842832f511a72b4aaab715b3cb6b00de8f90"
boosting.cpp 2.22 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include <LightGBM/boosting.h>
#include "gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
3
#include "dart.hpp"
Guolin Ke's avatar
Guolin Ke committed
4
#include "goss.hpp"
Guolin Ke's avatar
Guolin Ke committed
5
#include "rf.hpp"
Guolin Ke's avatar
Guolin Ke committed
6
7
8

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
9
std::string GetBoostingTypeFromModelFile(const char* filename) {
10
11
  TextReader<size_t> model_reader(filename, true);
  std::string type = model_reader.first_line();
Guolin Ke's avatar
Guolin Ke committed
12
  return type;
13
14
}

wxchan's avatar
wxchan committed
15
bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) {
16
  if (boosting != nullptr) {
wxchan's avatar
wxchan committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    if (format == std::string("text")) {
      TextReader<size_t> model_reader(filename, true);
      model_reader.ReadAllLines();
      std::stringstream str_buf;
      for (auto& line : model_reader.Lines()) {
        str_buf << line << '\n';
      }
      if (!boosting->LoadModelFromString(str_buf.str())) {
        return false;
      }
    } else if (format == std::string("proto")) {
      if (!boosting->LoadModelFromProto(filename)) {
        return false;
      }
    } else {
      Log::Fatal("Unknown model format during loading: %s", format.c_str());
33
34
    }
  }
35
  return true;
36
37
}

wxchan's avatar
wxchan committed
38
Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
39
  if (filename == nullptr || filename[0] == '\0') {
Guolin Ke's avatar
Guolin Ke committed
40
    if (type == std::string("gbdt")) {
41
      return new GBDT();
Guolin Ke's avatar
Guolin Ke committed
42
    } else if (type == std::string("dart")) {
43
      return new DART();
Guolin Ke's avatar
Guolin Ke committed
44
45
    } else if (type == std::string("goss")) {
      return new GOSS();
Guolin Ke's avatar
Guolin Ke committed
46
47
    } else if (type == std::string("rf")) {
      return new RF();
48
49
50
    } else {
      return nullptr;
    }
Guolin Ke's avatar
Guolin Ke committed
51
  } else {
Guolin Ke's avatar
Guolin Ke committed
52
    std::unique_ptr<Boosting> ret;
wxchan's avatar
wxchan committed
53
    if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
Guolin Ke's avatar
Guolin Ke committed
54
      if (type == std::string("gbdt")) {
Guolin Ke's avatar
Guolin Ke committed
55
        ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
56
      } else if (type == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
57
        ret.reset(new DART());
Guolin Ke's avatar
Guolin Ke committed
58
59
      } else if (type == std::string("goss")) {
        ret.reset(new GOSS());
Guolin Ke's avatar
Guolin Ke committed
60
61
      } else if (type == std::string("rf")) {
        return new RF();
Guolin Ke's avatar
Guolin Ke committed
62
      } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
63
        Log::Fatal("unknown boosting type %s", type.c_str());
64
      }
wxchan's avatar
wxchan committed
65
      LoadFileToBoosting(ret.get(), format, filename);
66
    } else {
wxchan's avatar
wxchan committed
67
      Log::Fatal("unknown model format or submodel type in model file %s", filename);
68
    }
Guolin Ke's avatar
Guolin Ke committed
69
    return ret.release();
70
71
72
  }
}

Guolin Ke's avatar
Guolin Ke committed
73
}  // namespace LightGBM