boosting.cpp 2.19 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/boosting.h>
6

7
8
9
#include <memory>
#include <string>

Guolin Ke's avatar
Guolin Ke committed
10
#include "dart.hpp"
11
#include "gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
12
#include "rf.hpp"
Guolin Ke's avatar
Guolin Ke committed
13
14
15

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
16
std::string GetBoostingTypeFromModelFile(const char* filename) {
17
18
  TextReader<size_t> model_reader(filename, true);
  std::string type = model_reader.first_line();
Guolin Ke's avatar
Guolin Ke committed
19
  return type;
20
21
}

22
23
bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
  auto start_time = std::chrono::steady_clock::now();
Guolin Ke's avatar
Guolin Ke committed
24
25
  if (boosting != nullptr) {
    TextReader<size_t> model_reader(filename, true);
26
27
    size_t buffer_len = 0;
    auto buffer = model_reader.ReadContent(&buffer_len);
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
    if (!boosting->LoadModelFromString(buffer.data(), buffer_len)) {
      return false;
    }
  }
32
  std::chrono::duration<double, std::milli> delta = (std::chrono::steady_clock::now() - start_time);
33
  Log::Debug("Time for loading model: %f seconds", 1e-3*delta);
Guolin Ke's avatar
Guolin Ke committed
34
  return true;
35
36
}

37
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
38
  if (filename == nullptr || filename[0] == '\0') {
Guolin Ke's avatar
Guolin Ke committed
39
    if (type == std::string("gbdt")) {
40
      return new GBDT();
Guolin Ke's avatar
Guolin Ke committed
41
    } else if (type == std::string("dart")) {
42
      return new DART();
Guolin Ke's avatar
Guolin Ke committed
43
    } else if (type == std::string("goss")) {
44
      return new GBDT();
Guolin Ke's avatar
Guolin Ke committed
45
46
    } else if (type == std::string("rf")) {
      return new RF();
47
48
49
    } else {
      return nullptr;
    }
Guolin Ke's avatar
Guolin Ke committed
50
  } else {
Guolin Ke's avatar
Guolin Ke committed
51
    std::unique_ptr<Boosting> ret;
52
    if (GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
Guolin Ke's avatar
Guolin Ke committed
53
      if (type == std::string("gbdt")) {
Guolin Ke's avatar
Guolin Ke committed
54
        ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
55
      } else if (type == std::string("dart")) {
Guolin Ke's avatar
Guolin Ke committed
56
        ret.reset(new DART());
Guolin Ke's avatar
Guolin Ke committed
57
      } else if (type == std::string("goss")) {
58
        ret.reset(new GBDT());
Guolin Ke's avatar
Guolin Ke committed
59
60
      } else if (type == std::string("rf")) {
        return new RF();
Guolin Ke's avatar
Guolin Ke committed
61
      } else {
62
        Log::Fatal("Unknown boosting type %s", type.c_str());
63
      }
64
      LoadFileToBoosting(ret.get(), filename);
65
    } else {
66
      Log::Fatal("Unknown model format or submodel type in model file %s", filename);
67
    }
Guolin Ke's avatar
Guolin Ke committed
68
    return ret.release();
69
70
71
  }
}

Guolin Ke's avatar
Guolin Ke committed
72
}  // namespace LightGBM