boosting.cpp 1.93 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
#include <LightGBM/boosting.h>
#include "gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
3
#include "dart.hpp"
Guolin Ke's avatar
Guolin Ke committed
4
5
6

namespace LightGBM {

7
8
9
10
11
BoostingType GetBoostingTypeFromModelFile(const char* filename) {
  TextReader<size_t> model_reader(filename, true);
  std::string type = model_reader.first_line();
  if (type == std::string("gbdt")) {
    return BoostingType::kGBDT;
12
13
  } else if (type == std::string("dart")) {
    return BoostingType::kDART;
14
15
16
17
18
19
20
21
22
23
24
25
  }
  return BoostingType::kUnknow;
}

void LoadFileToBoosting(Boosting* boosting, const char* filename) {
  if (boosting != nullptr) {
    TextReader<size_t> model_reader(filename, true);
    model_reader.ReadAllLines();
    std::stringstream str_buf;
    for (auto& line : model_reader.Lines()) {
      str_buf << line << '\n';
    }
Guolin Ke's avatar
Guolin Ke committed
26
    boosting->LoadModelFromString(str_buf.str());
27
28
29
30
  }
}

Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
31
  if (filename == nullptr || filename[0] == '\0') {
32
33
    if (type == BoostingType::kGBDT) {
      return new GBDT();
34
35
    } else if (type == BoostingType::kDART) {
      return new DART();
36
37
38
    } else {
      return nullptr;
    }
Guolin Ke's avatar
Guolin Ke committed
39
  } else {
Guolin Ke's avatar
Guolin Ke committed
40
    std::unique_ptr<Boosting> ret;
41
42
43
    auto type_in_file = GetBoostingTypeFromModelFile(filename);
    if (type_in_file == type) {
      if (type == BoostingType::kGBDT) {
Guolin Ke's avatar
Guolin Ke committed
44
        ret.reset(new GBDT());
45
      } else if (type == BoostingType::kDART) {
Guolin Ke's avatar
Guolin Ke committed
46
        ret.reset(new DART());
47
      }
Guolin Ke's avatar
Guolin Ke committed
48
      LoadFileToBoosting(ret.get(), filename);
49
    } else {
50
      Log::Fatal("Boosting type in parameter is not the same as the type in the model file");
51
    }
Guolin Ke's avatar
Guolin Ke committed
52
    return ret.release();
53
54
55
56
57
  }
}

Boosting* Boosting::CreateBoosting(const char* filename) {
  auto type = GetBoostingTypeFromModelFile(filename);
Guolin Ke's avatar
Guolin Ke committed
58
  std::unique_ptr<Boosting> ret;
59
  if (type == BoostingType::kGBDT) {
Guolin Ke's avatar
Guolin Ke committed
60
    ret.reset(new GBDT());
61
  } else if (type == BoostingType::kDART) {
Guolin Ke's avatar
Guolin Ke committed
62
    ret.reset(new DART());
Guolin Ke's avatar
Guolin Ke committed
63
  }
Guolin Ke's avatar
Guolin Ke committed
64
65
  LoadFileToBoosting(ret.get(), filename);
  return ret.release();
Guolin Ke's avatar
Guolin Ke committed
66
67
68
}

}  // namespace LightGBM