predictor.hpp 7.62 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#ifndef LIGHTGBM_PREDICTOR_HPP_
#define LIGHTGBM_PREDICTOR_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/boosting.h>
#include <LightGBM/utils/text_reader.h>
#include <LightGBM/dataset.h>

#include <omp.h>

#include <cstring>
#include <cstdio>
#include <vector>
#include <utility>
#include <functional>
#include <string>

namespace LightGBM {

/*!
* \brief Used to prediction data with input model
*/
class Predictor {
public:
  /*!
  * \brief Constructor
  * \param boosting Input boosting model
  * \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification)
wxchan's avatar
wxchan committed
29
  * \param predict_leaf_index True if output leaf index instead of prediction score
Guolin Ke's avatar
Guolin Ke committed
30
  */
31
32
33
  Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index, int num_used_model)
    : is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index),
      num_used_model_(num_used_model) {
Guolin Ke's avatar
Guolin Ke committed
34
35
    boosting_ = boosting;
    num_features_ = boosting_->MaxFeatureIdx() + 1;
36
    num_class_ = boosting_->NumberOfClass();
Guolin Ke's avatar
Guolin Ke committed
37
38
39
40
41
#pragma omp parallel
#pragma omp master
    {
      num_threads_ = omp_get_num_threads();
    }
42
    features_ = new float*[num_threads_];
Guolin Ke's avatar
Guolin Ke committed
43
    for (int i = 0; i < num_threads_; ++i) {
44
      features_[i] = new float[num_features_];
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    }
  }
  /*!
  * \brief Destructor
  */
  ~Predictor() {
    if (features_ != nullptr) {
      for (int i = 0; i < num_threads_; ++i) {
        delete[] features_[i];
      }
      delete[] features_;
    }
  }

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
60
  * \brief prediction for one record, only raw result(without sigmoid transformation)
Guolin Ke's avatar
Guolin Ke committed
61
62
63
  * \param features Feature for this record
  * \return Prediction result
  */
64
  float PredictRawOneLine(const std::vector<std::pair<int, float>>& features) {
Guolin Ke's avatar
Guolin Ke committed
65
    const int tid = PutFeatureValuesToBuffer(features);
Qiwei Ye's avatar
Qiwei Ye committed
66
    // get result without sigmoid transformation
67
    return boosting_->PredictRaw(features_[tid], num_used_model_);
Guolin Ke's avatar
Guolin Ke committed
68
  }
wxchan's avatar
wxchan committed
69
70
71
72
73
74
  
  /*!
  * \brief prediction for one record, only raw result(without sigmoid transformation)
  * \param features Feature for this record
  * \return Predictied leaf index
  */
75
  std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, float>>& features) {
Guolin Ke's avatar
Guolin Ke committed
76
    const int tid = PutFeatureValuesToBuffer(features);
wxchan's avatar
wxchan committed
77
    // get result for leaf index
78
    return boosting_->PredictLeafIndex(features_[tid], num_used_model_);
wxchan's avatar
wxchan committed
79
  }
Guolin Ke's avatar
Guolin Ke committed
80
81

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
82
83
  * \brief prediction for one record, will use sigmoid transformation if needed(only enabled for binary classification noe)
  * \param features Feature of this record
Guolin Ke's avatar
Guolin Ke committed
84
85
  * \return Prediction result
  */
86
  float PredictOneLine(const std::vector<std::pair<int, float>>& features) {
Guolin Ke's avatar
Guolin Ke committed
87
88
    const int tid = PutFeatureValuesToBuffer(features);
    // get result with sigmoid transform if needed
89
    return boosting_->Predict(features_[tid], num_used_model_);
Guolin Ke's avatar
Guolin Ke committed
90
  }
91
92
93
94
95
96
97
98
99
100
101
102
  
  /*!
  * \brief prediction for multiclass classification
  * \param features Feature of this record
  * \return Prediction result
  */
  std::vector<float> PredictMulticlassOneLine(const std::vector<std::pair<int, float>>& features) {
    const int tid = PutFeatureValuesToBuffer(features);
    // get result with sigmoid transform if needed
    return boosting_->PredictMulticlass(features_[tid], num_used_model_);
  }
  
Guolin Ke's avatar
Guolin Ke committed
103
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
104
  * \brief predicting on data, then saving result to disk
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
  * \param data_filename Filename of data
  * \param has_label True if this data contains label
  * \param result_filename Filename of output result
  */
Guolin Ke's avatar
Guolin Ke committed
109
  void Predict(const char* data_filename, const char* result_filename, bool has_header) {
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
116
117
118
    FILE* result_file;

#ifdef _MSC_VER
    fopen_s(&result_file, result_filename, "w");
#else
    result_file = fopen(result_filename, "w");
#endif

    if (result_file == NULL) {
Qiwei Ye's avatar
Qiwei Ye committed
119
      Log::Fatal("Predition result file %s doesn't exists", data_filename);
Guolin Ke's avatar
Guolin Ke committed
120
    }
Guolin Ke's avatar
Guolin Ke committed
121
    Parser* parser = Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx());
Guolin Ke's avatar
Guolin Ke committed
122
123

    if (parser == nullptr) {
Qiwei Ye's avatar
Qiwei Ye committed
124
      Log::Fatal("Recongnizing input data format failed, filename %s", data_filename);
Guolin Ke's avatar
Guolin Ke committed
125
126
127
    }

    // function for parse data
128
129
    std::function<void(const char*, std::vector<std::pair<int, float>>*)> parser_fun;
    float tmp_label;
Guolin Ke's avatar
Guolin Ke committed
130
    parser_fun = [this, &parser, &tmp_label]
131
    (const char* buffer, std::vector<std::pair<int, float>>* feature) {
Guolin Ke's avatar
Guolin Ke committed
132
133
134
      parser->ParseOneLine(buffer, feature, &tmp_label);
    };

135
    std::function<std::string(const std::vector<std::pair<int, float>>&)> predict_fun;
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    if (num_class_ > 1) {
      predict_fun = [this](const std::vector<std::pair<int, float>>& features){
        std::vector<float> prediction = PredictMulticlassOneLine(features);
        std::stringstream result_stream_buf;
        for (size_t i = 0; i < prediction.size(); ++i){
          if (i > 0) {
            result_stream_buf << '\t';
          }
          result_stream_buf << prediction[i];
        }
        return result_stream_buf.str();  
      };  
    }
    else if (is_predict_leaf_index_) {
150
      predict_fun = [this](const std::vector<std::pair<int, float>>& features){
wxchan's avatar
wxchan committed
151
        std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features);
152
        std::stringstream result_stream_buf;
wxchan's avatar
wxchan committed
153
154
        for (size_t i = 0; i < predicted_leaf_index.size(); ++i){
          if (i > 0) {
155
            result_stream_buf << '\t';
wxchan's avatar
wxchan committed
156
          }
157
          result_stream_buf << predicted_leaf_index[i];
wxchan's avatar
wxchan committed
158
        }
159
        return result_stream_buf.str();  
Guolin Ke's avatar
Guolin Ke committed
160
161
      };
    }
wxchan's avatar
wxchan committed
162
163
    else {
      if (is_simgoid_) {
164
        predict_fun = [this](const std::vector<std::pair<int, float>>& features){
wxchan's avatar
wxchan committed
165
166
167
168
          return std::to_string(PredictOneLine(features));
        };
      } 
      else {
169
        predict_fun = [this](const std::vector<std::pair<int, float>>& features){
wxchan's avatar
wxchan committed
170
171
172
173
          return std::to_string(PredictRawOneLine(features));
        };
      } 
    }
Guolin Ke's avatar
Guolin Ke committed
174
175
176
    std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
      [this, &parser_fun, &predict_fun, &result_file]
    (data_size_t, const std::vector<std::string>& lines) {
177
      std::vector<std::pair<int, float>> oneline_features;
wxchan's avatar
wxchan committed
178
      std::vector<std::string> pred_result(lines.size(), "");
Guolin Ke's avatar
Guolin Ke committed
179
#pragma omp parallel for schedule(static) private(oneline_features)
180
      for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
185
186
187
188
        oneline_features.clear();
        // parser
        parser_fun(lines[i].c_str(), &oneline_features);
        // predict
        pred_result[i] = predict_fun(oneline_features);
      }

      for (size_t i = 0; i < pred_result.size(); ++i) {
wxchan's avatar
wxchan committed
189
        fprintf(result_file, "%s\n", pred_result[i].c_str());
Guolin Ke's avatar
Guolin Ke committed
190
191
      }
    };
Guolin Ke's avatar
Guolin Ke committed
192
    TextReader<data_size_t> predict_data_reader(data_filename, has_header);
Guolin Ke's avatar
Guolin Ke committed
193
194
195
196
197
198
199
    predict_data_reader.ReadAllAndProcessParallel(process_fun);

    fclose(result_file);
    delete parser;
  }

private:
200
  int PutFeatureValuesToBuffer(const std::vector<std::pair<int, float>>& features) {
Guolin Ke's avatar
Guolin Ke committed
201
202
    int tid = omp_get_thread_num();
    // init feature value
203
    std::memset(features_[tid], 0, sizeof(float)*num_features_);
Guolin Ke's avatar
Guolin Ke committed
204
205
206
207
208
209
210
211
    // put feature value
    for (const auto& p : features) {
      if (p.first < num_features_) {
        features_[tid][p.first] = p.second;
      }
    }
    return tid;
  }
Guolin Ke's avatar
Guolin Ke committed
212
213
214
  /*! \brief Boosting model */
  const Boosting* boosting_;
  /*! \brief Buffer for feature values */
215
  float** features_;
Guolin Ke's avatar
Guolin Ke committed
216
217
  /*! \brief Number of features */
  int num_features_;
218
219
  /*! \brief Number of classes */
  int num_class_;
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
  /*! \brief True if need to predict result with sigmoid transform */
  bool is_simgoid_;
  /*! \brief Number of threads */
  int num_threads_;
wxchan's avatar
wxchan committed
224
  /*! \brief True if output leaf index instead of prediction score */
225
226
227
  bool is_predict_leaf_index_;
  /*! \brief Number of used model */
  int num_used_model_;
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
232
#endif   // LightGBM_PREDICTOR_HPP_