#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "./application/predictor.hpp" #include "./boosting/gbdt.h" namespace LightGBM { class Booster { public: explicit Booster(const char* filename) { boosting_.reset(Boosting::CreateBoosting(filename)); } Booster(const Dataset* train_data, const char* parameters) { auto param = ConfigBase::Str2Map(parameters); config_.Set(param); if (config_.num_threads > 0) { omp_set_num_threads(config_.num_threads); } // create boosting if (config_.io_config.input_model.size() > 0) { Log::Warning("continued train from model is not support for c_api, \ please use continued train with input score"); } boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr)); // initialize the boosting boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); ResetTrainingData(train_data); } void MergeFrom(const Booster* other) { std::lock_guard lock(mutex_); boosting_->MergeFrom(other->boosting_.get()); } ~Booster() { } void ResetTrainingData(const Dataset* train_data) { std::lock_guard lock(mutex_); train_data_ = train_data; // create objective function objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, config_.objective_config)); if (objective_fun_ == nullptr) { Log::Warning("Using self-defined objective function"); } // initialize the objective function if (objective_fun_ != nullptr) { objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); } // create training metric train_metric_.clear(); for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr( Metric::CreateMetric(metric_type, config_.metric_config)); if (metric == nullptr) { continue; } metric->Init(train_data_->metadata(), train_data_->num_data()); train_metric_.push_back(std::move(metric)); } train_metric_.shrink_to_fit(); // reset the boosting boosting_->ResetTrainingData(&config_.boosting_config, train_data_, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); } void ResetConfig(const char* parameters) { std::lock_guard lock(mutex_); auto param = ConfigBase::Str2Map(parameters); if (param.count("num_class")) { Log::Fatal("cannot change num class during training"); } if (param.count("boosting_type")) { Log::Fatal("cannot change boosting_type during training"); } if (param.count("metric")) { Log::Fatal("cannot change metric during training"); } config_.Set(param); if (config_.num_threads > 0) { omp_set_num_threads(config_.num_threads); } if (param.count("objective")) { // create objective function objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, config_.objective_config)); if (objective_fun_ == nullptr) { Log::Warning("Using self-defined objective function"); } // initialize the objective function if (objective_fun_ != nullptr) { objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); } } boosting_->ResetTrainingData(&config_.boosting_config, train_data_, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); } void AddValidData(const Dataset* valid_data) { std::lock_guard lock(mutex_); valid_metrics_.emplace_back(); for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); if (metric == nullptr) { continue; } metric->Init(valid_data->metadata(), valid_data->num_data()); valid_metrics_.back().push_back(std::move(metric)); } valid_metrics_.back().shrink_to_fit(); boosting_->AddValidDataset(valid_data, Common::ConstPtrInVectorWrapper(valid_metrics_.back())); } bool TrainOneIter() { std::lock_guard lock(mutex_); return boosting_->TrainOneIter(nullptr, nullptr, false); } bool TrainOneIter(const float* gradients, const float* hessians) { std::lock_guard lock(mutex_); return boosting_->TrainOneIter(gradients, hessians, false); } void RollbackOneIter() { std::lock_guard lock(mutex_); boosting_->RollbackOneIter(); } Predictor NewPredictor(int num_iteration, int predict_type) { std::lock_guard lock(mutex_); boosting_->SetNumIterationForPred(num_iteration); bool is_predict_leaf = false; bool is_raw_score = false; if (predict_type == C_API_PREDICT_LEAF_INDEX) { is_predict_leaf = true; } else if (predict_type == C_API_PREDICT_RAW_SCORE) { is_raw_score = true; } else { is_raw_score = false; } // not threading safe now // boosting_->SetNumIterationForPred may be set by other thread during prediction. return Predictor(boosting_.get(), is_raw_score, is_predict_leaf); } void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { boosting_->GetPredictAt(data_idx, out_result, out_len); } void SaveModelToFile(int num_iteration, const char* filename) { boosting_->SaveModelToFile(num_iteration, filename); } std::string DumpModel(int num_iteration) { return boosting_->DumpModel(num_iteration); } double GetLeafValue(int tree_idx, int leaf_idx) const { return dynamic_cast(boosting_.get())->GetLeafValue(tree_idx, leaf_idx); } void SetLeafValue(int tree_idx, int leaf_idx, double val) { std::lock_guard lock(mutex_); dynamic_cast(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val); } int GetEvalCounts() const { int ret = 0; for (const auto& metric : train_metric_) { ret += static_cast(metric->GetName().size()); } return ret; } int GetEvalNames(char** out_strs) const { int idx = 0; for (const auto& metric : train_metric_) { for (const auto& name : metric->GetName()) { std::strcpy(out_strs[idx], name.c_str()); ++idx; } } return idx; } const Boosting* GetBoosting() const { return boosting_.get(); } private: const Dataset* train_data_; std::unique_ptr boosting_; /*! \brief All configs */ OverallConfig config_; /*! \brief Metric for training data */ std::vector> train_metric_; /*! \brief Metrics for validation data */ std::vector>> valid_metrics_; /*! \brief Training objective function */ std::unique_ptr objective_fun_; /*! \brief mutex for threading safe call */ std::mutex mutex_; }; } using namespace LightGBM; // some help functions used to convert data std::function(int row_idx)> RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major); std::function>(int row_idx)> RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major); std::function>(int idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem); // Row iterator of on column for CSC matrix class CSC_RowIterator { public: CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx); ~CSC_RowIterator() {} // return value at idx, only can access by ascent order double Get(int idx); // return next non-zero pair, if index < 0, means no more data std::pair NextNonZero(); private: int nonzero_idx_ = 0; int cur_idx_ = -1; double cur_val_ = 0.0f; bool is_end_ = false; std::function(int idx)> iter_fun_; }; // start of c_api functions LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() { return LastErrorMsg(); } LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, const char* parameters, const DatasetHandle reference, DatasetHandle* out) { API_BEGIN(); auto param = ConfigBase::Str2Map(parameters); IOConfig io_config; io_config.Set(param); DatasetLoader loader(io_config, nullptr, 1, filename); if (reference == nullptr) { *out = loader.LoadFromFile(filename); } else { *out = loader.LoadFromFileAlignWithOtherDataset(filename, reinterpret_cast(reference)); } API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, int data_type, int32_t nrow, int32_t ncol, int is_row_major, const char* parameters, const DatasetHandle reference, DatasetHandle* out) { API_BEGIN(); auto param = ConfigBase::Str2Map(parameters); IOConfig io_config; io_config.Set(param); std::unique_ptr ret; auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); if (reference == nullptr) { // sample data first Random rand(io_config.data_random_seed); const int sample_cnt = static_cast(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt); std::vector> sample_values(ncol); for (size_t i = 0; i < sample_indices.size(); ++i) { auto idx = sample_indices[i]; auto row = get_row_fun(static_cast(idx)); for (size_t j = 0; j < row.size(); ++j) { if (std::fabs(row[j]) > 1e-15) { sample_values[j].push_back(row[j]); } } } DatasetLoader loader(io_config, nullptr, 1, nullptr); ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); } else { ret.reset(new Dataset(nrow)); ret->CopyFeatureMapperFrom( reinterpret_cast(reference), io_config.is_enable_sparse); } #pragma omp parallel for schedule(guided) for (int i = 0; i < nrow; ++i) { const int tid = omp_get_thread_num(); auto one_row = get_row_fun(i); ret->PushOneRow(tid, i, one_row); } ret->FinishLoad(); *out = ret.release(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem, int64_t num_col, const char* parameters, const DatasetHandle reference, DatasetHandle* out) { API_BEGIN(); auto param = ConfigBase::Str2Map(parameters); IOConfig io_config; io_config.Set(param); std::unique_ptr ret; auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int32_t nrow = static_cast(nindptr - 1); if (reference == nullptr) { // sample data first Random rand(io_config.data_random_seed); const int sample_cnt = static_cast(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt); std::vector> sample_values; for (size_t i = 0; i < sample_indices.size(); ++i) { auto idx = sample_indices[i]; auto row = get_row_fun(static_cast(idx)); for (std::pair& inner_data : row) { if (static_cast(inner_data.first) >= sample_values.size()) { // if need expand feature set size_t need_size = inner_data.first - sample_values.size() + 1; for (size_t j = 0; j < need_size; ++j) { sample_values.emplace_back(); } } if (std::fabs(inner_data.second) > 1e-15) { // edit the feature value sample_values[inner_data.first].push_back(inner_data.second); } } } CHECK(num_col >= static_cast(sample_values.size())); DatasetLoader loader(io_config, nullptr, 1, nullptr); ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); } else { ret.reset(new Dataset(nrow)); ret->CopyFeatureMapperFrom( reinterpret_cast(reference), io_config.is_enable_sparse); } #pragma omp parallel for schedule(guided) for (int i = 0; i < nindptr - 1; ++i) { const int tid = omp_get_thread_num(); auto one_row = get_row_fun(i); ret->PushOneRow(tid, i, one_row); } ret->FinishLoad(); *out = ret.release(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int64_t num_row, const char* parameters, const DatasetHandle reference, DatasetHandle* out) { API_BEGIN(); auto param = ConfigBase::Str2Map(parameters); IOConfig io_config; io_config.Set(param); std::unique_ptr ret; int32_t nrow = static_cast(num_row); if (reference == nullptr) { // sample data first Random rand(io_config.data_random_seed); const int sample_cnt = static_cast(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt); std::vector> sample_values(ncol_ptr - 1); #pragma omp parallel for schedule(guided) for (int i = 0; i < static_cast(sample_values.size()); ++i) { CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i); for (int j = 0; j < sample_cnt; j++) { auto val = col_it.Get(sample_indices[j]); if (std::fabs(val) > kEpsilon) { sample_values[i].push_back(val); } } } DatasetLoader loader(io_config, nullptr, 1, nullptr); ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); } else { ret.reset(new Dataset(nrow)); ret->CopyFeatureMapperFrom( reinterpret_cast(reference), io_config.is_enable_sparse); } #pragma omp parallel for schedule(guided) for (int i = 0; i < ncol_ptr - 1; ++i) { const int tid = omp_get_thread_num(); int feature_idx = ret->GetInnerFeatureIndex(i); if (feature_idx < 0) { continue; } CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i); int row_idx = 0; while (row_idx < nrow) { auto pair = col_it.NextNonZero(); row_idx = pair.first; // no more data if (row_idx < 0) { break; } ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second); } } ret->FinishLoad(); *out = ret.release(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset( const DatasetHandle handle, const int32_t* used_row_indices, int32_t num_used_row_indices, const char* parameters, DatasetHandle* out) { API_BEGIN(); auto param = ConfigBase::Str2Map(parameters); IOConfig io_config; io_config.Set(param); auto full_dataset = reinterpret_cast(handle); auto ret = std::unique_ptr( full_dataset->Subset(used_row_indices, num_used_row_indices, io_config.is_enable_sparse)); ret->FinishLoad(); *out = ret.release(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames( DatasetHandle handle, const char** feature_names, int num_feature_names) { API_BEGIN(); auto dataset = reinterpret_cast(handle); std::vector feature_names_str; for (int i = 0; i < num_feature_names; ++i) { feature_names_str.emplace_back(feature_names[i]); } dataset->set_feature_names(feature_names_str); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames( DatasetHandle handle, char** feature_names, int* num_feature_names) { API_BEGIN(); auto dataset = reinterpret_cast(handle); auto inside_feature_name = dataset->feature_names(); *num_feature_names = static_cast(inside_feature_name.size()); for (int i = 0; i < *num_feature_names; ++i) { std::strcpy(feature_names[i], inside_feature_name[i].c_str()); } API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) { API_BEGIN(); delete reinterpret_cast(handle); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle, const char* filename) { API_BEGIN(); auto dataset = reinterpret_cast(handle); dataset->SaveBinaryFile(filename); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, const char* field_name, const void* field_data, int num_element, int type) { API_BEGIN(); auto dataset = reinterpret_cast(handle); bool is_success = false; if (type == C_API_DTYPE_FLOAT32) { is_success = dataset->SetFloatField(field_name, reinterpret_cast(field_data), static_cast(num_element)); } else if (type == C_API_DTYPE_INT32) { is_success = dataset->SetIntField(field_name, reinterpret_cast(field_data), static_cast(num_element)); } else if (type == C_API_DTYPE_FLOAT64) { is_success = dataset->SetDoubleField(field_name, reinterpret_cast(field_data), static_cast(num_element)); } if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); } API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, const char* field_name, int* out_len, const void** out_ptr, int* out_type) { API_BEGIN(); auto dataset = reinterpret_cast(handle); bool is_success = false; if (dataset->GetFloatField(field_name, out_len, reinterpret_cast(out_ptr))) { *out_type = C_API_DTYPE_FLOAT32; is_success = true; } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast(out_ptr))) { *out_type = C_API_DTYPE_INT32; is_success = true; } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast(out_ptr))) { *out_type = C_API_DTYPE_FLOAT64; is_success = true; } if (!is_success) { throw std::runtime_error("Field not found"); } if (*out_ptr == nullptr) { *out_len = 0; } API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, int* out) { API_BEGIN(); auto dataset = reinterpret_cast(handle); *out = dataset->num_data(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, int* out) { API_BEGIN(); auto dataset = reinterpret_cast(handle); *out = dataset->num_total_features(); API_END(); } // ---- start of booster LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data, const char* parameters, BoosterHandle* out) { API_BEGIN(); const Dataset* p_train_data = reinterpret_cast(train_data); auto ret = std::unique_ptr(new Booster(p_train_data, parameters)); *out = ret.release(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile( const char* filename, int* out_num_iterations, BoosterHandle* out) { API_BEGIN(); auto ret = std::unique_ptr(new Booster(filename)); *out_num_iterations = ret->GetBoosting()->GetCurrentIteration(); *out = ret.release(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) { API_BEGIN(); delete reinterpret_cast(handle); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle, BoosterHandle other_handle) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); Booster* ref_other_booster = reinterpret_cast(other_handle); ref_booster->MergeFrom(ref_other_booster); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle, const DatasetHandle valid_data) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); const Dataset* p_dataset = reinterpret_cast(valid_data); ref_booster->AddValidData(p_dataset); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle, const DatasetHandle train_data) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); const Dataset* p_dataset = reinterpret_cast(train_data); ref_booster->ResetTrainingData(p_dataset); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); ref_booster->ResetConfig(parameters); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); *out_len = ref_booster->GetBoosting()->NumberOfClasses(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); if (ref_booster->TrainOneIter()) { *is_finished = 1; } else { *is_finished = 0; } API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, const float* grad, const float* hess, int* is_finished) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); if (ref_booster->TrainOneIter(grad, hess)) { *is_finished = 1; } else { *is_finished = 0; } API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); ref_booster->RollbackOneIter(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); *out_len = ref_booster->GetEvalCounts(); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); *out_len = ref_booster->GetEvalNames(out_strs); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle, int data_idx, int* out_len, double* out_results) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); auto boosting = ref_booster->GetBoosting(); auto result_buf = boosting->GetEvalAt(data_idx); *out_len = static_cast(result_buf.size()); for (size_t i = 0; i < result_buf.size(); ++i) { (out_results)[i] = static_cast(result_buf[i]); } API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle, int data_idx, int64_t* out_len) { API_BEGIN(); auto boosting = reinterpret_cast(handle)->GetBoosting(); *out_len = boosting->GetNumPredictAt(data_idx); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, int data_idx, int64_t* out_len, double* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); ref_booster->GetPredictAt(data_idx, out_result, out_len); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, const char* data_filename, int data_has_header, int predict_type, int num_iteration, const char* result_filename) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); auto predictor = ref_booster->NewPredictor(static_cast(num_iteration), predict_type); bool bool_data_has_header = data_has_header > 0 ? true : false; predictor.Predict(data_filename, result_filename, bool_data_has_header); API_END(); } int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) { int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses(); if (predict_type == C_API_PREDICT_LEAF_INDEX) { int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration(); if (num_iteration > 0) { num_preb_in_one_row *= static_cast(std::min(max_iteration, num_iteration)); } else { num_preb_in_one_row *= max_iteration; } } return num_preb_in_one_row; } LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, int num_row, int predict_type, int num_iteration, int64_t* out_len) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); *out_len = static_cast(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration)); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem, int64_t, int predict_type, int num_iteration, int64_t* out_len, double* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); auto predictor = ref_booster->NewPredictor(static_cast(num_iteration), predict_type); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration); int nrow = static_cast(nindptr - 1); #pragma omp parallel for schedule(guided) for (int i = 0; i < nrow; ++i) { auto one_row = get_row_fun(i); auto predicton_result = predictor.GetPredictFunction()(one_row); for (int j = 0; j < static_cast(predicton_result.size()); ++j) { out_result[i * num_preb_in_one_row + j] = static_cast(predicton_result[j]); } } *out_len = nrow * num_preb_in_one_row; API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int64_t num_row, int predict_type, int num_iteration, int64_t* out_len, double* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); auto predictor = ref_booster->NewPredictor(static_cast(num_iteration), predict_type); int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration); int ncol = static_cast(ncol_ptr - 1); Threading::For(0, num_row, [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem] (int, data_size_t start, data_size_t end) { std::vector iterators; for (int j = 0; j < ncol; ++j) { iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j); } std::vector> one_row; for (int64_t i = start; i < end; ++i) { one_row.clear(); for (int j = 0; j < ncol; ++j) { auto val = iterators[j].Get(static_cast(i)); if (std::fabs(val) > kEpsilon) { one_row.emplace_back(j, val); } } auto predicton_result = predictor.GetPredictFunction()(one_row); for (int j = 0; j < static_cast(predicton_result.size()); ++j) { out_result[i * num_preb_in_one_row + j] = static_cast(predicton_result[j]); } } }); *out_len = num_row * num_preb_in_one_row; API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, const void* data, int data_type, int32_t nrow, int32_t ncol, int is_row_major, int predict_type, int num_iteration, int64_t* out_len, double* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); auto predictor = ref_booster->NewPredictor(static_cast(num_iteration), predict_type); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration); #pragma omp parallel for schedule(guided) for (int i = 0; i < nrow; ++i) { auto one_row = get_row_fun(i); auto predicton_result = predictor.GetPredictFunction()(one_row); for (int j = 0; j < static_cast(predicton_result.size()); ++j) { out_result[i * num_preb_in_one_row + j] = static_cast(predicton_result[j]); } } *out_len = nrow * num_preb_in_one_row; API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, int num_iteration, const char* filename) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); ref_booster->SaveModelToFile(num_iteration, filename); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, int num_iteration, int buffer_len, int* out_len, char* out_str) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); std::string model = ref_booster->DumpModel(num_iteration); *out_len = static_cast(model.size()) + 1; if (*out_len <= buffer_len) { std::strcpy(out_str, model.c_str()); } API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle, int tree_idx, int leaf_idx, double* out_val) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); *out_val = static_cast(ref_booster->GetLeafValue(tree_idx, leaf_idx)); API_END(); } LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, int tree_idx, int leaf_idx, double val) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); ref_booster->SetLeafValue(tree_idx, leaf_idx, val); API_END(); } // ---- start of some help functions std::function(int row_idx)> RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { if (data_type == C_API_DTYPE_FLOAT32) { const float* data_ptr = reinterpret_cast(data); if (is_row_major) { return [data_ptr, num_col, num_row](int row_idx) { std::vector ret(num_col); auto tmp_ptr = data_ptr + num_col * row_idx; for (int i = 0; i < num_col; ++i) { ret[i] = static_cast(*(tmp_ptr + i)); } return ret; }; } else { return [data_ptr, num_col, num_row](int row_idx) { std::vector ret(num_col); for (int i = 0; i < num_col; ++i) { ret[i] = static_cast(*(data_ptr + num_row * i + row_idx)); } return ret; }; } } else if (data_type == C_API_DTYPE_FLOAT64) { const double* data_ptr = reinterpret_cast(data); if (is_row_major) { return [data_ptr, num_col, num_row](int row_idx) { std::vector ret(num_col); auto tmp_ptr = data_ptr + num_col * row_idx; for (int i = 0; i < num_col; ++i) { ret[i] = static_cast(*(tmp_ptr + i)); } return ret; }; } else { return [data_ptr, num_col, num_row](int row_idx) { std::vector ret(num_col); for (int i = 0; i < num_col; ++i) { ret[i] = static_cast(*(data_ptr + num_row * i + row_idx)); } return ret; }; } } throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric"); } std::function>(int row_idx)> RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major); if (inner_function != nullptr) { return [inner_function](int row_idx) { auto raw_values = inner_function(row_idx); std::vector> ret; for (int i = 0; i < static_cast(raw_values.size()); ++i) { if (std::fabs(raw_values[i]) > 1e-15) { ret.emplace_back(i, raw_values[i]); } } return ret; }; } return nullptr; } std::function>(int idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) { if (data_type == C_API_DTYPE_FLOAT32) { const float* data_ptr = reinterpret_cast(data); if (indptr_type == C_API_DTYPE_INT32) { const int32_t* ptr_indptr = reinterpret_cast(indptr); return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; for (int64_t i = start; i < end; ++i) { ret.emplace_back(indices[i], data_ptr[i]); } return ret; }; } else if (indptr_type == C_API_DTYPE_INT64) { const int64_t* ptr_indptr = reinterpret_cast(indptr); return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; for (int64_t i = start; i < end; ++i) { ret.emplace_back(indices[i], data_ptr[i]); } return ret; }; } } else if (data_type == C_API_DTYPE_FLOAT64) { const double* data_ptr = reinterpret_cast(data); if (indptr_type == C_API_DTYPE_INT32) { const int32_t* ptr_indptr = reinterpret_cast(indptr); return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; for (int64_t i = start; i < end; ++i) { ret.emplace_back(indices[i], data_ptr[i]); } return ret; }; } else if (indptr_type == C_API_DTYPE_INT64) { const int64_t* ptr_indptr = reinterpret_cast(indptr); return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; for (int64_t i = start; i < end; ++i) { ret.emplace_back(indices[i], data_ptr[i]); } return ret; }; } } throw std::runtime_error("unknown data type in RowFunctionFromCSR"); } std::function(int idx)> IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) { CHECK(col_idx < ncol_ptr && col_idx >= 0); if (data_type == C_API_DTYPE_FLOAT32) { const float* data_ptr = reinterpret_cast(data); if (col_ptr_type == C_API_DTYPE_INT32) { const int32_t* ptr_col_ptr = reinterpret_cast(col_ptr); int64_t start = ptr_col_ptr[col_idx]; int64_t end = ptr_col_ptr[col_idx + 1]; return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { int64_t i = static_cast(start + bias); if (i >= end) { return std::make_pair(-1, 0.0); } int idx = static_cast(indices[i]); double val = static_cast(data_ptr[i]); return std::make_pair(idx, val); }; } else if (col_ptr_type == C_API_DTYPE_INT64) { const int64_t* ptr_col_ptr = reinterpret_cast(col_ptr); int64_t start = ptr_col_ptr[col_idx]; int64_t end = ptr_col_ptr[col_idx + 1]; return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { int64_t i = static_cast(start + bias); if (i >= end) { return std::make_pair(-1, 0.0); } int idx = static_cast(indices[i]); double val = static_cast(data_ptr[i]); return std::make_pair(idx, val); }; } } else if (data_type == C_API_DTYPE_FLOAT64) { const double* data_ptr = reinterpret_cast(data); if (col_ptr_type == C_API_DTYPE_INT32) { const int32_t* ptr_col_ptr = reinterpret_cast(col_ptr); int64_t start = ptr_col_ptr[col_idx]; int64_t end = ptr_col_ptr[col_idx + 1]; return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { int64_t i = static_cast(start + bias); if (i >= end) { return std::make_pair(-1, 0.0); } int idx = static_cast(indices[i]); double val = static_cast(data_ptr[i]); return std::make_pair(idx, val); }; } else if (col_ptr_type == C_API_DTYPE_INT64) { const int64_t* ptr_col_ptr = reinterpret_cast(col_ptr); int64_t start = ptr_col_ptr[col_idx]; int64_t end = ptr_col_ptr[col_idx + 1]; return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { int64_t i = static_cast(start + bias); if (i >= end) { return std::make_pair(-1, 0.0); } int idx = static_cast(indices[i]); double val = static_cast(data_ptr[i]); return std::make_pair(idx, val); }; } } throw std::runtime_error("unknown data type in CSC matrix"); } CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) { iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx); } double CSC_RowIterator::Get(int idx) { while (idx > cur_idx_ && !is_end_) { auto ret = iter_fun_(nonzero_idx_); if (ret.first < 0) { is_end_ = true; break; } cur_idx_ = ret.first; cur_val_ = ret.second; ++nonzero_idx_; } if (idx == cur_idx_) { return cur_val_; } else { return 0.0f; } } std::pair CSC_RowIterator::NextNonZero() { if (!is_end_) { auto ret = iter_fun_(nonzero_idx_); ++nonzero_idx_; if (ret.first < 0) { is_end_ = true; } return ret; } else { return std::make_pair(-1, 0.0); } }