Commit f6b8ecf6 authored by Guolin Ke's avatar Guolin Ke Committed by Nikita Titov
Browse files

option to disable the shape checking in prediction (#2669)



* implement

* better documentation

* Update include/LightGBM/config.h
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Apply suggestions from code review
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* fix document

* regenerate docs
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 65455588
...@@ -725,6 +725,18 @@ IO Parameters ...@@ -725,6 +725,18 @@ IO Parameters
- the threshold of margin in early-stopping prediction - the threshold of margin in early-stopping prediction
- ``predict_disable_shape_check`` :raw-html:`<a id="predict_disable_shape_check" title="Permalink to this parameter" href="#predict_disable_shape_check">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool
- used only in ``prediction`` task
- control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data
- if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training
- if ``true``, LightGBM will attempt to predict on whatever data you provide. This is dangerous because you might get incorrect predictions, but you could use it in situations where it is difficult or expensive to generate some features and you are very confident that they were never chosen for splits in the model
- **Note**: be very careful setting this parameter to ``true``
- ``convert_model_language`` :raw-html:`<a id="convert_model_language" title="Permalink to this parameter" href="#convert_model_language">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string - ``convert_model_language`` :raw-html:`<a id="convert_model_language" title="Permalink to this parameter" href="#convert_model_language">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string
- used only in ``convert_model`` task - used only in ``convert_model`` task
......
...@@ -658,6 +658,13 @@ struct Config { ...@@ -658,6 +658,13 @@ struct Config {
// desc = the threshold of margin in early-stopping prediction // desc = the threshold of margin in early-stopping prediction
double pred_early_stop_margin = 10.0; double pred_early_stop_margin = 10.0;
// desc = used only in ``prediction`` task
// desc = control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data
// desc = if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training
// desc = if ``true``, LightGBM will attempt to predict on whatever data you provide. This is dangerous because you might get incorrect predictions, but you could use it in situations where it is difficult or expensive to generate some features and you are very confident that they were never chosen for splits in the model
// desc = **Note**: be very careful setting this parameter to ``true``
bool predict_disable_shape_check = false;
// desc = used only in ``convert_model`` task // desc = used only in ``convert_model`` task
// desc = only ``cpp`` is supported yet; for conversion model to other languages consider using `m2cgen <https://github.com/BayesWitnesses/m2cgen>`__ utility // desc = only ``cpp`` is supported yet; for conversion model to other languages consider using `m2cgen <https://github.com/BayesWitnesses/m2cgen>`__ utility
// desc = if ``convert_model_language`` is set and ``task=train``, the model will be also converted // desc = if ``convert_model_language`` is set and ``task=train``, the model will be also converted
......
...@@ -215,7 +215,7 @@ void Application::Predict() { ...@@ -215,7 +215,7 @@ void Application::Predict() {
if (config_.task == TaskType::KRefitTree) { if (config_.task == TaskType::KRefitTree) {
// create predictor // create predictor
Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1); Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header); predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
TextReader<int> result_reader(config_.output_result.c_str(), false); TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines(); result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size()); std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
...@@ -245,7 +245,7 @@ void Application::Predict() { ...@@ -245,7 +245,7 @@ void Application::Predict() {
config_.pred_early_stop, config_.pred_early_stop_freq, config_.pred_early_stop, config_.pred_early_stop_freq,
config_.pred_early_stop_margin); config_.pred_early_stop_margin);
predictor.Predict(config_.data.c_str(), predictor.Predict(config_.data.c_str(),
config_.output_result.c_str(), config_.header); config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
Log::Info("Finished prediction"); Log::Info("Finished prediction");
} }
} }
......
...@@ -130,7 +130,7 @@ class Predictor { ...@@ -130,7 +130,7 @@ class Predictor {
* \param data_filename Filename of data * \param data_filename Filename of data
* \param result_filename Filename of output result * \param result_filename Filename of output result
*/ */
void Predict(const char* data_filename, const char* result_filename, bool header) { void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check) {
auto writer = VirtualFileWriter::Make(result_filename); auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) { if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be found", result_filename); Log::Fatal("Prediction results file %s cannot be found", result_filename);
...@@ -141,8 +141,9 @@ class Predictor { ...@@ -141,8 +141,9 @@ class Predictor {
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
} }
if (!header && parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) { if (!header && !disable_shape_check && parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1); Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1);
} }
TextReader<data_size_t> predict_data_reader(data_filename, header); TextReader<data_size_t> predict_data_reader(data_filename, header);
std::vector<int> feature_remapper(parser->NumFeatures(), -1); std::vector<int> feature_remapper(parser->NumFeatures(), -1);
......
...@@ -253,8 +253,9 @@ class Booster { ...@@ -253,8 +253,9 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config, const Config& config,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) { if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1); Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
} }
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (single_row_predictor_[predict_type].get() == nullptr || if (single_row_predictor_[predict_type].get() == nullptr ||
...@@ -274,8 +275,9 @@ class Booster { ...@@ -274,8 +275,9 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config, const Config& config,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) { if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1); Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
} }
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
...@@ -327,7 +329,7 @@ class Booster { ...@@ -327,7 +329,7 @@ class Booster {
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib, Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header); predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
} }
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
......
...@@ -256,6 +256,7 @@ std::unordered_set<std::string> Config::parameter_set({ ...@@ -256,6 +256,7 @@ std::unordered_set<std::string> Config::parameter_set({
"pred_early_stop", "pred_early_stop",
"pred_early_stop_freq", "pred_early_stop_freq",
"pred_early_stop_margin", "pred_early_stop_margin",
"predict_disable_shape_check",
"convert_model_language", "convert_model_language",
"convert_model", "convert_model",
"num_class", "num_class",
...@@ -509,6 +510,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -509,6 +510,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin); GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetBool(params, "predict_disable_shape_check", &predict_disable_shape_check);
GetString(params, "convert_model_language", &convert_model_language); GetString(params, "convert_model_language", &convert_model_language);
GetString(params, "convert_model", &convert_model); GetString(params, "convert_model", &convert_model);
...@@ -669,6 +672,7 @@ std::string Config::SaveMembersToString() const { ...@@ -669,6 +672,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[pred_early_stop: " << pred_early_stop << "]\n"; str_buf << "[pred_early_stop: " << pred_early_stop << "]\n";
str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n"; str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n";
str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n"; str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n";
str_buf << "[predict_disable_shape_check: " << predict_disable_shape_check << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n"; str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n"; str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[num_class: " << num_class << "]\n"; str_buf << "[num_class: " << num_class << "]\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment