predictor.hpp 5.66 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
#ifndef LIGHTGBM_PREDICTOR_HPP_
#define LIGHTGBM_PREDICTOR_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/boosting.h>
#include <LightGBM/utils/text_reader.h>
#include <LightGBM/dataset.h>

9
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16

#include <cstring>
#include <cstdio>
#include <vector>
#include <utility>
#include <functional>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21

namespace LightGBM {

/*!
zhangyafeikimi's avatar
zhangyafeikimi committed
22
* \brief Used to predict data with input model
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
27
28
*/
class Predictor {
public:
  /*!
  * \brief Constructor
  * \param boosting Input boosting model
Guolin Ke's avatar
Guolin Ke committed
29
  * \param num_iteration Number of boosting round
30
  * \param is_raw_score True if need to predict result with raw score
zhangyafeikimi's avatar
zhangyafeikimi committed
31
  * \param is_predict_leaf_index True if output leaf index instead of prediction score
Guolin Ke's avatar
Guolin Ke committed
32
  */
Guolin Ke's avatar
Guolin Ke committed
33
34
  Predictor(Boosting* boosting, int num_iteration,
            bool is_raw_score, bool is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
35
36
37
38
39
    #pragma omp parallel
    #pragma omp master
    {
      num_threads_ = omp_get_num_threads();
    }
40
    boosting->InitPredict(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
41
    boosting_ = boosting;
Guolin Ke's avatar
Guolin Ke committed
42
    num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index);
Guolin Ke's avatar
Guolin Ke committed
43
    predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(boosting_->MaxFeatureIdx() + 1, 0.0f));
Guolin Ke's avatar
Guolin Ke committed
44

Guolin Ke's avatar
Guolin Ke committed
45
    if (is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
46
      predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
47
48
        int tid = omp_get_thread_num();
        CopyToPredictBuffer(predict_buf_[tid].data(), features);
Guolin Ke's avatar
Guolin Ke committed
49
        // get result for leaf index
Guolin Ke's avatar
Guolin Ke committed
50
51
        boosting_->PredictLeafIndex(predict_buf_[tid].data(), output);
        ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
Guolin Ke's avatar
Guolin Ke committed
52
      };
Guolin Ke's avatar
Guolin Ke committed
53

Guolin Ke's avatar
Guolin Ke committed
54
    } else {
Guolin Ke's avatar
Guolin Ke committed
55
      if (is_raw_score) {
Guolin Ke's avatar
Guolin Ke committed
56
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
          int tid = omp_get_thread_num();
          CopyToPredictBuffer(predict_buf_[tid].data(), features);
          boosting_->PredictRaw(predict_buf_[tid].data(), output);
          ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
Guolin Ke's avatar
Guolin Ke committed
61
62
        };
      } else {
Guolin Ke's avatar
Guolin Ke committed
63
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
          int tid = omp_get_thread_num();
          CopyToPredictBuffer(predict_buf_[tid].data(), features);
          boosting_->Predict(predict_buf_[tid].data(), output);
          ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
Guolin Ke's avatar
Guolin Ke committed
68
69
70
        };
      }
    }
Guolin Ke's avatar
Guolin Ke committed
71
  }
72

Guolin Ke's avatar
Guolin Ke committed
73
74
75
76
77
78
  /*!
  * \brief Destructor
  */
  ~Predictor() {
  }

zhangyafeikimi's avatar
zhangyafeikimi committed
79
  inline const PredictFunction& GetPredictFunction() const {
Guolin Ke's avatar
Guolin Ke committed
80
    return predict_fun_;
81
  }
82

Guolin Ke's avatar
Guolin Ke committed
83
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
84
  * \brief predicting on data, then saving result to disk
Guolin Ke's avatar
Guolin Ke committed
85
86
87
  * \param data_filename Filename of data
  * \param result_filename Filename of output result
  */
Guolin Ke's avatar
Guolin Ke committed
88
  void Predict(const char* data_filename, const char* result_filename, bool has_header) {
Guolin Ke's avatar
Guolin Ke committed
89
90
    FILE* result_file;

Guolin Ke's avatar
Guolin Ke committed
91
    #ifdef _MSC_VER
Guolin Ke's avatar
Guolin Ke committed
92
    fopen_s(&result_file, result_filename, "w");
Guolin Ke's avatar
Guolin Ke committed
93
    #else
Guolin Ke's avatar
Guolin Ke committed
94
    result_file = fopen(result_filename, "w");
Guolin Ke's avatar
Guolin Ke committed
95
    #endif
Guolin Ke's avatar
Guolin Ke committed
96
97

    if (result_file == NULL) {
98
      Log::Fatal("Prediction results file %s doesn't exist", data_filename);
Guolin Ke's avatar
Guolin Ke committed
99
    }
100
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
Guolin Ke's avatar
Guolin Ke committed
101
102

    if (parser == nullptr) {
103
      Log::Fatal("Could not recognize the data format of data file %s", data_filename);
Guolin Ke's avatar
Guolin Ke committed
104
105
106
    }

    // function for parse data
107
108
    std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun;
    double tmp_label;
Guolin Ke's avatar
Guolin Ke committed
109
    parser_fun = [this, &parser, &tmp_label]
110
    (const char* buffer, std::vector<std::pair<int, double>>* feature) {
Guolin Ke's avatar
Guolin Ke committed
111
112
113
      parser->ParseOneLine(buffer, feature, &tmp_label);
    };

Guolin Ke's avatar
Guolin Ke committed
114
    std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
Guolin Ke's avatar
Guolin Ke committed
115
      [this, &parser_fun, &result_file]
Guolin Ke's avatar
Guolin Ke committed
116
    (data_size_t, const std::vector<std::string>& lines) {
117
      std::vector<std::pair<int, double>> oneline_features;
118
      for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
        oneline_features.clear();
        // parser
        parser_fun(lines[i].c_str(), &oneline_features);
        // predict
Guolin Ke's avatar
Guolin Ke committed
123
124
125
126
        std::vector<double> result(num_pred_one_row_);
        predict_fun_(oneline_features, result.data());
        auto str_result = Common::Join<double>(result, "\t");
        fprintf(result_file, "%s\n", str_result.c_str());
Guolin Ke's avatar
Guolin Ke committed
127
128
      }
    };
Guolin Ke's avatar
Guolin Ke committed
129
    TextReader<data_size_t> predict_data_reader(data_filename, has_header);
Guolin Ke's avatar
Guolin Ke committed
130
131
132
133
134
    predict_data_reader.ReadAllAndProcessParallel(process_fun);
    fclose(result_file);
  }

private:
135

Guolin Ke's avatar
Guolin Ke committed
136
  void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
Guolin Ke's avatar
Guolin Ke committed
137
    int loop_size = static_cast<int>(features.size());
138
    #pragma omp parallel for schedule(static,128) if (loop_size >= 256)
Guolin Ke's avatar
Guolin Ke committed
139
    for (int i = 0; i < loop_size; ++i) {
Guolin Ke's avatar
Guolin Ke committed
140
      pred_buf[features[i].first] = features[i].second;
141
142
143
    }
  }

Guolin Ke's avatar
Guolin Ke committed
144
145
146
  void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) {
    if (features.size() < static_cast<size_t>(buf_size / 2)) {
      std::memset(pred_buf, 0, sizeof(double)*(buf_size));
147
148
149
150
    } else {
      int loop_size = static_cast<int>(features.size());
      #pragma omp parallel for schedule(static,128) if (loop_size >= 256)
      for (int i = 0; i < loop_size; ++i) {
Guolin Ke's avatar
Guolin Ke committed
151
        pred_buf[features[i].first] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
152
153
154
      }
    }
  }
155

Guolin Ke's avatar
Guolin Ke committed
156
157
  /*! \brief Boosting model */
  const Boosting* boosting_;
Guolin Ke's avatar
Guolin Ke committed
158
159
  /*! \brief function for prediction */
  PredictFunction predict_fun_;
Guolin Ke's avatar
Guolin Ke committed
160
  int num_pred_one_row_;
Guolin Ke's avatar
Guolin Ke committed
161
162
  int num_threads_;
  std::vector<std::vector<double>> predict_buf_;
Guolin Ke's avatar
Guolin Ke committed
163
164
165
166
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
167
#endif   // LightGBM_PREDICTOR_HPP_