Commit ac73638f authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug for csc prediction.

parent 5aa3ef4d
......@@ -171,7 +171,7 @@ public:
void Predict(int num_iteration, int predict_type, int nrow,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const char* parameter,
const IOConfig& config,
double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false;
......@@ -183,9 +183,7 @@ public:
} else {
is_raw_score = false;
}
auto param = ConfigBase::Str2Map(parameter);
IOConfig config;
config.Set(param);
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
......@@ -204,7 +202,7 @@ public:
}
void Predict(int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const char* parameter,
int data_has_header, const IOConfig& config,
const char* result_filename) {
std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false;
......@@ -216,9 +214,6 @@ public:
} else {
is_raw_score = false;
}
auto param = ConfigBase::Str2Map(parameter);
IOConfig config;
config.Set(param);
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false;
......@@ -981,9 +976,15 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* parameter,
const char* result_filename) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameter);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
parameter, result_filename);
config.io_config, result_filename);
API_END();
}
......@@ -1014,11 +1015,17 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t* out_len,
double* out_result) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameter);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
parameter, out_result, out_len);
config.io_config, out_result, out_len);
API_END();
}
......@@ -1038,23 +1045,38 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto param = ConfigBase::Str2Map(parameter);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
int num_threads = 1;
#pragma omp parallel
#pragma omp master
{
num_threads = omp_get_num_threads();
}
int ncol = static_cast<int>(ncol_ptr - 1);
std::vector<CSC_RowIterator> iterators;
for (int j = 0; j < ncol; ++j) {
iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
for (int i = 0; i < num_threads; ++i) {
for (int j = 0; j < ncol; ++j) {
iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
}
}
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
[&iterators, ncol] (int i) {
std::vector<std::pair<int, double>> one_row;
const int tid = omp_get_thread_num();
for (int j = 0; j < ncol; ++j) {
auto val = iterators[j].Get(i);
auto val = iterators[tid][j].Get(i);
if (std::fabs(val) > kEpsilon) {
one_row.emplace_back(j, val);
}
}
return one_row;
};
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, parameter,
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config.io_config,
out_result, out_len);
API_END();
}
......@@ -1071,10 +1093,16 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
int64_t* out_len,
double* out_result) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameter);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
parameter, out_result, out_len);
config.io_config, out_result, out_len);
API_END();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment