".github/vscode:/vscode.git/clone" did not exist on "696b904874a6c91f0390401bdb1085e3be55a274"
predictor.hpp 4.97 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#ifndef LIGHTGBM_PREDICTOR_HPP_
#define LIGHTGBM_PREDICTOR_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/boosting.h>
#include <LightGBM/utils/text_reader.h>
#include <LightGBM/dataset.h>

#include <omp.h>

#include <cstring>
#include <cstdio>
#include <vector>
#include <utility>
#include <functional>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
27
28

namespace LightGBM {

/*!
* \brief Used to prediction data with input model
*/
class Predictor {
public:
  /*!
  * \brief Constructor
  * \param boosting Input boosting model
29
  * \param is_raw_score True if need to predict result with raw score
wxchan's avatar
wxchan committed
30
  * \param predict_leaf_index True if output leaf index instead of prediction score
Guolin Ke's avatar
Guolin Ke committed
31
  */
Guolin Ke's avatar
Guolin Ke committed
32
  Predictor(const Boosting* boosting, bool is_raw_score, bool is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
    boosting_ = boosting;
    num_features_ = boosting_->MaxFeatureIdx() + 1;
#pragma omp parallel
#pragma omp master
    {
      num_threads_ = omp_get_num_threads();
    }
    for (int i = 0; i < num_threads_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
41
      features_.push_back(std::vector<double>(num_features_));
Guolin Ke's avatar
Guolin Ke committed
42
    }
Guolin Ke's avatar
Guolin Ke committed
43
    features_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
44
    if (is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
45
46
47
      predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
        const int tid = PutFeatureValuesToBuffer(features);
        // get result for leaf index
Guolin Ke's avatar
Guolin Ke committed
48
        auto result = boosting_->PredictLeafIndex(features_[tid].data());
Guolin Ke's avatar
Guolin Ke committed
49
50
51
        return std::vector<double>(result.begin(), result.end());
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
52
      if (is_raw_score) {
Guolin Ke's avatar
Guolin Ke committed
53
54
55
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
          const int tid = PutFeatureValuesToBuffer(features);
          // get result without sigmoid transformation
Guolin Ke's avatar
Guolin Ke committed
56
          return boosting_->PredictRaw(features_[tid].data());
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
        };
      } else {
        predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
          const int tid = PutFeatureValuesToBuffer(features);
Guolin Ke's avatar
Guolin Ke committed
61
          return boosting_->Predict(features_[tid].data());
Guolin Ke's avatar
Guolin Ke committed
62
63
64
        };
      }
    }
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
69
70
71
  }
  /*!
  * \brief Destructor
  */
  ~Predictor() {
  }

Guolin Ke's avatar
Guolin Ke committed
72
73
  inline const PredictFunction& GetPredictFunction() {
    return predict_fun_;
74
  }
75

Guolin Ke's avatar
Guolin Ke committed
76
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
77
  * \brief predicting on data, then saving result to disk
Guolin Ke's avatar
Guolin Ke committed
78
79
80
81
  * \param data_filename Filename of data
  * \param has_label True if this data contains label
  * \param result_filename Filename of output result
  */
Guolin Ke's avatar
Guolin Ke committed
82
  void Predict(const char* data_filename, const char* result_filename, bool has_header) {
Guolin Ke's avatar
Guolin Ke committed
83
84
85
86
87
88
89
90
91
    FILE* result_file;

#ifdef _MSC_VER
    fopen_s(&result_file, result_filename, "w");
#else
    result_file = fopen(result_filename, "w");
#endif

    if (result_file == NULL) {
92
      Log::Fatal("Prediction results file %s doesn't exist", data_filename);
Guolin Ke's avatar
Guolin Ke committed
93
    }
Guolin Ke's avatar
Guolin Ke committed
94
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx()));
Guolin Ke's avatar
Guolin Ke committed
95
96

    if (parser == nullptr) {
97
      Log::Fatal("Could not recognize the data format of data file %s", data_filename);
Guolin Ke's avatar
Guolin Ke committed
98
99
100
    }

    // function for parse data
101
102
    std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun;
    double tmp_label;
Guolin Ke's avatar
Guolin Ke committed
103
    parser_fun = [this, &parser, &tmp_label]
104
    (const char* buffer, std::vector<std::pair<int, double>>* feature) {
Guolin Ke's avatar
Guolin Ke committed
105
106
107
      parser->ParseOneLine(buffer, feature, &tmp_label);
    };

Guolin Ke's avatar
Guolin Ke committed
108
    std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
Guolin Ke's avatar
Guolin Ke committed
109
      [this, &parser_fun, &result_file]
Guolin Ke's avatar
Guolin Ke committed
110
    (data_size_t, const std::vector<std::string>& lines) {
111
      std::vector<std::pair<int, double>> oneline_features;
wxchan's avatar
wxchan committed
112
      std::vector<std::string> pred_result(lines.size(), "");
Guolin Ke's avatar
Guolin Ke committed
113
#pragma omp parallel for schedule(static) private(oneline_features)
114
      for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
        oneline_features.clear();
        // parser
        parser_fun(lines[i].c_str(), &oneline_features);
        // predict
Guolin Ke's avatar
Guolin Ke committed
119
        pred_result[i] = Common::Join<double>(predict_fun_(oneline_features), "\t");
Guolin Ke's avatar
Guolin Ke committed
120
121
122
      }

      for (size_t i = 0; i < pred_result.size(); ++i) {
wxchan's avatar
wxchan committed
123
        fprintf(result_file, "%s\n", pred_result[i].c_str());
Guolin Ke's avatar
Guolin Ke committed
124
125
      }
    };
Guolin Ke's avatar
Guolin Ke committed
126
    TextReader<data_size_t> predict_data_reader(data_filename, has_header);
Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
131
132
    predict_data_reader.ReadAllAndProcessParallel(process_fun);

    fclose(result_file);
  }

private:
133
  int PutFeatureValuesToBuffer(const std::vector<std::pair<int, double>>& features) {
Guolin Ke's avatar
Guolin Ke committed
134
135
    int tid = omp_get_thread_num();
    // init feature value
Guolin Ke's avatar
Guolin Ke committed
136
    std::memset(features_[tid].data(), 0, sizeof(double)*num_features_);
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
142
143
144
    // put feature value
    for (const auto& p : features) {
      if (p.first < num_features_) {
        features_[tid][p.first] = p.second;
      }
    }
    return tid;
  }
Guolin Ke's avatar
Guolin Ke committed
145
146
147
  /*! \brief Boosting model */
  const Boosting* boosting_;
  /*! \brief Buffer for feature values */
Guolin Ke's avatar
Guolin Ke committed
148
  std::vector<std::vector<double>> features_;
Guolin Ke's avatar
Guolin Ke committed
149
150
151
152
  /*! \brief Number of features */
  int num_features_;
  /*! \brief Number of threads */
  int num_threads_;
Guolin Ke's avatar
Guolin Ke committed
153
154
  /*! \brief function for prediction */
  PredictFunction predict_fun_;
Guolin Ke's avatar
Guolin Ke committed
155
156
157
158
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
159
#endif   // LightGBM_PREDICTOR_HPP_