#ifndef LIGHTGBM_PREDICTOR_HPP_ #define LIGHTGBM_PREDICTOR_HPP_ #include #include #include #include #include #include #include #include #include #include #include namespace LightGBM { /*! * \brief Used to prediction data with input model */ class Predictor { public: /*! * \brief Constructor * \param boosting Input boosting model * \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification) */ Predictor(const Boosting* boosting, bool is_simgoid) : is_simgoid_(is_simgoid) { boosting_ = boosting; num_features_ = boosting_->MaxFeatureIdx() + 1; #pragma omp parallel #pragma omp master { num_threads_ = omp_get_num_threads(); } features_ = new double*[num_threads_]; for (int i = 0; i < num_threads_; ++i) { features_[i] = new double[num_features_]; } } /*! * \brief Destructor */ ~Predictor() { if (features_ != nullptr) { for (int i = 0; i < num_threads_; ++i) { delete[] features_[i]; } delete[] features_; } } /*! * \brief prediction for one record, only raw result(without sigmoid transformation) * \param features Feature for this record * \return Prediction result */ double PredictRawOneLine(const std::vector>& features) { const int tid = omp_get_thread_num(); // init feature value std::memset(features_[tid], 0, sizeof(double)*num_features_); // put feature value for (const auto& p : features) { if (p.first < num_features_) { features_[tid][p.first] = p.second; } } // get result without sigmoid transformation return boosting_->PredictRaw(features_[tid]); } /*! * \brief prediction for one record, will use sigmoid transformation if needed(only enabled for binary classification noe) * \param features Feature of this record * \return Prediction result */ double PredictOneLine(const std::vector>& features) { const int tid = omp_get_thread_num(); // init feature value std::memset(features_[tid], 0, sizeof(double)*num_features_); // put feature value for (const auto& p : features) { if (p.first < num_features_) { features_[tid][p.first] = p.second; } } // get result with sigmoid transform return boosting_->Predict(features_[tid]); } /*! * \brief predicting on data, then saving result to disk * \param data_filename Filename of data * \param has_label True if this data contains label * \param result_filename Filename of output result */ void Predict(const char* data_filename, const char* result_filename) { FILE* result_file; #ifdef _MSC_VER fopen_s(&result_file, result_filename, "w"); #else result_file = fopen(result_filename, "w"); #endif if (result_file == NULL) { Log::Error("predition result file %s doesn't exists", data_filename); } bool has_label = false; Parser* parser = Parser::CreateParser(data_filename, num_features_, &has_label); if (parser == nullptr) { Log::Error("recongnizing input data format failed, filename %s", data_filename); } // function for parse data std::function>*)> parser_fun; double tmp_label; if (has_label) { // parse function with label parser_fun = [this, &parser, &tmp_label] (const char* buffer, std::vector>* feature) { parser->ParseOneLine(buffer, feature, &tmp_label); }; Log::Info("start prediction for data %s, and data has label", data_filename); } else { // parse function without label parser_fun = [this, &parser] (const char* buffer, std::vector>* feature) { parser->ParseOneLine(buffer, feature); }; Log::Info("start prediction for data %s, and data doesn't has label", data_filename); } std::function>&)> predict_fun; if (is_simgoid_) { predict_fun = [this](const std::vector>& features) { return PredictOneLine(features); }; } else { predict_fun = [this](const std::vector>& features) { return PredictRawOneLine(features); }; } std::function&)> process_fun = [this, &parser_fun, &predict_fun, &result_file] (data_size_t, const std::vector& lines) { std::vector> oneline_features; std::vector pred_result(lines.size(), 0.0f); #pragma omp parallel for schedule(static) private(oneline_features) for (data_size_t i = 0; i < static_cast(lines.size()); i++) { oneline_features.clear(); // parser parser_fun(lines[i].c_str(), &oneline_features); // predict pred_result[i] = predict_fun(oneline_features); } for (size_t i = 0; i < pred_result.size(); ++i) { fprintf(result_file, "%f\n", pred_result[i]); } }; TextReader predict_data_reader(data_filename); predict_data_reader.ReadAllAndProcessParallel(process_fun); fclose(result_file); delete parser; } private: /*! \brief Boosting model */ const Boosting* boosting_; /*! \brief Buffer for feature values */ double** features_; /*! \brief Number of features */ int num_features_; /*! \brief True if need to predict result with sigmoid transform */ bool is_simgoid_; /*! \brief Number of threads */ int num_threads_; }; } // namespace LightGBM #endif // LightGBM_PREDICTOR_HPP_