Commit f6b8ecf6 authored by Guolin Ke's avatar Guolin Ke Committed by Nikita Titov
Browse files

option to disable the shape checking in prediction (#2669)



* implement

* better documentation

* Update include/LightGBM/config.h
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Apply suggestions from code review
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* fix document

* regenerate docs
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 65455588
......@@ -725,6 +725,18 @@ IO Parameters
- the threshold of margin in early-stopping prediction
- ``predict_disable_shape_check`` :raw-html:`<a id="predict_disable_shape_check" title="Permalink to this parameter" href="#predict_disable_shape_check">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool
- used only in ``prediction`` task
- control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data
- if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training
- if ``true``, LightGBM will attempt to predict on whatever data you provide. This is dangerous because you might get incorrect predictions, but you could use it in situations where it is difficult or expensive to generate some features and you are very confident that they were never chosen for splits in the model
- **Note**: be very careful setting this parameter to ``true``
- ``convert_model_language`` :raw-html:`<a id="convert_model_language" title="Permalink to this parameter" href="#convert_model_language">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string
- used only in ``convert_model`` task
......
......@@ -658,6 +658,13 @@ struct Config {
// desc = the threshold of margin in early-stopping prediction
double pred_early_stop_margin = 10.0;
// desc = used only in ``prediction`` task
// desc = control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data
// desc = if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training
// desc = if ``true``, LightGBM will attempt to predict on whatever data you provide. This is dangerous because you might get incorrect predictions, but you could use it in situations where it is difficult or expensive to generate some features and you are very confident that they were never chosen for splits in the model
// desc = **Note**: be very careful setting this parameter to ``true``
bool predict_disable_shape_check = false;
// desc = used only in ``convert_model`` task
// desc = only ``cpp`` is supported yet; for conversion model to other languages consider using `m2cgen <https://github.com/BayesWitnesses/m2cgen>`__ utility
// desc = if ``convert_model_language`` is set and ``task=train``, the model will be also converted
......
......@@ -215,7 +215,7 @@ void Application::Predict() {
if (config_.task == TaskType::KRefitTree) {
// create predictor
Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header);
predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
......@@ -245,7 +245,7 @@ void Application::Predict() {
config_.pred_early_stop, config_.pred_early_stop_freq,
config_.pred_early_stop_margin);
predictor.Predict(config_.data.c_str(),
config_.output_result.c_str(), config_.header);
config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
Log::Info("Finished prediction");
}
}
......
......@@ -130,7 +130,7 @@ class Predictor {
* \param data_filename Filename of data
* \param result_filename Filename of output result
*/
void Predict(const char* data_filename, const char* result_filename, bool header) {
void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check) {
auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be found", result_filename);
......@@ -141,8 +141,9 @@ class Predictor {
if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename);
}
if (!header && parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1);
if (!header && !disable_shape_check && parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1);
}
TextReader<data_size_t> predict_data_reader(data_filename, header);
std::vector<int> feature_remapper(parser->NumFeatures(), -1);
......
......@@ -253,8 +253,9 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_);
if (single_row_predictor_[predict_type].get() == nullptr ||
......@@ -274,8 +275,9 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false;
......@@ -327,7 +329,7 @@ class Booster {
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header);
predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
}
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
......
......@@ -256,6 +256,7 @@ std::unordered_set<std::string> Config::parameter_set({
"pred_early_stop",
"pred_early_stop_freq",
"pred_early_stop_margin",
"predict_disable_shape_check",
"convert_model_language",
"convert_model",
"num_class",
......@@ -509,6 +510,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetBool(params, "predict_disable_shape_check", &predict_disable_shape_check);
GetString(params, "convert_model_language", &convert_model_language);
GetString(params, "convert_model", &convert_model);
......@@ -669,6 +672,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[pred_early_stop: " << pred_early_stop << "]\n";
str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n";
str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n";
str_buf << "[predict_disable_shape_check: " << predict_disable_shape_check << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[num_class: " << num_class << "]\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment