Commit 6d4c7b03 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Support early stopping of prediction in CLI (#565)

* fix multi-threading.

* fix name style.

* support in CLI version.

* remove warnings.

* Not default parameters.

* fix if...else... .

* fix bug.

* fix warning.

* refine c_api.

* fix R-package.

* fix R's warning.

* fix tests.

* fix pep8 .
parent e04a8bb4
...@@ -16,21 +16,10 @@ ...@@ -16,21 +16,10 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
namespace
{
/// Singleton used when earlyStop is nullptr in PredictRaw()
const auto noEarlyStop = LightGBM::createPredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
}
namespace LightGBM { namespace LightGBM {
void GBDT::PredictRaw(const double* features, double* output, const PredictionEarlyStopInstance* earlyStop) const { void GBDT::PredictRaw(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
if (earlyStop == nullptr) int early_stop_round_counter = 0;
{
earlyStop = &noEarlyStop;
}
int earlyStopRoundCounter = 0;
for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
// predict all the trees for one iteration // predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
...@@ -38,18 +27,18 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa ...@@ -38,18 +27,18 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa
} }
// check early stopping // check early stopping
++earlyStopRoundCounter; ++early_stop_round_counter;
if (earlyStop->roundPeriod == earlyStopRoundCounter) { if (early_stop->round_period == early_stop_round_counter) {
if (earlyStop->callbackFunction(output, num_tree_per_iteration_)) { if (early_stop->callback_function(output, num_tree_per_iteration_)) {
return; return;
} }
earlyStopRoundCounter = 0; early_stop_round_counter = 0;
} }
} }
} }
void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* earlyStop) const { void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
PredictRaw(features, output, earlyStop); PredictRaw(features, output, early_stop);
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output); objective_function_->ConvertOutput(output, output);
...@@ -58,7 +47,6 @@ void GBDT::Predict(const double* features, double* output, const PredictionEarly ...@@ -58,7 +47,6 @@ void GBDT::Predict(const double* features, double* output, const PredictionEarly
void GBDT::PredictLeafIndex(const double* features, double* output) const { void GBDT::PredictLeafIndex(const double* features, double* output) const {
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_; int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
#pragma omp parallel for schedule(static)
for (int i = 0; i < total_tree; ++i) { for (int i = 0; i < total_tree; ++i) {
output[i] = models_[i]->PredictLeafIndex(features); output[i] = models_[i]->PredictLeafIndex(features);
} }
......
#include <LightGBM/prediction_early_stop.h> #include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
using namespace LightGBM;
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
namespace namespace {
{
PredictionEarlyStopInstance createNone(const PredictionEarlyStopConfig&) using namespace LightGBM;
{
PredictionEarlyStopInstance CreateNone(const PredictionEarlyStopConfig&) {
return PredictionEarlyStopInstance{ return PredictionEarlyStopInstance{
[](const double*, int) [](const double*, int) {
{
return false; return false;
}, },
std::numeric_limits<int>::max() // make sure the lambda is almost never called std::numeric_limits<int>::max() // make sure the lambda is almost never called
}; };
} }
PredictionEarlyStopInstance createMulticlass(const PredictionEarlyStopConfig& config) PredictionEarlyStopInstance CreateMulticlass(const PredictionEarlyStopConfig& config) {
{ // margin_threshold will be captured by value
// marginThreshold will be captured by value const double margin_threshold = config.margin_threshold;
const double marginThreshold = config.marginThreshold;
return PredictionEarlyStopInstance{ return PredictionEarlyStopInstance{
[marginThreshold](const double* pred, int sz) [margin_threshold](const double* pred, int sz) {
{ if (sz < 2) {
if(sz < 2) {
Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger"); Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
} }
// copy and sort // copy and sort
std::vector<double> votes(static_cast<size_t>(sz)); std::vector<double> votes(static_cast<size_t>(sz));
for (int i=0; i < sz; ++i) { for (int i = 0; i < sz; ++i) {
votes[i] = pred[i]; votes[i] = pred[i];
} }
std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>()); std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());
const auto margin = votes[0] - votes[1]; const auto margin = votes[0] - votes[1];
if (margin > marginThreshold) { if (margin > margin_threshold) {
return true; return true;
} }
return false; return false;
}, },
config.roundPeriod config.round_period
}; };
} }
PredictionEarlyStopInstance createBinary(const PredictionEarlyStopConfig& config) PredictionEarlyStopInstance CreateBinary(const PredictionEarlyStopConfig& config) {
{ // margin_threshold will be captured by value
// marginThreshold will be captured by value const double margin_threshold = config.margin_threshold;
const double marginThreshold = config.marginThreshold;
return PredictionEarlyStopInstance{ return PredictionEarlyStopInstance{
[marginThreshold](const double* pred, int sz) [margin_threshold](const double* pred, int sz) {
{ if (sz != 1) {
if(sz != 1) {
Log::Fatal("Binary early stopping needs predictions to be of length one"); Log::Fatal("Binary early stopping needs predictions to be of length one");
} }
const auto margin = 2.0 * fabs(pred[0]); const auto margin = 2.0 * fabs(pred[0]);
if (margin > marginThreshold) { if (margin > margin_threshold) {
return true; return true;
} }
return false; return false;
}, },
config.roundPeriod config.round_period
}; };
}
} }
namespace LightGBM }
{
PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type, namespace LightGBM {
const PredictionEarlyStopConfig& config)
{ PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
if (type == "none") const PredictionEarlyStopConfig& config) {
{ if (type == "none") {
return createNone(config); return CreateNone(config);
} } else if (type == "multiclass") {
else if (type == "multiclass") return CreateMulticlass(config);
{ } else if (type == "binary") {
return createMulticlass(config); return CreateBinary(config);
} } else {
else if (type == "binary")
{
return createBinary(config);
}
else
{
throw std::runtime_error("Unknown early stopping type: " + type); throw std::runtime_error("Unknown early stopping type: " + type);
} }
} }
} }
...@@ -163,7 +163,7 @@ public: ...@@ -163,7 +163,7 @@ public:
void Predict(int num_iteration, int predict_type, int nrow, void Predict(int num_iteration, int predict_type, int nrow,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
...@@ -175,21 +175,28 @@ public: ...@@ -175,21 +175,28 @@ public:
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
auto param = ConfigBase::Str2Map(parameter);
IOConfig config;
config.Set(param);
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle)); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf); int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
auto pred_fun = predictor.GetPredictFunction(); auto pred_fun = predictor.GetPredictFunction();
auto pred_wrt_ptr = out_result; OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
auto one_row = get_row_fun(i); auto one_row = get_row_fun(i);
auto pred_wrt_ptr = out_result + static_cast<size_t>(num_preb_in_one_row) * i;
pred_fun(one_row, pred_wrt_ptr); pred_fun(one_row, pred_wrt_ptr);
pred_wrt_ptr += num_preb_in_one_row; OMP_LOOP_EX_END();
} }
OMP_THROW_EX();
*out_len = nrow * num_preb_in_one_row; *out_len = nrow * num_preb_in_one_row;
} }
void Predict(int num_iteration, int predict_type, const char* data_filename, void Predict(int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const PredictionEarlyStoppingHandle early_stop_handle, int data_has_header, const char* parameter,
const char* result_filename) { const char* result_filename) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
...@@ -201,8 +208,11 @@ public: ...@@ -201,8 +208,11 @@ public:
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
auto param = ConfigBase::Str2Map(parameter);
IOConfig config;
config.Set(param);
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle)); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header); predictor.Predict(data_filename, result_filename, bool_data_has_header);
} }
...@@ -244,6 +254,7 @@ public: ...@@ -244,6 +254,7 @@ public:
return ret; return ret;
} }
#pragma warning(disable : 4996)
int GetEvalNames(char** out_strs) const { int GetEvalNames(char** out_strs) const {
int idx = 0; int idx = 0;
for (const auto& metric : train_metric_) { for (const auto& metric : train_metric_) {
...@@ -255,6 +266,7 @@ public: ...@@ -255,6 +266,7 @@ public:
return idx; return idx;
} }
#pragma warning(disable : 4996)
int GetFeatureNames(char** out_strs) const { int GetFeatureNames(char** out_strs) const {
int idx = 0; int idx = 0;
for (const auto& name : boosting_->FeatureNames()) { for (const auto& name : boosting_->FeatureNames()) {
...@@ -681,6 +693,7 @@ int LGBM_DatasetSetFeatureNames( ...@@ -681,6 +693,7 @@ int LGBM_DatasetSetFeatureNames(
API_END(); API_END();
} }
#pragma warning(disable : 4996)
int LGBM_DatasetGetFeatureNames( int LGBM_DatasetGetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
char** feature_names, char** feature_names,
...@@ -695,6 +708,7 @@ int LGBM_DatasetGetFeatureNames( ...@@ -695,6 +708,7 @@ int LGBM_DatasetGetFeatureNames(
API_END(); API_END();
} }
#pragma warning(disable : 4702)
int LGBM_DatasetFree(DatasetHandle handle) { int LGBM_DatasetFree(DatasetHandle handle) {
API_BEGIN(); API_BEGIN();
delete reinterpret_cast<Dataset*>(handle); delete reinterpret_cast<Dataset*>(handle);
...@@ -802,6 +816,7 @@ int LGBM_BoosterLoadModelFromString( ...@@ -802,6 +816,7 @@ int LGBM_BoosterLoadModelFromString(
API_END(); API_END();
} }
#pragma warning(disable : 4702)
int LGBM_BoosterFree(BoosterHandle handle) { int LGBM_BoosterFree(BoosterHandle handle) {
API_BEGIN(); API_BEGIN();
delete reinterpret_cast<Booster*>(handle); delete reinterpret_cast<Booster*>(handle);
...@@ -955,12 +970,12 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -955,12 +970,12 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
early_stop_handle, result_filename); parameter, result_filename);
API_END(); API_END();
} }
...@@ -987,7 +1002,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -987,7 +1002,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t, int64_t,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
...@@ -995,7 +1010,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -995,7 +1010,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1); int nrow = static_cast<int>(nindptr - 1);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
early_stop_handle, out_result, out_len); parameter, out_result, out_len);
API_END(); API_END();
} }
...@@ -1010,7 +1025,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1010,7 +1025,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
...@@ -1021,7 +1036,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1021,7 +1036,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j); iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
} }
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun = std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
[&iterators, ncol](int i) { [&iterators, ncol] (int i) {
std::vector<std::pair<int, double>> one_row; std::vector<std::pair<int, double>> one_row;
for (int j = 0; j < ncol; ++j) { for (int j = 0; j < ncol; ++j) {
auto val = iterators[j].Get(i); auto val = iterators[j].Get(i);
...@@ -1031,7 +1046,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1031,7 +1046,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
} }
return one_row; return one_row;
}; };
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, early_stop_handle, ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, parameter,
out_result, out_len); out_result, out_len);
API_END(); API_END();
} }
...@@ -1044,14 +1059,14 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1044,14 +1059,14 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
early_stop_handle, out_result, out_len); parameter, out_result, out_len);
API_END(); API_END();
} }
...@@ -1064,6 +1079,7 @@ int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -1064,6 +1079,7 @@ int LGBM_BoosterSaveModel(BoosterHandle handle,
API_END(); API_END();
} }
#pragma warning(disable : 4996)
int LGBM_BoosterSaveModelToString(BoosterHandle handle, int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
...@@ -1079,6 +1095,7 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -1079,6 +1095,7 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
API_END(); API_END();
} }
#pragma warning(disable : 4996)
int LGBM_BoosterDumpModel(BoosterHandle handle, int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
...@@ -1114,31 +1131,6 @@ int LGBM_BoosterSetLeafValue(BoosterHandle handle, ...@@ -1114,31 +1131,6 @@ int LGBM_BoosterSetLeafValue(BoosterHandle handle,
API_END(); API_END();
} }
int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
int round_period,
double margin_threshold,
PredictionEarlyStoppingHandle* out)
{
API_BEGIN();
PredictionEarlyStopConfig config;
config.marginThreshold = margin_threshold;
config.roundPeriod = round_period;
auto earlyStop = createPredictionEarlyStopInstance(type, config);
// create new by copying
*out = new PredictionEarlyStopInstance(earlyStop);
API_END();
}
int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle)
{
API_BEGIN();
delete reinterpret_cast<const PredictionEarlyStopInstance*>(handle);
API_END();
}
// ---- start of some help functions // ---- start of some help functions
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
...@@ -1146,7 +1138,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1146,7 +1138,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
if (data_type == C_API_DTYPE_FLOAT32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx; auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1158,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1158,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
...@@ -1172,7 +1164,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1172,7 +1164,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx; auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1184,7 +1176,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1184,7 +1176,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
...@@ -1203,7 +1195,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)> ...@@ -1203,7 +1195,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major); auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
if (inner_function != nullptr) { if (inner_function != nullptr) {
return [inner_function](int row_idx) { return [inner_function] (int row_idx) {
auto raw_values = inner_function(row_idx); auto raw_values = inner_function(row_idx);
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) { for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
...@@ -1223,7 +1215,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1223,7 +1215,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1236,7 +1228,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1236,7 +1228,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1252,7 +1244,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1252,7 +1244,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1265,7 +1257,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1265,7 +1257,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1290,7 +1282,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1290,7 +1282,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1304,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1304,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1321,7 +1313,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1321,7 +1313,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1335,7 +1327,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1335,7 +1327,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
......
...@@ -230,6 +230,11 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -230,6 +230,11 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
CHECK(min_data_in_bin > 0); CHECK(min_data_in_bin > 0);
GetDouble(params, "max_conflict_rate", &max_conflict_rate); GetDouble(params, "max_conflict_rate", &max_conflict_rate);
GetBool(params, "enable_bundle", &enable_bundle); GetBool(params, "enable_bundle", &enable_bundle);
GetBool(params, "pred_early_stop", &pred_early_stop);
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetDeviceType(params); GetDeviceType(params);
} }
......
...@@ -129,6 +129,8 @@ public: ...@@ -129,6 +129,8 @@ public:
bool SkipEmptyClass() const override { return true; } bool SkipEmptyClass() const override { return true; }
bool NeedAccuratePrediction() const override { return false; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
......
...@@ -118,6 +118,8 @@ public: ...@@ -118,6 +118,8 @@ public:
int NumPredictOneRow() const override { return num_class_; } int NumPredictOneRow() const override { return num_class_; }
bool NeedAccuratePrediction() const override { return false; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
...@@ -208,6 +210,8 @@ public: ...@@ -208,6 +210,8 @@ public:
int NumPredictOneRow() const override { return num_class_; } int NumPredictOneRow() const override { return num_class_; }
bool NeedAccuratePrediction() const override { return false; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
......
...@@ -207,6 +207,8 @@ public: ...@@ -207,6 +207,8 @@ public:
return str_buf.str(); return str_buf.str();
} }
bool NeedAccuratePrediction() const override { return false; }
private: private:
/*! \brief Gains for labels */ /*! \brief Gains for labels */
std::vector<double> label_gain_; std::vector<double> label_gain_;
......
...@@ -228,8 +228,8 @@ def test_booster(): ...@@ -228,8 +228,8 @@ def test_booster():
1, 1,
1, 1,
50, 50,
ctypes.c_void_p(), c_str(''),
ctypes.byref(num_preb), ctypes.byref(num_preb),
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
LIB.LGBM_BoosterPredictForFile(booster2, c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')), 0, 0, 50, ctypes.c_void_p(), c_str('preb.txt')) LIB.LGBM_BoosterPredictForFile(booster2, c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')), 0, 0, 50, c_str(''), c_str('preb.txt'))
LIB.LGBM_BoosterFree(booster2) LIB.LGBM_BoosterFree(booster2)
...@@ -54,8 +54,8 @@ class TestBasic(unittest.TestCase): ...@@ -54,8 +54,8 @@ class TestBasic(unittest.TestCase):
self.assertEqual(*preds) self.assertEqual(*preds)
# check early stopping is working. Make it stop very early, so the scores should be very close to zero # check early stopping is working. Make it stop very early, so the scores should be very close to zero
estop = lgb.PredictionEarlyStopInstance("binary", round_period=5, margin_threshold=1.5) pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
pred_early_stopping = bst.predict(X_test, early_stop_instance=estop) pred_early_stopping = bst.predict(X_test, pred_parameter=pred_parameter)
self.assertEqual(len(pred_from_matr), len(pred_early_stopping)) self.assertEqual(len(pred_from_matr), len(pred_early_stopping))
for preds in zip(pred_early_stopping, pred_from_matr): for preds in zip(pred_early_stopping, pred_from_matr):
# scores likely to be different, but prediction should still be the same # scores likely to be different, but prediction should still be the same
......
...@@ -108,13 +108,13 @@ class TestEngine(unittest.TestCase): ...@@ -108,13 +108,13 @@ class TestEngine(unittest.TestCase):
verbose_eval=False, verbose_eval=False,
evals_result=evals_result) evals_result=evals_result)
estop = lgb.PredictionEarlyStopInstance("multiclass", round_period=5, margin_threshold=1.5) pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
ret = multi_logloss(y_test, gbm.predict(X_test, early_stop_instance=estop)) ret = multi_logloss(y_test, gbm.predict(X_test, pred_parameter=pred_parameter))
self.assertLess(ret, 0.8) self.assertLess(ret, 0.8)
self.assertGreater(ret, 0.5) # loss will be higher than when evaluating the full model self.assertGreater(ret, 0.5) # loss will be higher than when evaluating the full model
estop = lgb.PredictionEarlyStopInstance("multiclass", round_period=5, margin_threshold=5.5) pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 5.5}
ret = multi_logloss(y_test, gbm.predict(X_test, early_stop_instance=estop)) ret = multi_logloss(y_test, gbm.predict(X_test, pred_parameter=pred_parameter))
self.assertLess(ret, 0.2) self.assertLess(ret, 0.2)
def test_early_stopping(self): def test_early_stopping(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment