boosting.cpp 2.01 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include <LightGBM/boosting.h>
#include "gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
3
#include "dart.hpp"
Guolin Ke's avatar
Guolin Ke committed
4
#include "goss.hpp"
Guolin Ke's avatar
Guolin Ke committed
5
#include "rf.hpp"
Guolin Ke's avatar
Guolin Ke committed
6
7
8

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
9
std::string GetBoostingTypeFromModelFile(const char* filename) {
10
11
  TextReader<size_t> model_reader(filename, true);
  std::string type = model_reader.first_line();
Guolin Ke's avatar
Guolin Ke committed
12
  return type;
13
14
}

15
16
bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
17
18
  if (boosting != nullptr) {
    TextReader<size_t> model_reader(filename, true);
19
20
    size_t buffer_len = 0;
    auto buffer = model_reader.ReadContent(&buffer_len);
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
    if (!boosting->LoadModelFromString(buffer.data(), buffer_len)) {
      return false;
    }
  }
25
  std::chrono::duration<double, std::milli> delta = (std::chrono::steady_clock::now() - start_time);
Guolin Ke's avatar
Guolin Ke committed
26
27
  Log::Debug("time for loading model: %f seconds", 1e-3*delta);
  return true;
28
29
}

30
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
31
  if (filename == nullptr || filename[0] == '\0') {
Guolin Ke's avatar
Guolin Ke committed
32
    if (type == std::string("gbdt")) {
33
      return new GBDT();
Guolin Ke's avatar
Guolin Ke committed
34
    } else if (type == std::string("dart")) {
35
      return new DART();
Guolin Ke's avatar
Guolin Ke committed
36
37
    } else if (type == std::string("goss")) {
      return new GOSS();
Guolin Ke's avatar
Guolin Ke committed
38
39
    } else if (type == std::string("rf")) {
      return new RF();
40
41
42
    } else {
      return nullptr;
    }
Guolin Ke's avatar
Guolin Ke committed
43
  } else {
Guolin Ke's avatar
Guolin Ke committed
44
    std::unique_ptr<Boosting> ret;
45
    if (GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
Guolin Ke's avatar
Guolin Ke committed
46
      if (type == std::string("gbdt")) {
Guolin Ke's avatar
Guolin Ke committed
47
        ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
48
      } else if (type == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
49
        ret.reset(new DART());
Guolin Ke's avatar
Guolin Ke committed
50
51
      } else if (type == std::string("goss")) {
        ret.reset(new GOSS());
Guolin Ke's avatar
Guolin Ke committed
52
53
      } else if (type == std::string("rf")) {
        return new RF();
Guolin Ke's avatar
Guolin Ke committed
54
      } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
55
        Log::Fatal("unknown boosting type %s", type.c_str());
56
      }
57
      LoadFileToBoosting(ret.get(), filename);
58
    } else {
wxchan's avatar
wxchan committed
59
      Log::Fatal("unknown model format or submodel type in model file %s", filename);
60
    }
Guolin Ke's avatar
Guolin Ke committed
61
    return ret.release();
62
63
64
  }
}

Guolin Ke's avatar
Guolin Ke committed
65
}  // namespace LightGBM