predictor.hpp 5.11 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
#ifndef LIGHTGBM_PREDICTOR_HPP_
#define LIGHTGBM_PREDICTOR_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/boosting.h>
#include <LightGBM/utils/text_reader.h>
#include <LightGBM/dataset.h>

9
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16

#include <cstring>
#include <cstdio>
#include <vector>
#include <utility>
#include <functional>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21

namespace LightGBM {

/*!
zhangyafeikimi's avatar
zhangyafeikimi committed
22
* \brief Used to predict data with input model
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
27
28
*/
class Predictor {
public:
  /*!
  * \brief Constructor
  * \param boosting Input boosting model
Guolin Ke's avatar
Guolin Ke committed
29
  * \param num_iteration Number of boosting round
30
  * \param is_raw_score True if need to predict result with raw score
zhangyafeikimi's avatar
zhangyafeikimi committed
31
  * \param is_predict_leaf_index True if output leaf index instead of prediction score
Guolin Ke's avatar
Guolin Ke committed
32
  */
Guolin Ke's avatar
Guolin Ke committed
33
34
35
  Predictor(Boosting* boosting, int num_iteration,
            bool is_raw_score, bool is_predict_leaf_index) {

36
    boosting->InitPredict(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
37
    boosting_ = boosting;
Guolin Ke's avatar
Guolin Ke committed
38
    num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index);
39
    predict_buf_ = std::vector<double>(boosting_->MaxFeatureIdx() + 1, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
40

Guolin Ke's avatar
Guolin Ke committed
41
    if (is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
42
      predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
43
        CopyToPredictBuffer(features);
Guolin Ke's avatar
Guolin Ke committed
44
        // get result for leaf index
45
46
        boosting_->PredictLeafIndex(predict_buf_.data(), output);
        ClearPredictBuffer(features);
Guolin Ke's avatar
Guolin Ke committed
47
      };
Guolin Ke's avatar
Guolin Ke committed
48

Guolin Ke's avatar
Guolin Ke committed
49
    } else {
Guolin Ke's avatar
Guolin Ke committed
50
      if (is_raw_score) {
Guolin Ke's avatar
Guolin Ke committed
51
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
52
53
54
          CopyToPredictBuffer(features);
          boosting_->PredictRaw(predict_buf_.data(), output);
          ClearPredictBuffer(features);
Guolin Ke's avatar
Guolin Ke committed
55
56
        };
      } else {
Guolin Ke's avatar
Guolin Ke committed
57
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
58
59
60
          CopyToPredictBuffer(features);
          boosting_->Predict(predict_buf_.data(), output);
          ClearPredictBuffer(features);
Guolin Ke's avatar
Guolin Ke committed
61
62
63
        };
      }
    }
Guolin Ke's avatar
Guolin Ke committed
64
  }
65

Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
71
  /*!
  * \brief Destructor
  */
  ~Predictor() {
  }

zhangyafeikimi's avatar
zhangyafeikimi committed
72
  inline const PredictFunction& GetPredictFunction() const {
Guolin Ke's avatar
Guolin Ke committed
73
    return predict_fun_;
74
  }
75

Guolin Ke's avatar
Guolin Ke committed
76
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
77
  * \brief predicting on data, then saving result to disk
Guolin Ke's avatar
Guolin Ke committed
78
79
80
  * \param data_filename Filename of data
  * \param result_filename Filename of output result
  */
Guolin Ke's avatar
Guolin Ke committed
81
  void Predict(const char* data_filename, const char* result_filename, bool has_header) {
Guolin Ke's avatar
Guolin Ke committed
82
83
    FILE* result_file;

Guolin Ke's avatar
Guolin Ke committed
84
    #ifdef _MSC_VER
Guolin Ke's avatar
Guolin Ke committed
85
    fopen_s(&result_file, result_filename, "w");
Guolin Ke's avatar
Guolin Ke committed
86
    #else
Guolin Ke's avatar
Guolin Ke committed
87
    result_file = fopen(result_filename, "w");
Guolin Ke's avatar
Guolin Ke committed
88
    #endif
Guolin Ke's avatar
Guolin Ke committed
89
90

    if (result_file == NULL) {
91
      Log::Fatal("Prediction results file %s doesn't exist", data_filename);
Guolin Ke's avatar
Guolin Ke committed
92
    }
93
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
Guolin Ke's avatar
Guolin Ke committed
94
95

    if (parser == nullptr) {
96
      Log::Fatal("Could not recognize the data format of data file %s", data_filename);
Guolin Ke's avatar
Guolin Ke committed
97
98
99
    }

    // function for parse data
100
101
    std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun;
    double tmp_label;
Guolin Ke's avatar
Guolin Ke committed
102
    parser_fun = [this, &parser, &tmp_label]
103
    (const char* buffer, std::vector<std::pair<int, double>>* feature) {
Guolin Ke's avatar
Guolin Ke committed
104
105
106
      parser->ParseOneLine(buffer, feature, &tmp_label);
    };

Guolin Ke's avatar
Guolin Ke committed
107
    std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
Guolin Ke's avatar
Guolin Ke committed
108
      [this, &parser_fun, &result_file]
Guolin Ke's avatar
Guolin Ke committed
109
    (data_size_t, const std::vector<std::string>& lines) {
110
      std::vector<std::pair<int, double>> oneline_features;
111
      for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
        oneline_features.clear();
        // parser
        parser_fun(lines[i].c_str(), &oneline_features);
        // predict
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
        std::vector<double> result(num_pred_one_row_);
        predict_fun_(oneline_features, result.data());
        auto str_result = Common::Join<double>(result, "\t");
        fprintf(result_file, "%s\n", str_result.c_str());
Guolin Ke's avatar
Guolin Ke committed
120
121
      }
    };
Guolin Ke's avatar
Guolin Ke committed
122
    TextReader<data_size_t> predict_data_reader(data_filename, has_header);
Guolin Ke's avatar
Guolin Ke committed
123
124
125
126
127
    predict_data_reader.ReadAllAndProcessParallel(process_fun);
    fclose(result_file);
  }

private:
128
129

  void CopyToPredictBuffer(const std::vector<std::pair<int, double>>& features) {
Guolin Ke's avatar
Guolin Ke committed
130
    int loop_size = static_cast<int>(features.size());
131
    #pragma omp parallel for schedule(static,128) if (loop_size >= 256)
Guolin Ke's avatar
Guolin Ke committed
132
    for (int i = 0; i < loop_size; ++i) {
133
134
135
136
137
138
139
140
141
142
143
144
      predict_buf_[features[i].first] = features[i].second;
    }
  }

  void ClearPredictBuffer(const std::vector<std::pair<int, double>>& features) {
    if (features.size() < static_cast<size_t>(predict_buf_.size() / 2)) {
      std::memset(predict_buf_.data(), 0, sizeof(double)*(predict_buf_.size()));
    } else {
      int loop_size = static_cast<int>(features.size());
      #pragma omp parallel for schedule(static,128) if (loop_size >= 256)
      for (int i = 0; i < loop_size; ++i) {
        predict_buf_[features[i].first] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
145
146
147
      }
    }
  }
148

Guolin Ke's avatar
Guolin Ke committed
149
150
  /*! \brief Boosting model */
  const Boosting* boosting_;
Guolin Ke's avatar
Guolin Ke committed
151
152
  /*! \brief function for prediction */
  PredictFunction predict_fun_;
Guolin Ke's avatar
Guolin Ke committed
153
  int num_pred_one_row_;
154
  std::vector<double> predict_buf_;
Guolin Ke's avatar
Guolin Ke committed
155
156
157
158
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
159
#endif   // LightGBM_PREDICTOR_HPP_