Commit 317b1bfb authored by Ilya Matiach's avatar Ilya Matiach Committed by Guolin Ke
Browse files

[java][mmlspark] Fix cached predictor causing bad values for predicted probabilities (#2356)

* [mmlspark] Fix cached predictor causing bad values for predicted probabilities

* updated based on comments

* removed tabs
parent 254a8699
...@@ -46,6 +46,55 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \ ...@@ -46,6 +46,55 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \ catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0; return 0;
const int PREDICTOR_TYPES = 4;
// Single row predictor to abstract away caching logic
class SingleRowPredictor {
public:
PredictFunction predict_function;
int64_t num_pred_in_one_row;
SingleRowPredictor(int predict_type, Boosting& boosting, const Config& config, int iter) {
bool is_predict_leaf = false;
bool is_raw_score = false;
bool predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
predict_contrib = true;
} else {
is_raw_score = false;
}
early_stop_ = config.pred_early_stop;
early_stop_freq_ = config.pred_early_stop_freq;
early_stop_margin_ = config.pred_early_stop_margin;
iter_ = iter;
predictor_.reset(new Predictor(&boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
early_stop_, early_stop_freq_, early_stop_margin_));
num_pred_in_one_row = boosting.NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
predict_function = predictor_->GetPredictFunction();
num_total_model_ = boosting.NumberOfTotalModel();
}
~SingleRowPredictor() {}
bool IsPredictorEqual(const Config& config, int iter, Boosting& boosting) {
return early_stop_ != config.pred_early_stop ||
early_stop_freq_ != config.pred_early_stop_freq ||
early_stop_margin_ != config.pred_early_stop_margin ||
iter_ != iter ||
num_total_model_ != boosting.NumberOfTotalModel();
}
private:
std::unique_ptr<Predictor> predictor_;
bool early_stop_;
int early_stop_freq_;
double early_stop_margin_;
int iter_;
int num_total_model_;
};
class Booster { class Booster {
public: public:
explicit Booster(const char* filename) { explicit Booster(const char* filename) {
...@@ -205,33 +254,17 @@ class Booster { ...@@ -205,33 +254,17 @@ class Booster {
const Config& config, const Config& config,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (single_row_predictor_[predict_type].get() == nullptr ||
if (single_row_predictor_.get() == nullptr) { !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, *boosting_.get())) {
bool is_predict_leaf = false; single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, *boosting_.get(),
bool is_raw_score = false; config, num_iteration));
bool predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
predict_contrib = true;
} else {
is_raw_score = false;
}
// TODO(eisber): config could be optimized away... (maybe using lambda callback?)
single_row_predictor_.reset(new Predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin));
single_row_num_pred_in_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
single_row_predict_function_ = single_row_predictor_->GetPredictFunction();
} }
auto one_row = get_row_fun(0); auto one_row = get_row_fun(0);
auto pred_wrt_ptr = out_result; auto pred_wrt_ptr = out_result;
single_row_predict_function_(one_row, pred_wrt_ptr); single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
*out_len = single_row_num_pred_in_one_row_; *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
} }
...@@ -364,9 +397,7 @@ class Booster { ...@@ -364,9 +397,7 @@ class Booster {
private: private:
const Dataset* train_data_; const Dataset* train_data_;
std::unique_ptr<Boosting> boosting_; std::unique_ptr<Boosting> boosting_;
std::unique_ptr<Predictor> single_row_predictor_; std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
PredictFunction single_row_predict_function_;
int64_t single_row_num_pred_in_one_row_;
/*! \brief All configs */ /*! \brief All configs */
Config config_; Config config_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment