predictor.hpp 7.39 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
#ifndef LIGHTGBM_PREDICTOR_HPP_
#define LIGHTGBM_PREDICTOR_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/boosting.h>
#include <LightGBM/utils/text_reader.h>
#include <LightGBM/dataset.h>

9
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16

#include <cstring>
#include <cstdio>
#include <vector>
#include <utility>
#include <functional>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21

namespace LightGBM {

/*!
zhangyafeikimi's avatar
zhangyafeikimi committed
22
* \brief Used to predict data with input model
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
27
28
*/
class Predictor {
public:
  /*!
  * \brief Constructor
  * \param boosting Input boosting model
Guolin Ke's avatar
Guolin Ke committed
29
  * \param num_iteration Number of boosting round
30
  * \param is_raw_score True if need to predict result with raw score
31
32
  * \param is_predict_leaf_index True to output leaf index instead of prediction score
  * \param is_predict_contrib True to output feature contributions instead of prediction score
Guolin Ke's avatar
Guolin Ke committed
33
  */
Guolin Ke's avatar
Guolin Ke committed
34
  Predictor(Boosting* boosting, int num_iteration,
35
            bool is_raw_score, bool is_predict_leaf_index, bool is_predict_contrib,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
            bool early_stop, int early_stop_freq, double early_stop_margin) {

    early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
    if (early_stop && !boosting->NeedAccuratePrediction()) {
      PredictionEarlyStopConfig pred_early_stop_config;
      pred_early_stop_config.margin_threshold = early_stop_margin;
      pred_early_stop_config.round_period = early_stop_freq;
      if (boosting->NumberOfClasses() == 1) {
        early_stop_ = CreatePredictionEarlyStopInstance("binary", pred_early_stop_config);
      } else {
        early_stop_ = CreatePredictionEarlyStopInstance("multiclass", pred_early_stop_config);
      }
    }

Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
    #pragma omp parallel
    #pragma omp master
    {
      num_threads_ = omp_get_num_threads();
    }
55
    boosting->InitPredict(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
56
    boosting_ = boosting;
57
    num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib);
58
59
    num_feature_ = boosting_->MaxFeatureIdx() + 1;
    predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
Guolin Ke's avatar
Guolin Ke committed
60

Guolin Ke's avatar
Guolin Ke committed
61
    if (is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
62
      predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
63
64
        int tid = omp_get_thread_num();
        CopyToPredictBuffer(predict_buf_[tid].data(), features);
Guolin Ke's avatar
Guolin Ke committed
65
        // get result for leaf index
Guolin Ke's avatar
Guolin Ke committed
66
67
        boosting_->PredictLeafIndex(predict_buf_[tid].data(), output);
        ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
Guolin Ke's avatar
Guolin Ke committed
68
      };
Guolin Ke's avatar
Guolin Ke committed
69

70
71
72
73
74
75
76
77
78
    } else if (is_predict_contrib) {
      predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
        int tid = omp_get_thread_num();
        CopyToPredictBuffer(predict_buf_[tid].data(), features);
        // get result for leaf index
        boosting_->PredictContrib(predict_buf_[tid].data(), output, &early_stop_);
        ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
      };

Guolin Ke's avatar
Guolin Ke committed
79
    } else {
Guolin Ke's avatar
Guolin Ke committed
80
      if (is_raw_score) {
81
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
82
83
          int tid = omp_get_thread_num();
          CopyToPredictBuffer(predict_buf_[tid].data(), features);
84
          boosting_->PredictRaw(predict_buf_[tid].data(), output, &early_stop_);
Guolin Ke's avatar
Guolin Ke committed
85
          ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
Guolin Ke's avatar
Guolin Ke committed
86
87
        };
      } else {
88
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
89
90
          int tid = omp_get_thread_num();
          CopyToPredictBuffer(predict_buf_[tid].data(), features);
91
          boosting_->Predict(predict_buf_[tid].data(), output, &early_stop_);
Guolin Ke's avatar
Guolin Ke committed
92
          ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
Guolin Ke's avatar
Guolin Ke committed
93
94
95
        };
      }
    }
Guolin Ke's avatar
Guolin Ke committed
96
  }
97

Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
102
103
  /*!
  * \brief Destructor
  */
  ~Predictor() {
  }

zhangyafeikimi's avatar
zhangyafeikimi committed
104
  inline const PredictFunction& GetPredictFunction() const {
Guolin Ke's avatar
Guolin Ke committed
105
    return predict_fun_;
106
  }
107

Guolin Ke's avatar
Guolin Ke committed
108
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
109
  * \brief predicting on data, then saving result to disk
Guolin Ke's avatar
Guolin Ke committed
110
111
112
  * \param data_filename Filename of data
  * \param result_filename Filename of output result
  */
Guolin Ke's avatar
Guolin Ke committed
113
  void Predict(const char* data_filename, const char* result_filename, bool has_header) {
Guolin Ke's avatar
Guolin Ke committed
114
115
    FILE* result_file;

Guolin Ke's avatar
Guolin Ke committed
116
    #ifdef _MSC_VER
Guolin Ke's avatar
Guolin Ke committed
117
    fopen_s(&result_file, result_filename, "w");
Guolin Ke's avatar
Guolin Ke committed
118
    #else
Guolin Ke's avatar
Guolin Ke committed
119
    result_file = fopen(result_filename, "w");
Guolin Ke's avatar
Guolin Ke committed
120
    #endif
Guolin Ke's avatar
Guolin Ke committed
121
122

    if (result_file == NULL) {
Qiwei Ye's avatar
Qiwei Ye committed
123
      Log::Fatal("Prediction results file %s cannot be found.", result_filename);
Guolin Ke's avatar
Guolin Ke committed
124
    }
125
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
Guolin Ke's avatar
Guolin Ke committed
126
127

    if (parser == nullptr) {
Qiwei Ye's avatar
Qiwei Ye committed
128
      Log::Fatal("Could not recognize the data format of data file %s.", data_filename);
Guolin Ke's avatar
Guolin Ke committed
129
130
131
    }

    // function for parse data
132
133
    std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun;
    double tmp_label;
Guolin Ke's avatar
Guolin Ke committed
134
    parser_fun = [this, &parser, &tmp_label]
135
    (const char* buffer, std::vector<std::pair<int, double>>* feature) {
Guolin Ke's avatar
Guolin Ke committed
136
137
138
      parser->ParseOneLine(buffer, feature, &tmp_label);
    };

Guolin Ke's avatar
Guolin Ke committed
139
    std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
Guolin Ke's avatar
Guolin Ke committed
140
      [this, &parser_fun, &result_file]
Guolin Ke's avatar
Guolin Ke committed
141
    (data_size_t, const std::vector<std::string>& lines) {
142
      std::vector<std::pair<int, double>> oneline_features;
143
144
145
      std::vector<std::string> result_to_write(lines.size());
      OMP_INIT_EX();
      #pragma omp parallel for schedule(static) firstprivate(oneline_features)
146
      for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
147
        OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
        oneline_features.clear();
        // parser
        parser_fun(lines[i].c_str(), &oneline_features);
        // predict
Guolin Ke's avatar
Guolin Ke committed
152
153
154
        std::vector<double> result(num_pred_one_row_);
        predict_fun_(oneline_features, result.data());
        auto str_result = Common::Join<double>(result, "\t");
155
156
157
158
159
160
        result_to_write[i] = str_result;
        OMP_LOOP_EX_END();
      }
      OMP_THROW_EX();
      for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
        fprintf(result_file, "%s\n", result_to_write[i].c_str());
Guolin Ke's avatar
Guolin Ke committed
161
162
      }
    };
Guolin Ke's avatar
Guolin Ke committed
163
    TextReader<data_size_t> predict_data_reader(data_filename, has_header);
Guolin Ke's avatar
Guolin Ke committed
164
165
166
167
168
    predict_data_reader.ReadAllAndProcessParallel(process_fun);
    fclose(result_file);
  }

private:
169

Guolin Ke's avatar
Guolin Ke committed
170
  void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
Guolin Ke's avatar
Guolin Ke committed
171
172
    int loop_size = static_cast<int>(features.size());
    for (int i = 0; i < loop_size; ++i) {
173
174
175
      if (features[i].first < num_feature_) {
        pred_buf[features[i].first] = features[i].second;
      }
176
177
178
    }
  }

Guolin Ke's avatar
Guolin Ke committed
179
180
181
  void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) {
    if (features.size() < static_cast<size_t>(buf_size / 2)) {
      std::memset(pred_buf, 0, sizeof(double)*(buf_size));
182
183
184
    } else {
      int loop_size = static_cast<int>(features.size());
      for (int i = 0; i < loop_size; ++i) {
Guolin Ke's avatar
Guolin Ke committed
185
186
187
        if (features[i].first < num_feature_) {
          pred_buf[features[i].first] = 0.0f;
        }
Guolin Ke's avatar
Guolin Ke committed
188
189
190
      }
    }
  }
191

Guolin Ke's avatar
Guolin Ke committed
192
193
  /*! \brief Boosting model */
  const Boosting* boosting_;
Guolin Ke's avatar
Guolin Ke committed
194
195
  /*! \brief function for prediction */
  PredictFunction predict_fun_;
196
  PredictionEarlyStopInstance early_stop_;
197
  int num_feature_;
Guolin Ke's avatar
Guolin Ke committed
198
  int num_pred_one_row_;
Guolin Ke's avatar
Guolin Ke committed
199
200
  int num_threads_;
  std::vector<std::vector<double>> predict_buf_;
Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
205
#endif   // LightGBM_PREDICTOR_HPP_