predictor.hpp 9.74 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
#ifndef LIGHTGBM_PREDICTOR_HPP_
#define LIGHTGBM_PREDICTOR_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/boosting.h>
#include <LightGBM/utils/text_reader.h>
#include <LightGBM/dataset.h>

9
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
10

ww's avatar
ww committed
11
#include <map>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
16
17
#include <cstring>
#include <cstdio>
#include <vector>
#include <utility>
#include <functional>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22

namespace LightGBM {

/*!
zhangyafeikimi's avatar
zhangyafeikimi committed
23
* \brief Used to predict data with input model
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
29
*/
class Predictor {
public:
  /*!
  * \brief Constructor
  * \param boosting Input boosting model
Guolin Ke's avatar
Guolin Ke committed
30
  * \param num_iteration Number of boosting round
31
  * \param is_raw_score True if need to predict result with raw score
Guolin Ke's avatar
Guolin Ke committed
32
33
  * \param predict_leaf_index True to output leaf index instead of prediction score
  * \param predict_contrib True to output feature contributions instead of prediction score
Guolin Ke's avatar
Guolin Ke committed
34
  */
Guolin Ke's avatar
Guolin Ke committed
35
  Predictor(Boosting* boosting, int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
36
            bool is_raw_score, bool predict_leaf_index, bool predict_contrib,
37
38
39
40
            bool early_stop, int early_stop_freq, double early_stop_margin) {
    early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
    if (early_stop && !boosting->NeedAccuratePrediction()) {
      PredictionEarlyStopConfig pred_early_stop_config;
41
42
      CHECK(early_stop_freq > 0);
      CHECK(early_stop_margin >= 0);
43
44
45
46
47
48
49
50
51
      pred_early_stop_config.margin_threshold = early_stop_margin;
      pred_early_stop_config.round_period = early_stop_freq;
      if (boosting->NumberOfClasses() == 1) {
        early_stop_ = CreatePredictionEarlyStopInstance("binary", pred_early_stop_config);
      } else {
        early_stop_ = CreatePredictionEarlyStopInstance("multiclass", pred_early_stop_config);
      }
    }

Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
    #pragma omp parallel
    #pragma omp master
    {
      num_threads_ = omp_get_num_threads();
    }
Guolin Ke's avatar
Guolin Ke committed
57
    boosting->InitPredict(num_iteration, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
58
    boosting_ = boosting;
Guolin Ke's avatar
Guolin Ke committed
59
    num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, predict_leaf_index, predict_contrib);
60
61
    num_feature_ = boosting_->MaxFeatureIdx() + 1;
    predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
62
63
    const int kFeatureThreshold = 100000;
    const size_t KSparseThreshold = static_cast<size_t>(0.01 * num_feature_);
Guolin Ke's avatar
Guolin Ke committed
64
    if (predict_leaf_index) {
65
      predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
66
        int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
67
68
69
        if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
          auto buf = CopyToPredictMap(features);
          boosting_->PredictLeafIndexByMap(buf, output);
70
71
72
73
74
75
        } else {
          CopyToPredictBuffer(predict_buf_[tid].data(), features);
          // get result for leaf index
          boosting_->PredictLeafIndex(predict_buf_[tid].data(), output);
          ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
        }
Guolin Ke's avatar
Guolin Ke committed
76
      };
Guolin Ke's avatar
Guolin Ke committed
77
    } else if (predict_contrib) {
78
79
80
        predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
          int tid = omp_get_thread_num();
          CopyToPredictBuffer(predict_buf_[tid].data(), features);
81
82
83
84
          // get result for leaf index
          boosting_->PredictContrib(predict_buf_[tid].data(), output, &early_stop_);
          ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
        };
Guolin Ke's avatar
Guolin Ke committed
85
    } else {
Guolin Ke's avatar
Guolin Ke committed
86
      if (is_raw_score) {
87
        predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
88
          int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
89
90
91
          if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
            auto buf = CopyToPredictMap(features);
            boosting_->PredictRawByMap(buf, output, &early_stop_);
92
93
94
95
96
          } else {
            CopyToPredictBuffer(predict_buf_[tid].data(), features);
            boosting_->PredictRaw(predict_buf_[tid].data(), output, &early_stop_);
            ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
          }
Guolin Ke's avatar
Guolin Ke committed
97
98
        };
      } else {
99
        predict_fun_ = [=](const std::vector<std::pair<int, double>>& features, double* output) {
Guolin Ke's avatar
Guolin Ke committed
100
          int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
101
102
103
          if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
            auto buf = CopyToPredictMap(features);
            boosting_->PredictByMap(buf, output, &early_stop_);
104
105
106
107
108
          } else {
            CopyToPredictBuffer(predict_buf_[tid].data(), features);
            boosting_->Predict(predict_buf_[tid].data(), output, &early_stop_);
            ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
          }
Guolin Ke's avatar
Guolin Ke committed
109
110
111
        };
      }
    }
Guolin Ke's avatar
Guolin Ke committed
112
  }
113

Guolin Ke's avatar
Guolin Ke committed
114
115
116
117
118
119
  /*!
  * \brief Destructor
  */
  ~Predictor() {
  }

zhangyafeikimi's avatar
zhangyafeikimi committed
120
  inline const PredictFunction& GetPredictFunction() const {
Guolin Ke's avatar
Guolin Ke committed
121
    return predict_fun_;
122
  }
123

Guolin Ke's avatar
Guolin Ke committed
124
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
125
  * \brief predicting on data, then saving result to disk
Guolin Ke's avatar
Guolin Ke committed
126
127
128
  * \param data_filename Filename of data
  * \param result_filename Filename of output result
  */
Guolin Ke's avatar
Guolin Ke committed
129
  void Predict(const char* data_filename, const char* result_filename, bool header) {
130
131
    auto writer = VirtualFileWriter::Make(result_filename);
    if (!writer->Init()) {
132
      Log::Fatal("Prediction results file %s cannot be found", result_filename);
Guolin Ke's avatar
Guolin Ke committed
133
    }
Guolin Ke's avatar
Guolin Ke committed
134
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
Guolin Ke's avatar
Guolin Ke committed
135
136

    if (parser == nullptr) {
137
      Log::Fatal("Could not recognize the data format of data file %s", data_filename);
Guolin Ke's avatar
Guolin Ke committed
138
139
    }

Guolin Ke's avatar
Guolin Ke committed
140
    TextReader<data_size_t> predict_data_reader(data_filename, header);
ww's avatar
ww committed
141
142
    std::unordered_map<int, int> feature_names_map_;
    bool need_adjust = false;
Guolin Ke's avatar
Guolin Ke committed
143
    if (header) {
ww's avatar
ww committed
144
      std::string first_line = predict_data_reader.first_line();
Guolin Ke's avatar
Guolin Ke committed
145
146
147
      std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
      header_words.erase(header_words.begin() + boosting_->LabelIdx());
      for (int i = 0; i < static_cast<int>(header_words.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
148
        for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
149
          if (header_words[i] == boosting_->FeatureNames()[j]) {
ww's avatar
ww committed
150
151
152
153
154
            feature_names_map_[i] = j;
            break;
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
155
156
      for (auto s : feature_names_map_) {
        if (s.first != s.second) {
ww's avatar
ww committed
157
158
159
160
161
          need_adjust = true;
          break;
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
162
    // function for parse data
163
164
    std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun;
    double tmp_label;
165
    parser_fun = [&]
166
    (const char* buffer, std::vector<std::pair<int, double>>* feature) {
Guolin Ke's avatar
Guolin Ke committed
167
      parser->ParseOneLine(buffer, feature, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
168
      if (need_adjust) {
ww's avatar
ww committed
169
        int i = 0, j = static_cast<int>(feature->size());
Guolin Ke's avatar
Guolin Ke committed
170
171
        while (i < j) {
          if (feature_names_map_.find((*feature)[i].first) != feature_names_map_.end()) {
ww's avatar
ww committed
172
173
            (*feature)[i].first = feature_names_map_[(*feature)[i].first];
            ++i;
Guolin Ke's avatar
Guolin Ke committed
174
          } else {
175
            // move the non-used features to the end of the feature vector
ww's avatar
ww committed
176
177
178
179
180
            std::swap((*feature)[i], (*feature)[--j]);
          }
        }
        feature->resize(i);
      }
Guolin Ke's avatar
Guolin Ke committed
181
182
    };

183
184
    std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = [&]
    (data_size_t, const std::vector<std::string>& lines) {
185
      std::vector<std::pair<int, double>> oneline_features;
186
187
188
      std::vector<std::string> result_to_write(lines.size());
      OMP_INIT_EX();
      #pragma omp parallel for schedule(static) firstprivate(oneline_features)
189
      for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
190
        OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
        oneline_features.clear();
        // parser
        parser_fun(lines[i].c_str(), &oneline_features);
        // predict
Guolin Ke's avatar
Guolin Ke committed
195
196
197
        std::vector<double> result(num_pred_one_row_);
        predict_fun_(oneline_features, result.data());
        auto str_result = Common::Join<double>(result, "\t");
198
199
200
201
202
        result_to_write[i] = str_result;
        OMP_LOOP_EX_END();
      }
      OMP_THROW_EX();
      for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
203
204
        writer->Write(result_to_write[i].c_str(), result_to_write[i].size());
        writer->Write("\n", 1);
Guolin Ke's avatar
Guolin Ke committed
205
206
207
208
209
210
      }
    };
    predict_data_reader.ReadAllAndProcessParallel(process_fun);
  }

private:
Guolin Ke's avatar
Guolin Ke committed
211
  void CopyToPredictBuffer(double* pred_buf, const std::vector<std::pair<int, double>>& features) {
Guolin Ke's avatar
Guolin Ke committed
212
213
    int loop_size = static_cast<int>(features.size());
    for (int i = 0; i < loop_size; ++i) {
214
215
216
      if (features[i].first < num_feature_) {
        pred_buf[features[i].first] = features[i].second;
      }
217
218
219
    }
  }

Guolin Ke's avatar
Guolin Ke committed
220
  void ClearPredictBuffer(double* pred_buf, size_t buf_size, const std::vector<std::pair<int, double>>& features) {
221
    if (features.size() > static_cast<size_t>(buf_size / 2)) {
Guolin Ke's avatar
Guolin Ke committed
222
      std::memset(pred_buf, 0, sizeof(double)*(buf_size));
223
224
225
    } else {
      int loop_size = static_cast<int>(features.size());
      for (int i = 0; i < loop_size; ++i) {
Guolin Ke's avatar
Guolin Ke committed
226
227
228
        if (features[i].first < num_feature_) {
          pred_buf[features[i].first] = 0.0f;
        }
Guolin Ke's avatar
Guolin Ke committed
229
230
231
      }
    }
  }
232

Guolin Ke's avatar
Guolin Ke committed
233
234
  std::unordered_map<int, double> CopyToPredictMap(const std::vector<std::pair<int, double>>& features) {
    std::unordered_map<int, double> buf;
235
236
237
    int loop_size = static_cast<int>(features.size());
    for (int i = 0; i < loop_size; ++i) {
      if (features[i].first < num_feature_) {
Guolin Ke's avatar
Guolin Ke committed
238
        buf[features[i].first] = features[i].second;
239
240
      }
    }
241
    return buf;
242
243
  }

Guolin Ke's avatar
Guolin Ke committed
244
245
  /*! \brief Boosting model */
  const Boosting* boosting_;
Guolin Ke's avatar
Guolin Ke committed
246
247
  /*! \brief function for prediction */
  PredictFunction predict_fun_;
248
  PredictionEarlyStopInstance early_stop_;
249
  int num_feature_;
Guolin Ke's avatar
Guolin Ke committed
250
  int num_pred_one_row_;
Guolin Ke's avatar
Guolin Ke committed
251
252
  int num_threads_;
  std::vector<std::vector<double>> predict_buf_;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
256
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
257
#endif   // LightGBM_PREDICTOR_HPP_